1
0
Fork 0
mirror of https://gitlab.rlp.net/mobitar/ReCo.jl.git synced 2024-09-19 19:01:17 +00:00
ReCo.jl/analysis/reward_discount_analysis.jl

78 lines
2 KiB
Julia
Raw Normal View History

2022-01-30 01:28:34 +00:00
using CairoMakie
using LaTeXStrings: @L_str
using ReCo: ReCo
2022-02-08 22:06:22 +00:00
include("../src/Visualization/common_CairoMakie.jl")
2022-01-30 01:28:34 +00:00
function run_rl_prcesses_reward_discount(γs::AbstractVector)
2022-01-30 01:28:34 +00:00
n_γs = length(γs)
env_helpers = Vector{ReCo.RL.EnvHelper}(undef, n_γs)
2022-01-30 01:28:34 +00:00
Threads.@threads for γ_ind in 1:n_γs
γ = γs[γ_ind]
2022-02-07 17:41:25 +00:00
env_helper, rl_dir = ReCo.run_rl(;
2022-01-30 01:28:34 +00:00
EnvType=ReCo.OriginEnv,
2022-01-30 03:38:57 +00:00
n_episodes=400,
2022-01-30 02:32:47 +00:00
episode_duration=15.0,
n_particles=150,
2022-01-30 01:28:34 +00:00
update_actions_at=0.08,
ϵ_stable=0.00001,
process_dir="reward_discount_analysis/$γ_ind",
reward_discount=γ,
2022-01-30 02:32:47 +00:00
show_simulation_progress=false,
2022-01-30 01:28:34 +00:00
)
env_helpers[γ_ind] = env_helper
end
return env_helpers
end
function plot_reward_discount_analysis(
γs::AbstractVector, env_helpers::Vector{ReCo.RL.EnvHelper}, linestyles::NTuple{N,Symbol}
) where {N}
n_γs = length(γs)
@assert n_γs == length(env_helpers) == N
γ_rewards = Vector{Vector{Float64}}(undef, n_γs)
for (env_helper_ind, env_helper) in enumerate(env_helpers)
2022-01-30 01:28:34 +00:00
rewards = env_helper.shared.hook.rewards
γ_rewards[env_helper_ind] = rewards
2022-01-30 01:28:34 +00:00
end
init_cairomakie!()
fig = gen_figure()
ax = Axis(fig[1, 1]; xlabel="Episode", ylabel="Reward")
rewards_plots = []
for (rewards, linestyle) in zip(γ_rewards, linestyles)
rewards_plot = lines!(
ax, 1:length(rewards), rewards; linestyle=linestyle, linewidth=0.6
)
2022-01-30 01:28:34 +00:00
push!(rewards_plots, rewards_plot)
end
Legend(fig[1, 2], rewards_plots, [L"\gamma = %$γ" for γ in γs])
set_gaps!(fig)
save_fig("reward_discount_analysis.pdf", fig)
return nothing
end
function run_reward_discount_analysis()
γs = 0.0:0.25:1.0
env_helpers = run_rl_prcesses_reward_discount(γs)
plot_reward_discount_analysis(
γs, env_helpers, (:solid, :dash, :dashdot, :solid, :solid)
)
2022-01-30 01:28:34 +00:00
return nothing
end