diff --git a/src/RL/LocalCOMEnv.jl b/src/RL/LocalCOMEnv.jl index 900b9f7..771ce14 100644 --- a/src/RL/LocalCOMEnv.jl +++ b/src/RL/LocalCOMEnv.jl @@ -159,19 +159,6 @@ function state_update_hook!(env_helper::LocalCOMEnvHelper, particles::Vector{Par return nothing end -""" - minimizing_reward(value::Float64, max_value::Float64) - -Returns the reward such that it is 0 for value=max_value and 1 for value=0. -""" -function minimizing_reward(value::Float64, max_value::Float64) - if value > max_value - error("value > max_value") - end - - return ((max_value - value) / (max_value + value))^2 -end - function update_reward!(env::LocalCOMEnv, env_helper::LocalCOMEnvHelper, particle::Particle) id = particle.id diff --git a/src/RL/RL.jl b/src/RL/RL.jl index c28514f..51aaf3e 100644 --- a/src/RL/RL.jl +++ b/src/RL/RL.jl @@ -5,7 +5,7 @@ export run_rl, LocalCOMEnv using Base: OneTo using ReinforcementLearning -using Flux: InvDecay +using Flux: Flux using Intervals using StaticArrays: SVector using LoopVectorization: @turbo @@ -22,6 +22,7 @@ include("EnvHelper.jl") include("States.jl") include("Hooks.jl") +include("Reward.jl") function gen_agent(n_states::Int64, n_actions::Int64, ϵ_stable::Float64) # TODO: Optimize warmup and decay @@ -31,8 +32,9 @@ function gen_agent(n_states::Int64, n_actions::Int64, ϵ_stable::Float64) policy = QBasedPolicy(; learner=MonteCarloLearner(; approximator=TabularQApproximator(; - n_state=n_states, n_action=n_actions, opt=InvDecay(1.0) + n_state=n_states, n_action=n_actions, opt=Flux.InvDecay(1.0) ), + γ=0.95, # Reward discount ), explorer=EpsilonGreedyExplorer(; kind=:linear, diff --git a/src/RL/Reward.jl b/src/RL/Reward.jl new file mode 100644 index 0000000..a8552c6 --- /dev/null +++ b/src/RL/Reward.jl @@ -0,0 +1,3 @@ +function minimizing_reward(value::Float64, max_value::Float64) + return exp(-0.5 * (value / (max_value / 3))^2) +end \ No newline at end of file