1
0
Fork 0
mirror of https://gitlab.rlp.net/mobitar/ReCo.jl.git synced 2024-09-17 18:51:17 +00:00
ReCo.jl/analysis/reward_discount_analysis.jl
2022-04-06 17:07:04 +02:00

78 lines
2 KiB
Julia
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

using CairoMakie
using LaTeXStrings: @L_str
using ReCo: ReCo
include("../visualization/common_CairoMakie.jl")
function run_reward_discount_processes(γs::AbstractVector)
n_γs = length(γs)
env_helpers = Vector{ReCo.RL.EnvHelper}(undef, n_γs)
Threads.@threads for γ_ind in 1:n_γs
γ = γs[γ_ind]
env_helper, rl_dir = ReCo.run_rl(
ReCo.OriginEnv;
n_episodes=400,
episode_duration=15.0,
n_particles=150,
update_actions_at=0.08,
ϵ_stable=0.00001,
process_dir="reward_discount_analysis/$γ_ind",
reward_discount=γ,
show_simulation_progress=false,
)
env_helpers[γ_ind] = env_helper
end
return env_helpers
end
function plot_reward_discount_analysis(
γs::AbstractVector, env_helpers::Vector{ReCo.RL.EnvHelper}, linestyles::NTuple{N,Symbol}
) where {N}
n_γs = length(γs)
@assert n_γs == length(env_helpers) == N
γ_rewards = Vector{Vector{Float64}}(undef, n_γs)
for (env_helper_ind, env_helper) in enumerate(env_helpers)
rewards = env_helper.shared.hook.rewards
γ_rewards[env_helper_ind] = rewards
end
init_cairomakie!()
fig = gen_figure()
ax = Axis(fig[1, 1]; xlabel="Episode", ylabel="Reward")
rewards_plots = []
for (rewards, linestyle) in zip(γ_rewards, linestyles)
rewards_plot = lines!(
ax, 1:length(rewards), rewards; linestyle=linestyle, linewidth=0.6
)
push!(rewards_plots, rewards_plot)
end
Legend(fig[1, 2], rewards_plots, [L"\gamma = %$γ" for γ in γs])
set_gaps!(fig)
save_fig("reward_discount_analysis.pdf", fig)
return nothing
end
function run_reward_discount_analysis()
γs = 0.0:0.25:1.0
env_helpers = run_reward_discount_processes(γs)
plot_reward_discount_analysis(
γs, env_helpers, (:solid, :dash, :dashdot, :solid, :solid)
)
return nothing
end