1
0
Fork 0
mirror of https://gitlab.rlp.net/mobitar/ReCo.jl.git synced 2024-09-19 19:01:17 +00:00
ReCo.jl/src/RL/RL.jl
2022-01-31 02:34:49 +01:00

197 lines
No EOL
5.4 KiB
Julia
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

module RL
export run_rl,
LocalCOMWithAdditionalShapeRewardEnv, OriginEnv, NearestNeighbourEnv, LocalCOMEnv
using Base: OneTo
using ReinforcementLearning
using Flux: Flux
using Intervals
using StaticArrays: SVector
using LoopVectorization: @turbo
using Random: Random
using ProgressMeter: ProgressMeter
using JLD2: JLD2
using LaTeXStrings: LaTeXStrings, @L_str
using ..ReCo: ReCo
const INITIAL_STATE_IND = 1
const INITIAL_REWARD = 0.0
include("Env.jl")
include("EnvHelper.jl")
include("States.jl")
include("Hooks.jl")
include("Reward.jl")
function gen_agent(
n_states::Int64, n_actions::Int64, ϵ_stable::Float64, reward_discount::Float64
)
# TODO: Optimize warming up and decay
warmup_steps = 400_000
decay_steps = 5_000_000
policy = QBasedPolicy(;
learner=MonteCarloLearner(;
approximator=TabularQApproximator(;
n_state=n_states, n_action=n_actions, opt=Flux.InvDecay(1.0)
),
γ=reward_discount,
),
explorer=EpsilonGreedyExplorer(;
kind=:linear,
ϵ_init=1.0,
ϵ_stable=ϵ_stable,
warmup_steps=warmup_steps,
decay_steps=decay_steps,
),
)
trajectory = VectorSARTTrajectory(;
state=Int64, action=Int64, reward=Float64, terminal=Bool
)
return Agent(; policy=policy, trajectory=trajectory)
end
function run_rl(;
EnvType::Type{E},
process_dir::String,
elliptical_a_b_ratio::Float64=1.0,
n_episodes::Int64=200,
episode_duration::Float64=50.0,
update_actions_at::Float64=0.1,
n_particles::Int64=100,
seed::Int64=42,
ϵ_stable::Float64=0.0001,
skin_to_interaction_radius_ratio::Float64=ReCo.DEFAULT_SKIN_TO_INTERACTION_RADIUS_RATIO,
packing_ratio::Float64=0.15,
show_progress::Bool=true,
reward_discount::Float64=0.1,
show_simulation_progress::Bool=true,
n_episodes_before_env_helper_saving::Int64=10,
) where {E<:Env}
@assert 0.0 <= elliptical_a_b_ratio <= 1.0
@assert n_episodes > 0
@assert episode_duration > 0
@assert update_actions_at in 0.001:0.001:episode_duration
@assert n_particles > 0
@assert 0.0 < ϵ_stable < 1.0
@assert 0.0 <= reward_discount <= 1.0
@assert n_episodes_before_env_helper_saving > 0
# Setup
Random.seed!(seed)
sim_consts = ReCo.gen_sim_consts(
n_particles,
0.0;
skin_to_interaction_radius_ratio=skin_to_interaction_radius_ratio,
packing_ratio=packing_ratio,
)
n_particles = sim_consts.n_particles # Not always equal to the input!
env_args = (skin_radius=sim_consts.skin_radius, half_box_len=sim_consts.half_box_len)
env = EnvType(; args=env_args)
agent = gen_agent(env.shared.n_states, env.shared.n_actions, ϵ_stable, reward_discount)
n_steps_before_actions_update = round(Int64, update_actions_at / sim_consts.δt)
hook = TotalRewardPerEpisode()
n_actions_updates_per_episode = ceil(Int64, episode_duration / update_actions_at)
env_helper_shared = EnvHelperSharedProps(
env,
agent,
hook,
n_steps_before_actions_update,
n_actions_updates_per_episode,
elliptical_a_b_ratio,
n_particles,
)
env_helper_args = (
half_box_len=sim_consts.half_box_len, skin_radius=sim_consts.skin_radius
)
env_helper = gen_env_helper(env, env_helper_shared; args=env_helper_args)
parent_dir = "RL/" * process_dir
env_helper_path = ReCo.DEFAULT_EXPORTS_DIR * "/$parent_dir/env_helper.jld2"
# Pre experiment
hook(PRE_EXPERIMENT_STAGE, agent, env)
agent(PRE_EXPERIMENT_STAGE, env)
progress = ProgressMeter.Progress(n_episodes; dt=2, enabled=show_progress, desc="RL: ")
for episode in 1:n_episodes
dir = ReCo.init_sim_with_sim_consts(sim_consts; parent_dir=parent_dir)
# Reset
reset!(env)
# Pre espisode
hook(PRE_EPISODE_STAGE, agent, env)
agent(PRE_EPISODE_STAGE, env)
# Episode
ReCo.run_sim(
dir;
duration=episode_duration,
seed=episode,
env_helper=env_helper,
show_progress=show_simulation_progress,
)
env.shared.terminated = true
# Post episode
hook(POST_EPISODE_STAGE, agent, env)
agent(POST_EPISODE_STAGE, env)
if episode % n_episodes_before_env_helper_saving == 0
JLD2.save_object(env_helper_path, env_helper)
end
ProgressMeter.next!(progress; showvalues=[(:rewards, hook.rewards)])
end
# Post experiment
hook(POST_EXPERIMENT_STAGE, agent, env)
JLD2.save_object(env_helper_path, env_helper)
return env_helper
end
function gen_state_space_labels(state_label::String, state_space::Vector{Interval})
labels = Vector{LaTeXStrings.LaTeXString}(undef, length(state_space))
for (state_ind, state) in enumerate(state_space)
labels[state_ind] = LaTeXStrings.latexstring(
"\$" *
state_label *
"\$=$(round(state.first; digits=2)):$(round(state.last, digits=2))",
)
end
return labels
end
function gen_state_spaces_labels(
states_labels::NTuple{N,String}, state_spaces::NTuple{N,Vector{Interval}}
) where {N}
return [gen_state_space_labels(states_labels[i], state_spaces[i]) for i in 1:N]
end
include("Envs/LocalCOMWithAdditionalShapeRewardEnv.jl")
include("Envs/OriginEnv.jl")
include("Envs/NearestNeighbourEnv.jl")
include("Envs/LocalCOMEnv.jl")
end # module