mirror of
https://gitlab.rlp.net/mobitar/ReCo.jl.git
synced 2024-12-21 00:51:21 +00:00
Seperated processes and plotting for reward discount analysis
This commit is contained in:
parent
946e646594
commit
3cd7395d60
1 changed files with 34 additions and 6 deletions
|
@ -5,10 +5,9 @@ using ReCo: ReCo
|
||||||
|
|
||||||
includet("../src/Visualization/common_CairoMakie.jl")
|
includet("../src/Visualization/common_CairoMakie.jl")
|
||||||
|
|
||||||
function run_reward_discount_analysis()
|
function run_rl_prcesses_reward_discount(γs::AbstractVector)
|
||||||
γs = 0.0:0.2:1.0
|
|
||||||
n_γs = length(γs)
|
n_γs = length(γs)
|
||||||
γ_rewards = Vector{Vector{Float64}}(undef, n_γs)
|
env_helpers = Vector{ReCo.RL.EnvHelper}(undef, n_γs)
|
||||||
|
|
||||||
Threads.@threads for γ_ind in 1:n_γs
|
Threads.@threads for γ_ind in 1:n_γs
|
||||||
γ = γs[γ_ind]
|
γ = γs[γ_ind]
|
||||||
|
@ -24,8 +23,23 @@ function run_reward_discount_analysis()
|
||||||
show_simulation_progress=false,
|
show_simulation_progress=false,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
env_helpers[γ_ind] = env_helper
|
||||||
|
end
|
||||||
|
|
||||||
|
return env_helpers
|
||||||
|
end
|
||||||
|
|
||||||
|
function plot_reward_discount_analysis(
|
||||||
|
γs::AbstractVector, env_helpers::Vector{ReCo.RL.EnvHelper}, linestyles::NTuple{N,Symbol}
|
||||||
|
) where {N}
|
||||||
|
n_γs = length(γs)
|
||||||
|
@assert n_γs == length(env_helpers) == N
|
||||||
|
|
||||||
|
γ_rewards = Vector{Vector{Float64}}(undef, n_γs)
|
||||||
|
|
||||||
|
for (env_helper_ind, env_helper) in enumerate(env_helpers)
|
||||||
rewards = env_helper.shared.hook.rewards
|
rewards = env_helper.shared.hook.rewards
|
||||||
γ_rewards[γ_ind] = rewards
|
γ_rewards[env_helper_ind] = rewards
|
||||||
end
|
end
|
||||||
|
|
||||||
init_cairomakie!()
|
init_cairomakie!()
|
||||||
|
@ -35,8 +49,10 @@ function run_reward_discount_analysis()
|
||||||
ax = Axis(fig[1, 1]; xlabel="Episode", ylabel="Reward")
|
ax = Axis(fig[1, 1]; xlabel="Episode", ylabel="Reward")
|
||||||
|
|
||||||
rewards_plots = []
|
rewards_plots = []
|
||||||
for rewards in γ_rewards
|
for (rewards, linestyle) in zip(γ_rewards, linestyles)
|
||||||
rewards_plot = lines!(ax, 1:length(rewards), rewards)
|
rewards_plot = lines!(
|
||||||
|
ax, 1:length(rewards), rewards; linestyle=linestyle, linewidth=0.6
|
||||||
|
)
|
||||||
push!(rewards_plots, rewards_plot)
|
push!(rewards_plots, rewards_plot)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -48,3 +64,15 @@ function run_reward_discount_analysis()
|
||||||
|
|
||||||
return nothing
|
return nothing
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function run_reward_discount_analysis()
|
||||||
|
γs = 0.0:0.25:1.0
|
||||||
|
|
||||||
|
env_helpers = run_rl_prcesses_reward_discount(γs)
|
||||||
|
|
||||||
|
plot_reward_discount_analysis(
|
||||||
|
γs, env_helpers, (:solid, :dash, :dashdot, :solid, :solid)
|
||||||
|
)
|
||||||
|
|
||||||
|
return nothing
|
||||||
|
end
|
Loading…
Reference in a new issue