1
0
Fork 0
mirror of https://gitlab.rlp.net/mobitar/ReCo.jl.git synced 2024-09-19 19:01:17 +00:00
ReCo.jl/src/RL/RL.jl

192 lines
5.3 KiB
Julia
Raw Normal View History

2021-12-12 23:19:18 +00:00
module RL
2021-12-12 14:29:08 +00:00
2022-01-29 16:13:17 +00:00
export run_rl, LocalCOMWithAdditionalShapeRewardEnv, OriginEnv
2021-12-14 03:03:14 +00:00
2022-01-06 00:48:37 +00:00
using Base: OneTo
2021-12-10 02:16:45 +00:00
using ReinforcementLearning
2022-01-29 01:26:55 +00:00
using Flux: Flux
2021-12-12 14:29:08 +00:00
using Intervals
2021-12-12 17:27:56 +00:00
using StaticArrays: SVector
2021-12-14 03:03:14 +00:00
using LoopVectorization: @turbo
2021-12-12 23:19:18 +00:00
using Random: Random
2022-01-23 03:21:06 +00:00
using ProgressMeter: ProgressMeter
using JLD2: JLD2
using LaTeXStrings: LaTeXStrings, @L_str
2021-12-12 23:19:18 +00:00
2022-01-14 12:01:14 +00:00
using ..ReCo: ReCo
2021-12-12 14:29:08 +00:00
2022-01-06 00:48:37 +00:00
const INITIAL_STATE_IND = 1
const INITIAL_REWARD = 0.0
2022-01-11 18:00:41 +00:00
include("Env.jl")
include("EnvHelper.jl")
2022-01-06 00:48:37 +00:00
2022-01-11 18:00:41 +00:00
include("States.jl")
include("Hooks.jl")
2022-01-29 01:26:55 +00:00
include("Reward.jl")
2021-12-12 14:29:08 +00:00
2022-01-30 01:28:34 +00:00
function gen_agent(
n_states::Int64, n_actions::Int64, ϵ_stable::Float64, reward_discount::Float64
)
2022-01-06 00:48:37 +00:00
# TODO: Optimize warmup and decay
2022-01-15 17:55:01 +00:00
warmup_steps = 500_000
decay_steps = 5_000_000
2022-01-06 00:48:37 +00:00
2021-12-20 23:31:44 +00:00
policy = QBasedPolicy(;
learner=MonteCarloLearner(;
approximator=TabularQApproximator(;
2022-01-29 01:26:55 +00:00
n_state=n_states, n_action=n_actions, opt=Flux.InvDecay(1.0)
2021-12-20 23:31:44 +00:00
),
2022-01-30 01:28:34 +00:00
γ=reward_discount,
2021-12-20 23:31:44 +00:00
),
2022-01-06 00:48:37 +00:00
explorer=EpsilonGreedyExplorer(;
kind=:linear,
ϵ_init=1.0,
ϵ_stable=ϵ_stable,
warmup_steps=warmup_steps,
decay_steps=decay_steps,
),
2021-12-20 23:31:44 +00:00
)
2022-01-11 17:43:43 +00:00
trajectory = VectorSARTTrajectory(;
state=Int64, action=Int64, reward=Float64, terminal=Bool
)
return Agent(; policy=policy, trajectory=trajectory)
2021-12-20 23:31:44 +00:00
end
2021-12-15 03:45:15 +00:00
function run_rl(;
EnvType::Type{E},
process_dir::String,
2022-01-29 16:40:26 +00:00
elliptical_a_b_ratio::Float64=1.0,
2021-12-28 16:15:00 +00:00
n_episodes::Int64=200,
2021-12-14 03:03:14 +00:00
episode_duration::Float64=50.0,
2021-12-28 16:15:00 +00:00
update_actions_at::Float64=0.1,
2021-12-13 01:24:34 +00:00
n_particles::Int64=100,
2021-12-14 03:03:14 +00:00
seed::Int64=42,
2022-01-06 00:48:37 +00:00
ϵ_stable::Float64=0.0001,
2022-01-30 20:19:53 +00:00
skin_to_interaction_radius_ratio::Float64=ReCo.DEFAULT_SKIN_TO_INTERACTION_RADIUS_RATIO,
2022-01-29 00:04:51 +00:00
packing_ratio::Float64=0.15,
2022-01-23 03:21:06 +00:00
show_progress::Bool=true,
2022-01-30 13:46:08 +00:00
reward_discount::Float64=0.1,
2022-01-30 02:20:45 +00:00
show_simulation_progress::Bool=true,
2022-01-30 20:19:53 +00:00
n_episodes_before_env_helper_saving::Int64=10,
) where {E<:Env}
@assert 0.0 <= elliptical_a_b_ratio <= 1.0
2021-12-12 14:29:08 +00:00
@assert n_episodes > 0
@assert episode_duration > 0
2021-12-28 16:15:00 +00:00
@assert update_actions_at in 0.001:0.001:episode_duration
2021-12-12 14:29:08 +00:00
@assert n_particles > 0
2022-01-06 00:48:37 +00:00
@assert 0.0 < ϵ_stable < 1.0
2022-01-30 20:19:53 +00:00
@assert 0.0 <= reward_discount <= 1.0
@assert n_episodes_before_env_helper_saving > 0
2021-12-12 14:29:08 +00:00
2021-12-14 03:03:14 +00:00
# Setup
Random.seed!(seed)
2021-12-10 02:16:45 +00:00
2021-12-28 16:15:00 +00:00
sim_consts = ReCo.gen_sim_consts(
n_particles,
0.0;
2022-01-29 13:32:04 +00:00
skin_to_interaction_radius_ratio=skin_to_interaction_radius_ratio,
packing_ratio=packing_ratio,
2021-12-28 16:15:00 +00:00
)
2022-01-11 18:00:41 +00:00
n_particles = sim_consts.n_particles # Not always equal to the input!
2021-12-12 14:29:08 +00:00
2022-01-29 16:40:26 +00:00
env_args = (skin_radius=sim_consts.skin_radius, half_box_len=sim_consts.half_box_len)
2022-01-14 11:28:47 +00:00
env = EnvType(; args=env_args)
2021-12-20 23:31:44 +00:00
2022-01-30 01:28:34 +00:00
agent = gen_agent(env.shared.n_states, env.shared.n_actions, ϵ_stable, reward_discount)
2021-12-12 14:29:08 +00:00
2021-12-12 23:19:18 +00:00
n_steps_before_actions_update = round(Int64, update_actions_at / sim_consts.δt)
2021-12-20 23:31:44 +00:00
hook = TotalRewardPerEpisode()
2022-01-30 02:20:45 +00:00
n_actions_updates_per_episode = ceil(Int64, episode_duration / update_actions_at)
2022-01-14 11:28:47 +00:00
env_helper_shared = EnvHelperSharedProps(
2022-01-30 02:20:45 +00:00
env,
agent,
hook,
n_steps_before_actions_update,
n_actions_updates_per_episode,
elliptical_a_b_ratio,
n_particles,
2021-12-12 17:27:56 +00:00
)
2022-01-29 13:32:04 +00:00
env_helper_args = (
half_box_len=sim_consts.half_box_len, skin_radius=sim_consts.skin_radius
)
2022-01-14 11:28:47 +00:00
env_helper = gen_env_helper(env, env_helper_shared; args=env_helper_args)
parent_dir = "RL/" * process_dir
2021-12-28 16:15:00 +00:00
2022-01-30 20:19:53 +00:00
env_helper_path = ReCo.DEFAULT_EXPORTS_DIR * "/$parent_dir/env_helper.jld2"
2021-12-14 03:03:14 +00:00
# Pre experiment
2021-12-20 23:31:44 +00:00
hook(PRE_EXPERIMENT_STAGE, agent, env)
agent(PRE_EXPERIMENT_STAGE, env)
2021-12-12 14:29:08 +00:00
2022-01-23 04:26:27 +00:00
progress = ProgressMeter.Progress(n_episodes; dt=2, enabled=show_progress, desc="RL: ")
2022-01-23 03:21:06 +00:00
for episode in 1:n_episodes
2021-12-28 16:15:00 +00:00
dir = ReCo.init_sim_with_sim_consts(sim_consts; parent_dir=parent_dir)
2021-12-12 17:27:56 +00:00
2021-12-14 03:03:14 +00:00
# Reset
2021-12-20 23:31:44 +00:00
reset!(env)
2021-12-12 14:29:08 +00:00
2021-12-14 03:03:14 +00:00
# Pre espisode
2021-12-20 23:31:44 +00:00
hook(PRE_EPISODE_STAGE, agent, env)
agent(PRE_EPISODE_STAGE, env)
2021-12-14 03:03:14 +00:00
# Episode
2022-01-30 02:20:45 +00:00
ReCo.run_sim(
dir;
duration=episode_duration,
seed=episode,
env_helper=env_helper,
show_progress=show_simulation_progress,
)
2021-12-12 17:27:56 +00:00
2022-01-11 17:39:38 +00:00
env.shared.terminated = true
2021-12-12 17:27:56 +00:00
2021-12-20 23:31:44 +00:00
# Post episode
hook(POST_EPISODE_STAGE, agent, env)
agent(POST_EPISODE_STAGE, env)
2021-12-28 16:15:00 +00:00
2022-01-30 20:19:53 +00:00
if episode % n_episodes_before_env_helper_saving == 0
JLD2.save_object(env_helper_path, env_helper)
end
2022-01-23 04:26:27 +00:00
ProgressMeter.next!(progress; showvalues=[(:rewards, hook.rewards)])
2021-12-12 14:29:08 +00:00
end
2021-12-14 03:03:14 +00:00
# Post experiment
2021-12-20 23:31:44 +00:00
hook(POST_EXPERIMENT_STAGE, agent, env)
2021-12-12 17:27:56 +00:00
return env_helper
2021-12-12 14:29:08 +00:00
end
2021-12-10 02:16:45 +00:00
function gen_state_space_labels(state_label::String, state_space::Vector{Interval})
labels = Vector{LaTeXStrings.LaTeXString}(undef, length(state_space))
for (state_ind, state) in enumerate(state_space)
labels[state_ind] = LaTeXStrings.latexstring(
2022-01-30 19:46:21 +00:00
"\$" *
state_label *
2022-01-30 19:46:21 +00:00
"\$=$(round(state.first; digits=2)):$(round(state.last, digits=2))",
)
end
return labels
end
function gen_state_spaces_labels(
states_labels::NTuple{N,String}, state_spaces::NTuple{N,Vector{Interval}}
) where {N}
return [gen_state_space_labels(states_labels[i], state_spaces[i]) for i in 1:N]
end
2022-01-29 14:48:13 +00:00
include("Envs/LocalCOMWithAdditionalShapeRewardEnv.jl")
2022-01-29 16:13:17 +00:00
include("Envs/OriginEnv.jl")
2021-12-12 14:29:08 +00:00
end # module