1
0
Fork 0
mirror of https://gitlab.rlp.net/mobitar/ReCo.jl.git synced 2024-09-19 19:01:17 +00:00
ReCo.jl/src/RL/RL.jl
2022-01-29 01:04:51 +01:00

141 lines
No EOL
3.6 KiB
Julia

module RL
export run_rl, LocalCOMEnv
using Base: OneTo
using ReinforcementLearning
using Flux: InvDecay
using Intervals
using StaticArrays: SVector
using LoopVectorization: @turbo
using Random: Random
using ProgressMeter: ProgressMeter
using ..ReCo: ReCo
const INITIAL_STATE_IND = 1
const INITIAL_REWARD = 0.0
include("Env.jl")
include("EnvHelper.jl")
include("States.jl")
include("Hooks.jl")
function gen_agent(n_states::Int64, n_actions::Int64, ϵ_stable::Float64)
# TODO: Optimize warmup and decay
warmup_steps = 500_000
decay_steps = 5_000_000
policy = QBasedPolicy(;
learner=MonteCarloLearner(;
approximator=TabularQApproximator(;
n_state=n_states, n_action=n_actions, opt=InvDecay(1.0)
),
),
explorer=EpsilonGreedyExplorer(;
kind=:linear,
ϵ_init=1.0,
ϵ_stable=ϵ_stable,
warmup_steps=warmup_steps,
decay_steps=decay_steps,
),
)
trajectory = VectorSARTTrajectory(;
state=Int64, action=Int64, reward=Float64, terminal=Bool
)
return Agent(; policy=policy, trajectory=trajectory)
end
function run_rl(;
EnvType::Type{E},
parent_dir_appendix::String,
elliptical_a_b_ratio::Float64,
n_episodes::Int64=200,
episode_duration::Float64=50.0,
update_actions_at::Float64=0.1,
n_particles::Int64=100,
seed::Int64=42,
ϵ_stable::Float64=0.0001,
skin_to_interaction_r_ratio::Float64=ReCo.DEFAULT_SKIN_TO_INTERACTION_R_RATIO,
packing_ratio::Float64=0.15,
show_progress::Bool=true,
) where {E<:Env}
@assert 0.0 <= elliptical_a_b_ratio <= 1.0
@assert n_episodes > 0
@assert episode_duration > 0
@assert update_actions_at in 0.001:0.001:episode_duration
@assert n_particles > 0
@assert 0.0 < ϵ_stable < 1.0
# Setup
Random.seed!(seed)
sim_consts = ReCo.gen_sim_consts(
n_particles,
0.0;
skin_to_interaction_r_ratio=skin_to_interaction_r_ratio,
packing_ratio=packing_ratio,
)
n_particles = sim_consts.n_particles # Not always equal to the input!
env_args = (skin_r=sim_consts.skin_r,)
env = EnvType(; args=env_args)
agent = gen_agent(env.shared.n_states, env.shared.n_actions, ϵ_stable)
n_steps_before_actions_update = round(Int64, update_actions_at / sim_consts.δt)
hook = TotalRewardPerEpisode()
env_helper_shared = EnvHelperSharedProps(
env, agent, hook, n_steps_before_actions_update, elliptical_a_b_ratio, n_particles
)
env_helper_args = (half_box_len=sim_consts.half_box_len, skin_r=sim_consts.skin_r)
env_helper = gen_env_helper(env, env_helper_shared; args=env_helper_args)
parent_dir = "RL_" * parent_dir_appendix
# Pre experiment
hook(PRE_EXPERIMENT_STAGE, agent, env)
agent(PRE_EXPERIMENT_STAGE, env)
progress = ProgressMeter.Progress(n_episodes; dt=2, enabled=show_progress, desc="RL: ")
for episode in 1:n_episodes
dir = ReCo.init_sim_with_sim_consts(sim_consts; parent_dir=parent_dir)
# Reset
reset!(env)
# Pre espisode
hook(PRE_EPISODE_STAGE, agent, env)
agent(PRE_EPISODE_STAGE, env)
# Episode
ReCo.run_sim(dir; duration=episode_duration, seed=episode, env_helper=env_helper)
env.shared.terminated = true
# Post episode
hook(POST_EPISODE_STAGE, agent, env)
agent(POST_EPISODE_STAGE, env)
ProgressMeter.next!(progress; showvalues=[(:rewards, hook.rewards)])
end
# Post experiment
hook(POST_EXPERIMENT_STAGE, agent, env)
return env_helper
end
include("LocalCOMEnv.jl")
end # module