diff --git a/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv.jl b/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv.jl index 52078e7..99ac5b8 100644 --- a/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv.jl +++ b/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv.jl @@ -1,6 +1,6 @@ export LocalCOMWithAdditionalShapeRewardEnv -using ..ReCo: Particle +using ..ReCo: ReCo struct LocalCOMWithAdditionalShapeRewardEnv <: Env shared::EnvSharedProps @@ -51,7 +51,7 @@ mutable struct LocalCOMWithAdditionalShapeRewardEnvHelper <: EnvHelper max_elliptical_distance::Float64 function LocalCOMWithAdditionalShapeRewardEnvHelper( - shared::EnvHelperSharedProps, half_box_len::Float64, skin_radius + shared::EnvHelperSharedProps, half_box_len::Float64, skin_radius::Float64 ) max_elliptical_distance = sqrt(2) * half_box_len / shared.elliptical_a_b_ratio @@ -106,14 +106,10 @@ function state_update_helper_hook!( end function state_update_hook!( - env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper, particles::Vector{Particle} + env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper, particles::Vector{ReCo.Particle} ) n_particles = env_helper.shared.n_particles - @turbo for id in 1:n_particles - env_helper.shared.old_states_id[id] = env_helper.shared.states_id[id] - end - env = env_helper.shared.env distance_to_local_center_of_mass_sum = 0.0 @@ -173,7 +169,7 @@ end function update_reward!( env::LocalCOMWithAdditionalShapeRewardEnv, env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper, - particle::Particle, + particle::ReCo.Particle, ) id = particle.id diff --git a/src/RL/Hooks.jl b/src/RL/Hooks.jl index 6fc13b5..18a4756 100644 --- a/src/RL/Hooks.jl +++ b/src/RL/Hooks.jl @@ -23,20 +23,18 @@ function update_table_and_actions_hook!( ) env, agent, hook = get_env_agent_hook(env_helper) - id = particle.id - if !first_integration_step # Old state - env.shared.state_id = env_helper.shared.old_states_id[id] + env.shared.state_id = env_helper.shared.old_states_id[particle.id] - action_id = env_helper.shared.actions_id[id] + action_id = env_helper.shared.actions_id[particle.id] # Pre act agent(PRE_ACT_STAGE, env, action_id) hook(PRE_ACT_STAGE, agent, env, action_id) # Update to current state - env.shared.state_id = env_helper.shared.states_id[id] + env.shared.state_id = env_helper.shared.states_id[particle.id] # Update reward update_reward!(env, env_helper, particle) @@ -50,8 +48,8 @@ function update_table_and_actions_hook!( action_id = agent(env) action = env.shared.action_space[action_id] - env_helper.shared.actions[id] = action - env_helper.shared.actions_id[id] = action_id + env_helper.shared.actions[particle.id] = action + env_helper.shared.actions_id[particle.id] = action_id return nothing end @@ -70,5 +68,19 @@ function act_hook!( particle.tmp_c += SVector(vδt * co, vδt * si) particle.φ += action[2] * δt + return nothing +end + +function copy_states_to_old_states_hook!(env_helper::EnvHelper) + n_particles = env_helper.shared.n_particles + + @turbo for particle_id in 1:n_particles + env_helper.shared.old_states_id[particle_id] = env_helper.shared.states_id[particle_id] + end + + return nothing +end + +function copy_states_to_old_states_hook!(::Nothing) return nothing end \ No newline at end of file diff --git a/src/simulation.jl b/src/simulation.jl index b85eb27..6a71558 100644 --- a/src/simulation.jl +++ b/src/simulation.jl @@ -66,6 +66,7 @@ function euler!( end end + RL.copy_states_to_old_states_hook!(env_helper) state_update_hook!(env_helper, args.particles) @simd for p in args.particles