mirror of
https://gitlab.rlp.net/mobitar/ReCo.jl.git
synced 2024-11-08 22:21:08 +00:00
Fix reward
This commit is contained in:
parent
978c3f39fb
commit
28fd6bab95
3 changed files with 13 additions and 16 deletions
|
@ -2,4 +2,4 @@
|
||||||
|
|
||||||
image:https://img.shields.io/badge/code%20style-blue-4495d1.svg[Code Style: Blue, link=https://github.com/invenia/BlueStyle]
|
image:https://img.shields.io/badge/code%20style-blue-4495d1.svg[Code Style: Blue, link=https://github.com/invenia/BlueStyle]
|
||||||
|
|
||||||
**Re**inforcement learning of **co**llective behaviour.
|
**Re**inforcement learning of **co**llective behavior.
|
||||||
|
|
|
@ -156,11 +156,7 @@ function state_update_hook(env_helper::LocalCOMEnvHelper, particles::Vector{Part
|
||||||
distance_to_local_center_of_mass_sum / n_particles
|
distance_to_local_center_of_mass_sum / n_particles
|
||||||
env_helper.add_shape_reward_term =
|
env_helper.add_shape_reward_term =
|
||||||
mean_distance_to_local_center_of_mass /
|
mean_distance_to_local_center_of_mass /
|
||||||
env_helper.max_distance_to_local_center_of_mass < 0.32
|
env_helper.max_distance_to_local_center_of_mass < 0.3
|
||||||
|
|
||||||
if env_helper.add_shape_reward_term
|
|
||||||
#println(mean_distance_to_local_center_of_mass / env_helper.max_distance_to_local_center_of_mass) # TODO: Remove
|
|
||||||
end
|
|
||||||
|
|
||||||
env_helper.center_of_mass = ReCo.center_of_mass(particles, env_helper.half_box_len)
|
env_helper.center_of_mass = ReCo.center_of_mass(particles, env_helper.half_box_len)
|
||||||
|
|
||||||
|
@ -180,7 +176,11 @@ end
|
||||||
Returns the reward such that it is 0 for value=max_value and 1 for value=0.
|
Returns the reward such that it is 0 for value=max_value and 1 for value=0.
|
||||||
"""
|
"""
|
||||||
function minimizing_reward(value::Float64, max_value::Float64)
|
function minimizing_reward(value::Float64, max_value::Float64)
|
||||||
return (max_value - value) / (max_value + value)
|
if value > max_value
|
||||||
|
error("value > max_value")
|
||||||
|
end
|
||||||
|
|
||||||
|
return ((max_value - value) / (max_value + value))^2
|
||||||
end
|
end
|
||||||
|
|
||||||
function update_reward!(env::LocalCOMEnv, env_helper::LocalCOMEnvHelper, particle::Particle)
|
function update_reward!(env::LocalCOMEnv, env_helper::LocalCOMEnvHelper, particle::Particle)
|
||||||
|
@ -207,12 +207,9 @@ function update_reward!(env::LocalCOMEnv, env_helper::LocalCOMEnvHelper, particl
|
||||||
env_helper.half_box_len,
|
env_helper.half_box_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
reward += unnormalized_reward(
|
reward += minimizing_reward(
|
||||||
elliptical_distance,
|
elliptical_distance, env_helper.max_elliptical_distance
|
||||||
env_helper.max_elliptical_distance, # TODO: Fix sq
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# println(elliptical_distance / env_helper.max_elliptical_distance) # TODO: Remove
|
|
||||||
end
|
end
|
||||||
|
|
||||||
env.shared.reward = reward / normalization
|
env.shared.reward = reward / normalization
|
||||||
|
|
|
@ -25,8 +25,8 @@ include("Hooks.jl")
|
||||||
|
|
||||||
function gen_agent(n_states::Int64, n_actions::Int64, ϵ_stable::Float64)
|
function gen_agent(n_states::Int64, n_actions::Int64, ϵ_stable::Float64)
|
||||||
# TODO: Optimize warmup and decay
|
# TODO: Optimize warmup and decay
|
||||||
warmup_steps = 200_000
|
warmup_steps = 500_000
|
||||||
decay_steps = 1_000_000
|
decay_steps = 5_000_000
|
||||||
|
|
||||||
policy = QBasedPolicy(;
|
policy = QBasedPolicy(;
|
||||||
learner=MonteCarloLearner(;
|
learner=MonteCarloLearner(;
|
||||||
|
@ -135,8 +135,8 @@ function run_rl(;
|
||||||
agent(POST_EPISODE_STAGE, env)
|
agent(POST_EPISODE_STAGE, env)
|
||||||
|
|
||||||
# TODO: Replace with live plot
|
# TODO: Replace with live plot
|
||||||
display(hook.rewards)
|
@show hook.rewards
|
||||||
display(agent.policy.explorer.step)
|
@show agent.policy.explorer.step
|
||||||
end
|
end
|
||||||
|
|
||||||
# Post experiment
|
# Post experiment
|
||||||
|
|
Loading…
Reference in a new issue