mirror of
https://gitlab.rlp.net/mobitar/ReCo.jl.git
synced 2024-12-21 00:51:21 +00:00
Added minimizing reward
This commit is contained in:
parent
c68d7f3e45
commit
4e054653f5
3 changed files with 7 additions and 15 deletions
|
@ -159,19 +159,6 @@ function state_update_hook!(env_helper::LocalCOMEnvHelper, particles::Vector{Par
|
|||
return nothing
|
||||
end
|
||||
|
||||
"""
|
||||
minimizing_reward(value::Float64, max_value::Float64)
|
||||
|
||||
Returns the reward such that it is 0 for value=max_value and 1 for value=0.
|
||||
"""
|
||||
function minimizing_reward(value::Float64, max_value::Float64)
|
||||
if value > max_value
|
||||
error("value > max_value")
|
||||
end
|
||||
|
||||
return ((max_value - value) / (max_value + value))^2
|
||||
end
|
||||
|
||||
function update_reward!(env::LocalCOMEnv, env_helper::LocalCOMEnvHelper, particle::Particle)
|
||||
id = particle.id
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ export run_rl, LocalCOMEnv
|
|||
using Base: OneTo
|
||||
|
||||
using ReinforcementLearning
|
||||
using Flux: InvDecay
|
||||
using Flux: Flux
|
||||
using Intervals
|
||||
using StaticArrays: SVector
|
||||
using LoopVectorization: @turbo
|
||||
|
@ -22,6 +22,7 @@ include("EnvHelper.jl")
|
|||
|
||||
include("States.jl")
|
||||
include("Hooks.jl")
|
||||
include("Reward.jl")
|
||||
|
||||
function gen_agent(n_states::Int64, n_actions::Int64, ϵ_stable::Float64)
|
||||
# TODO: Optimize warmup and decay
|
||||
|
@ -31,8 +32,9 @@ function gen_agent(n_states::Int64, n_actions::Int64, ϵ_stable::Float64)
|
|||
policy = QBasedPolicy(;
|
||||
learner=MonteCarloLearner(;
|
||||
approximator=TabularQApproximator(;
|
||||
n_state=n_states, n_action=n_actions, opt=InvDecay(1.0)
|
||||
n_state=n_states, n_action=n_actions, opt=Flux.InvDecay(1.0)
|
||||
),
|
||||
γ=0.95, # Reward discount
|
||||
),
|
||||
explorer=EpsilonGreedyExplorer(;
|
||||
kind=:linear,
|
||||
|
|
3
src/RL/Reward.jl
Normal file
3
src/RL/Reward.jl
Normal file
|
@ -0,0 +1,3 @@
|
|||
function minimizing_reward(value::Float64, max_value::Float64)
|
||||
return exp(-0.5 * (value / (max_value / 3))^2)
|
||||
end
|
Loading…
Reference in a new issue