1
0
Fork 0
mirror of https://gitlab.rlp.net/mobitar/ReCo.jl.git synced 2024-12-21 00:51:21 +00:00

Fix reward

This commit is contained in:
Mo8it 2022-01-15 18:55:01 +01:00
parent 978c3f39fb
commit 28fd6bab95
3 changed files with 13 additions and 16 deletions

View file

@ -2,4 +2,4 @@
image:https://img.shields.io/badge/code%20style-blue-4495d1.svg[Code Style: Blue, link=https://github.com/invenia/BlueStyle]
**Re**inforcement learning of **co**llective behaviour.
**Re**inforcement learning of **co**llective behavior.

View file

@ -156,11 +156,7 @@ function state_update_hook(env_helper::LocalCOMEnvHelper, particles::Vector{Part
distance_to_local_center_of_mass_sum / n_particles
env_helper.add_shape_reward_term =
mean_distance_to_local_center_of_mass /
env_helper.max_distance_to_local_center_of_mass < 0.32
if env_helper.add_shape_reward_term
#println(mean_distance_to_local_center_of_mass / env_helper.max_distance_to_local_center_of_mass) # TODO: Remove
end
env_helper.max_distance_to_local_center_of_mass < 0.3
env_helper.center_of_mass = ReCo.center_of_mass(particles, env_helper.half_box_len)
@ -180,7 +176,11 @@ end
Returns the reward such that it is 0 for value=max_value and 1 for value=0.
"""
function minimizing_reward(value::Float64, max_value::Float64)
return (max_value - value) / (max_value + value)
if value > max_value
error("value > max_value")
end
return ((max_value - value) / (max_value + value))^2
end
function update_reward!(env::LocalCOMEnv, env_helper::LocalCOMEnvHelper, particle::Particle)
@ -207,12 +207,9 @@ function update_reward!(env::LocalCOMEnv, env_helper::LocalCOMEnvHelper, particl
env_helper.half_box_len,
)
reward += unnormalized_reward(
elliptical_distance,
env_helper.max_elliptical_distance, # TODO: Fix sq
reward += minimizing_reward(
elliptical_distance, env_helper.max_elliptical_distance
)
# println(elliptical_distance / env_helper.max_elliptical_distance) # TODO: Remove
end
env.shared.reward = reward / normalization

View file

@ -25,8 +25,8 @@ include("Hooks.jl")
function gen_agent(n_states::Int64, n_actions::Int64, ϵ_stable::Float64)
# TODO: Optimize warmup and decay
warmup_steps = 200_000
decay_steps = 1_000_000
warmup_steps = 500_000
decay_steps = 5_000_000
policy = QBasedPolicy(;
learner=MonteCarloLearner(;
@ -135,8 +135,8 @@ function run_rl(;
agent(POST_EPISODE_STAGE, env)
# TODO: Replace with live plot
display(hook.rewards)
display(agent.policy.explorer.step)
@show hook.rewards
@show agent.policy.explorer.step
end
# Post experiment