1
0
Fork 0
mirror of https://gitlab.rlp.net/mobitar/ReCo.jl.git synced 2024-09-19 19:01:17 +00:00

Fix reward

This commit is contained in:
Mo8it 2022-01-15 18:55:01 +01:00
parent 978c3f39fb
commit 28fd6bab95
3 changed files with 13 additions and 16 deletions

View file

@ -2,4 +2,4 @@
image:https://img.shields.io/badge/code%20style-blue-4495d1.svg[Code Style: Blue, link=https://github.com/invenia/BlueStyle] image:https://img.shields.io/badge/code%20style-blue-4495d1.svg[Code Style: Blue, link=https://github.com/invenia/BlueStyle]
**Re**inforcement learning of **co**llective behaviour. **Re**inforcement learning of **co**llective behavior.

View file

@ -156,11 +156,7 @@ function state_update_hook(env_helper::LocalCOMEnvHelper, particles::Vector{Part
distance_to_local_center_of_mass_sum / n_particles distance_to_local_center_of_mass_sum / n_particles
env_helper.add_shape_reward_term = env_helper.add_shape_reward_term =
mean_distance_to_local_center_of_mass / mean_distance_to_local_center_of_mass /
env_helper.max_distance_to_local_center_of_mass < 0.32 env_helper.max_distance_to_local_center_of_mass < 0.3
if env_helper.add_shape_reward_term
#println(mean_distance_to_local_center_of_mass / env_helper.max_distance_to_local_center_of_mass) # TODO: Remove
end
env_helper.center_of_mass = ReCo.center_of_mass(particles, env_helper.half_box_len) env_helper.center_of_mass = ReCo.center_of_mass(particles, env_helper.half_box_len)
@ -180,7 +176,11 @@ end
Returns the reward such that it is 0 for value=max_value and 1 for value=0. Returns the reward such that it is 0 for value=max_value and 1 for value=0.
""" """
function minimizing_reward(value::Float64, max_value::Float64) function minimizing_reward(value::Float64, max_value::Float64)
return (max_value - value) / (max_value + value) if value > max_value
error("value > max_value")
end
return ((max_value - value) / (max_value + value))^2
end end
function update_reward!(env::LocalCOMEnv, env_helper::LocalCOMEnvHelper, particle::Particle) function update_reward!(env::LocalCOMEnv, env_helper::LocalCOMEnvHelper, particle::Particle)
@ -207,12 +207,9 @@ function update_reward!(env::LocalCOMEnv, env_helper::LocalCOMEnvHelper, particl
env_helper.half_box_len, env_helper.half_box_len,
) )
reward += unnormalized_reward( reward += minimizing_reward(
elliptical_distance, elliptical_distance, env_helper.max_elliptical_distance
env_helper.max_elliptical_distance, # TODO: Fix sq
) )
# println(elliptical_distance / env_helper.max_elliptical_distance) # TODO: Remove
end end
env.shared.reward = reward / normalization env.shared.reward = reward / normalization

View file

@ -25,8 +25,8 @@ include("Hooks.jl")
function gen_agent(n_states::Int64, n_actions::Int64, ϵ_stable::Float64) function gen_agent(n_states::Int64, n_actions::Int64, ϵ_stable::Float64)
# TODO: Optimize warmup and decay # TODO: Optimize warmup and decay
warmup_steps = 200_000 warmup_steps = 500_000
decay_steps = 1_000_000 decay_steps = 5_000_000
policy = QBasedPolicy(; policy = QBasedPolicy(;
learner=MonteCarloLearner(; learner=MonteCarloLearner(;
@ -135,8 +135,8 @@ function run_rl(;
agent(POST_EPISODE_STAGE, env) agent(POST_EPISODE_STAGE, env)
# TODO: Replace with live plot # TODO: Replace with live plot
display(hook.rewards) @show hook.rewards
display(agent.policy.explorer.step) @show agent.policy.explorer.step
end end
# Post experiment # Post experiment