mirror of
https://gitlab.rlp.net/mobitar/ReCo.jl.git
synced 2024-12-21 00:51:21 +00:00
Added reward discount analysis
This commit is contained in:
parent
a57bebbcb7
commit
b5767f0104
4 changed files with 65 additions and 11 deletions
|
@ -160,7 +160,7 @@ function plot_radial_distributions(;
|
||||||
max_g = maximum(maximum.(gs))
|
max_g = maximum(maximum.(gs))
|
||||||
|
|
||||||
ax = Axis(
|
ax = Axis(
|
||||||
fig[1:2, 1:2];
|
fig[1, 1];
|
||||||
xticks=0:(2 * particle_radius):floor(Int64, max_lower_radius),
|
xticks=0:(2 * particle_radius):floor(Int64, max_lower_radius),
|
||||||
yticks=0:ceil(Int64, max_g),
|
yticks=0:ceil(Int64, max_g),
|
||||||
xlabel=L"r / d",
|
xlabel=L"r / d",
|
||||||
|
|
49
analysis/reward_discount_analysis.jl
Normal file
49
analysis/reward_discount_analysis.jl
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
using CairoMakie
|
||||||
|
using LaTeXStrings: @L_str
|
||||||
|
|
||||||
|
using ReCo: ReCo
|
||||||
|
|
||||||
|
includet("../src/Visualization/common_CairoMakie.jl")
|
||||||
|
|
||||||
|
function run_reward_discount_analysis()
|
||||||
|
γs = 0.0:0.2:1.0
|
||||||
|
n_γs = length(γs)
|
||||||
|
γ_rewards = Vector{Vector{Float64}}(undef, n_γs)
|
||||||
|
|
||||||
|
Threads.@threads for γ_ind in 1:n_γs
|
||||||
|
γ = γs[γ_ind]
|
||||||
|
env_helper = ReCo.run_rl(;
|
||||||
|
EnvType=ReCo.OriginEnv,
|
||||||
|
n_episodes=500,
|
||||||
|
episode_duration=8.0,
|
||||||
|
n_particles=200,
|
||||||
|
update_actions_at=0.08,
|
||||||
|
ϵ_stable=0.00001,
|
||||||
|
process_dir="reward_discount_analysis/$γ_ind",
|
||||||
|
reward_discount=γ,
|
||||||
|
)
|
||||||
|
|
||||||
|
rewards = env_helper.shared.hook.rewards
|
||||||
|
γ_rewards[γ_ind] = rewards
|
||||||
|
end
|
||||||
|
|
||||||
|
init_cairomakie!()
|
||||||
|
|
||||||
|
fig = gen_figure()
|
||||||
|
|
||||||
|
ax = Axis(fig[1, 1]; xlabel="Episode", ylabel="Reward")
|
||||||
|
|
||||||
|
rewards_plots = []
|
||||||
|
for (rewards_ind, rewards) in enumerate(γ_rewards)
|
||||||
|
rewards_plot = lines!(ax, 1:length(rewards), rewards)
|
||||||
|
push!(rewards_plots, rewards_plot)
|
||||||
|
end
|
||||||
|
|
||||||
|
Legend(fig[1, 2], rewards_plots, [L"\gamma = %$γ" for γ in γs])
|
||||||
|
|
||||||
|
set_gaps!(fig)
|
||||||
|
|
||||||
|
save_fig("reward_discount_analysis.pdf", fig)
|
||||||
|
|
||||||
|
return nothing
|
||||||
|
end
|
|
@ -26,7 +26,9 @@ include("States.jl")
|
||||||
include("Hooks.jl")
|
include("Hooks.jl")
|
||||||
include("Reward.jl")
|
include("Reward.jl")
|
||||||
|
|
||||||
function gen_agent(n_states::Int64, n_actions::Int64, ϵ_stable::Float64)
|
function gen_agent(
|
||||||
|
n_states::Int64, n_actions::Int64, ϵ_stable::Float64, reward_discount::Float64
|
||||||
|
)
|
||||||
# TODO: Optimize warmup and decay
|
# TODO: Optimize warmup and decay
|
||||||
warmup_steps = 500_000
|
warmup_steps = 500_000
|
||||||
decay_steps = 5_000_000
|
decay_steps = 5_000_000
|
||||||
|
@ -36,7 +38,7 @@ function gen_agent(n_states::Int64, n_actions::Int64, ϵ_stable::Float64)
|
||||||
approximator=TabularQApproximator(;
|
approximator=TabularQApproximator(;
|
||||||
n_state=n_states, n_action=n_actions, opt=Flux.InvDecay(1.0)
|
n_state=n_states, n_action=n_actions, opt=Flux.InvDecay(1.0)
|
||||||
),
|
),
|
||||||
γ=0.95, # Reward discount
|
γ=reward_discount,
|
||||||
),
|
),
|
||||||
explorer=EpsilonGreedyExplorer(;
|
explorer=EpsilonGreedyExplorer(;
|
||||||
kind=:linear,
|
kind=:linear,
|
||||||
|
@ -67,6 +69,7 @@ function run_rl(;
|
||||||
skin_to_interaction_radius_ratio::Float64=ReCo.DEFAULT_SKIN_TO_INTERACTION_R_RATIO,
|
skin_to_interaction_radius_ratio::Float64=ReCo.DEFAULT_SKIN_TO_INTERACTION_R_RATIO,
|
||||||
packing_ratio::Float64=0.15,
|
packing_ratio::Float64=0.15,
|
||||||
show_progress::Bool=true,
|
show_progress::Bool=true,
|
||||||
|
reward_discount::Float64=1.0,
|
||||||
) where {E<:Env}
|
) where {E<:Env}
|
||||||
@assert 0.0 <= elliptical_a_b_ratio <= 1.0
|
@assert 0.0 <= elliptical_a_b_ratio <= 1.0
|
||||||
@assert n_episodes > 0
|
@assert n_episodes > 0
|
||||||
|
@ -89,7 +92,7 @@ function run_rl(;
|
||||||
env_args = (skin_radius=sim_consts.skin_radius, half_box_len=sim_consts.half_box_len)
|
env_args = (skin_radius=sim_consts.skin_radius, half_box_len=sim_consts.half_box_len)
|
||||||
env = EnvType(; args=env_args)
|
env = EnvType(; args=env_args)
|
||||||
|
|
||||||
agent = gen_agent(env.shared.n_states, env.shared.n_actions, ϵ_stable)
|
agent = gen_agent(env.shared.n_states, env.shared.n_actions, ϵ_stable, reward_discount)
|
||||||
|
|
||||||
n_steps_before_actions_update = round(Int64, update_actions_at / sim_consts.δt)
|
n_steps_before_actions_update = round(Int64, update_actions_at / sim_consts.δt)
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
using DataFrames
|
using DataFrames: DataFrames
|
||||||
using PrettyTables: pretty_table
|
using PrettyTables: pretty_table
|
||||||
|
|
||||||
function latex_table(dataframe, filename::String; path="exports/$filename")
|
function latex_table(
|
||||||
|
dataframe::DataFrames.DataFrame, filename::String; path::String="exports/$filename"
|
||||||
|
)
|
||||||
open(path, "w") do f
|
open(path, "w") do f
|
||||||
pretty_table(f, dataframe; backend=:latex, nosubheader=true, alignment=:c)
|
pretty_table(f, dataframe; backend=:latex, nosubheader=true, alignment=:c)
|
||||||
end
|
end
|
||||||
|
@ -9,8 +11,8 @@ function latex_table(dataframe, filename::String; path="exports/$filename")
|
||||||
return nothing
|
return nothing
|
||||||
end
|
end
|
||||||
|
|
||||||
function latex_rl_table(env_helper, filename)
|
function latex_rl_table(env_helper, filename::String)
|
||||||
table = env_helper.shared.agent.policy.learner.approximator.table
|
table = copy(env_helper.shared.agent.policy.learner.approximator.table)
|
||||||
|
|
||||||
for col in 1:size(table)[2]
|
for col in 1:size(table)[2]
|
||||||
table[:, col] ./= sum(table[:, col])
|
table[:, col] ./= sum(table[:, col])
|
||||||
|
@ -31,12 +33,12 @@ function latex_rl_table(env_helper, filename)
|
||||||
|
|
||||||
for i in action_spaces_labels[1]
|
for i in action_spaces_labels[1]
|
||||||
for j in action_spaces_labels[2]
|
for j in action_spaces_labels[2]
|
||||||
push!(actions, i * "; " * j)
|
push!(actions, i * " ; " * j)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
df = DataFrame(table, states)
|
df = DataFrames.DataFrame(table, states)
|
||||||
insertcols!(df, 1, :Actions => actions)
|
DataFrames.insertcols!(df, 1, :Actions => actions)
|
||||||
|
|
||||||
latex_table(df, filename)
|
latex_table(df, filename)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue