1
0
Fork 0
mirror of https://gitlab.rlp.net/mobitar/ReCo.jl.git synced 2024-09-19 19:01:17 +00:00
ReCo.jl/src/RL/RL.jl
2022-01-29 17:13:17 +01:00

146 lines
No EOL
3.8 KiB
Julia
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

module RL
export run_rl, LocalCOMWithAdditionalShapeRewardEnv, OriginEnv
using Base: OneTo
using ReinforcementLearning
using Flux: Flux
using Intervals
using StaticArrays: SVector
using LoopVectorization: @turbo
using Random: Random
using ProgressMeter: ProgressMeter
using ..ReCo: ReCo
const INITIAL_STATE_IND = 1
const INITIAL_REWARD = 0.0
include("Env.jl")
include("EnvHelper.jl")
include("States.jl")
include("Hooks.jl")
include("Reward.jl")
function gen_agent(n_states::Int64, n_actions::Int64, ϵ_stable::Float64)
# TODO: Optimize warmup and decay
warmup_steps = 500_000
decay_steps = 5_000_000
policy = QBasedPolicy(;
learner=MonteCarloLearner(;
approximator=TabularQApproximator(;
n_state=n_states, n_action=n_actions, opt=Flux.InvDecay(1.0)
),
γ=0.95, # Reward discount
),
explorer=EpsilonGreedyExplorer(;
kind=:linear,
ϵ_init=1.0,
ϵ_stable=ϵ_stable,
warmup_steps=warmup_steps,
decay_steps=decay_steps,
),
)
trajectory = VectorSARTTrajectory(;
state=Int64, action=Int64, reward=Float64, terminal=Bool
)
return Agent(; policy=policy, trajectory=trajectory)
end
function run_rl(;
EnvType::Type{E},
parent_dir_appendix::String,
elliptical_a_b_ratio::Float64,
n_episodes::Int64=200,
episode_duration::Float64=50.0,
update_actions_at::Float64=0.1,
n_particles::Int64=100,
seed::Int64=42,
ϵ_stable::Float64=0.0001,
skin_to_interaction_radius_ratio::Float64=ReCo.DEFAULT_SKIN_TO_INTERACTION_R_RATIO,
packing_ratio::Float64=0.15,
show_progress::Bool=true,
) where {E<:Env}
@assert 0.0 <= elliptical_a_b_ratio <= 1.0
@assert n_episodes > 0
@assert episode_duration > 0
@assert update_actions_at in 0.001:0.001:episode_duration
@assert n_particles > 0
@assert 0.0 < ϵ_stable < 1.0
# Setup
Random.seed!(seed)
sim_consts = ReCo.gen_sim_consts(
n_particles,
0.0;
skin_to_interaction_radius_ratio=skin_to_interaction_radius_ratio,
packing_ratio=packing_ratio,
)
n_particles = sim_consts.n_particles # Not always equal to the input!
env_args = (skin_radius=sim_consts.skin_radius,)
env = EnvType(; args=env_args)
agent = gen_agent(env.shared.n_states, env.shared.n_actions, ϵ_stable)
n_steps_before_actions_update = round(Int64, update_actions_at / sim_consts.δt)
hook = TotalRewardPerEpisode()
env_helper_shared = EnvHelperSharedProps(
env, agent, hook, n_steps_before_actions_update, elliptical_a_b_ratio, n_particles
)
env_helper_args = (
half_box_len=sim_consts.half_box_len, skin_radius=sim_consts.skin_radius
)
env_helper = gen_env_helper(env, env_helper_shared; args=env_helper_args)
parent_dir = "RL_" * parent_dir_appendix
# Pre experiment
hook(PRE_EXPERIMENT_STAGE, agent, env)
agent(PRE_EXPERIMENT_STAGE, env)
progress = ProgressMeter.Progress(n_episodes; dt=2, enabled=show_progress, desc="RL: ")
for episode in 1:n_episodes
dir = ReCo.init_sim_with_sim_consts(sim_consts; parent_dir=parent_dir)
# Reset
reset!(env)
# Pre espisode
hook(PRE_EPISODE_STAGE, agent, env)
agent(PRE_EPISODE_STAGE, env)
# Episode
ReCo.run_sim(dir; duration=episode_duration, seed=episode, env_helper=env_helper)
env.shared.terminated = true
# Post episode
hook(POST_EPISODE_STAGE, agent, env)
agent(POST_EPISODE_STAGE, env)
ProgressMeter.next!(progress; showvalues=[(:rewards, hook.rewards)])
end
# Post experiment
hook(POST_EXPERIMENT_STAGE, agent, env)
return env_helper
end
include("Envs/LocalCOMWithAdditionalShapeRewardEnv.jl")
include("Envs/OriginEnv.jl")
end # module