1
0
Fork 0
mirror of https://gitlab.rlp.net/mobitar/ReCo.jl.git synced 2024-12-21 00:51:21 +00:00

Added copy states hook

This commit is contained in:
Mo8it 2022-01-29 16:45:36 +01:00
parent bcf760243c
commit 8850e5dd34
3 changed files with 24 additions and 15 deletions

View file

@ -1,6 +1,6 @@
export LocalCOMWithAdditionalShapeRewardEnv export LocalCOMWithAdditionalShapeRewardEnv
using ..ReCo: Particle using ..ReCo: ReCo
struct LocalCOMWithAdditionalShapeRewardEnv <: Env struct LocalCOMWithAdditionalShapeRewardEnv <: Env
shared::EnvSharedProps shared::EnvSharedProps
@ -51,7 +51,7 @@ mutable struct LocalCOMWithAdditionalShapeRewardEnvHelper <: EnvHelper
max_elliptical_distance::Float64 max_elliptical_distance::Float64
function LocalCOMWithAdditionalShapeRewardEnvHelper( function LocalCOMWithAdditionalShapeRewardEnvHelper(
shared::EnvHelperSharedProps, half_box_len::Float64, skin_radius shared::EnvHelperSharedProps, half_box_len::Float64, skin_radius::Float64
) )
max_elliptical_distance = sqrt(2) * half_box_len / shared.elliptical_a_b_ratio max_elliptical_distance = sqrt(2) * half_box_len / shared.elliptical_a_b_ratio
@ -106,14 +106,10 @@ function state_update_helper_hook!(
end end
function state_update_hook!( function state_update_hook!(
env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper, particles::Vector{Particle} env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper, particles::Vector{ReCo.Particle}
) )
n_particles = env_helper.shared.n_particles n_particles = env_helper.shared.n_particles
@turbo for id in 1:n_particles
env_helper.shared.old_states_id[id] = env_helper.shared.states_id[id]
end
env = env_helper.shared.env env = env_helper.shared.env
distance_to_local_center_of_mass_sum = 0.0 distance_to_local_center_of_mass_sum = 0.0
@ -173,7 +169,7 @@ end
function update_reward!( function update_reward!(
env::LocalCOMWithAdditionalShapeRewardEnv, env::LocalCOMWithAdditionalShapeRewardEnv,
env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper, env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper,
particle::Particle, particle::ReCo.Particle,
) )
id = particle.id id = particle.id

View file

@ -23,20 +23,18 @@ function update_table_and_actions_hook!(
) )
env, agent, hook = get_env_agent_hook(env_helper) env, agent, hook = get_env_agent_hook(env_helper)
id = particle.id
if !first_integration_step if !first_integration_step
# Old state # Old state
env.shared.state_id = env_helper.shared.old_states_id[id] env.shared.state_id = env_helper.shared.old_states_id[particle.id]
action_id = env_helper.shared.actions_id[id] action_id = env_helper.shared.actions_id[particle.id]
# Pre act # Pre act
agent(PRE_ACT_STAGE, env, action_id) agent(PRE_ACT_STAGE, env, action_id)
hook(PRE_ACT_STAGE, agent, env, action_id) hook(PRE_ACT_STAGE, agent, env, action_id)
# Update to current state # Update to current state
env.shared.state_id = env_helper.shared.states_id[id] env.shared.state_id = env_helper.shared.states_id[particle.id]
# Update reward # Update reward
update_reward!(env, env_helper, particle) update_reward!(env, env_helper, particle)
@ -50,8 +48,8 @@ function update_table_and_actions_hook!(
action_id = agent(env) action_id = agent(env)
action = env.shared.action_space[action_id] action = env.shared.action_space[action_id]
env_helper.shared.actions[id] = action env_helper.shared.actions[particle.id] = action
env_helper.shared.actions_id[id] = action_id env_helper.shared.actions_id[particle.id] = action_id
return nothing return nothing
end end
@ -72,3 +70,17 @@ function act_hook!(
return nothing return nothing
end end
function copy_states_to_old_states_hook!(env_helper::EnvHelper)
n_particles = env_helper.shared.n_particles
@turbo for particle_id in 1:n_particles
env_helper.shared.old_states_id[particle_id] = env_helper.shared.states_id[particle_id]
end
return nothing
end
function copy_states_to_old_states_hook!(::Nothing)
return nothing
end

View file

@ -66,6 +66,7 @@ function euler!(
end end
end end
RL.copy_states_to_old_states_hook!(env_helper)
state_update_hook!(env_helper, args.particles) state_update_hook!(env_helper, args.particles)
@simd for p in args.particles @simd for p in args.particles