1
0
Fork 0
mirror of https://gitlab.rlp.net/mobitar/ReCo.jl.git synced 2024-12-30 17:13:27 +00:00

Added copy states hook

This commit is contained in:
Mo8it 2022-01-29 16:45:36 +01:00
parent bcf760243c
commit 8850e5dd34
3 changed files with 24 additions and 15 deletions

View file

@ -1,6 +1,6 @@
export LocalCOMWithAdditionalShapeRewardEnv
using ..ReCo: Particle
using ..ReCo: ReCo
struct LocalCOMWithAdditionalShapeRewardEnv <: Env
shared::EnvSharedProps
@ -51,7 +51,7 @@ mutable struct LocalCOMWithAdditionalShapeRewardEnvHelper <: EnvHelper
max_elliptical_distance::Float64
function LocalCOMWithAdditionalShapeRewardEnvHelper(
shared::EnvHelperSharedProps, half_box_len::Float64, skin_radius
shared::EnvHelperSharedProps, half_box_len::Float64, skin_radius::Float64
)
max_elliptical_distance = sqrt(2) * half_box_len / shared.elliptical_a_b_ratio
@ -106,14 +106,10 @@ function state_update_helper_hook!(
end
function state_update_hook!(
env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper, particles::Vector{Particle}
env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper, particles::Vector{ReCo.Particle}
)
n_particles = env_helper.shared.n_particles
@turbo for id in 1:n_particles
env_helper.shared.old_states_id[id] = env_helper.shared.states_id[id]
end
env = env_helper.shared.env
distance_to_local_center_of_mass_sum = 0.0
@ -173,7 +169,7 @@ end
function update_reward!(
env::LocalCOMWithAdditionalShapeRewardEnv,
env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper,
particle::Particle,
particle::ReCo.Particle,
)
id = particle.id

View file

@ -23,20 +23,18 @@ function update_table_and_actions_hook!(
)
env, agent, hook = get_env_agent_hook(env_helper)
id = particle.id
if !first_integration_step
# Old state
env.shared.state_id = env_helper.shared.old_states_id[id]
env.shared.state_id = env_helper.shared.old_states_id[particle.id]
action_id = env_helper.shared.actions_id[id]
action_id = env_helper.shared.actions_id[particle.id]
# Pre act
agent(PRE_ACT_STAGE, env, action_id)
hook(PRE_ACT_STAGE, agent, env, action_id)
# Update to current state
env.shared.state_id = env_helper.shared.states_id[id]
env.shared.state_id = env_helper.shared.states_id[particle.id]
# Update reward
update_reward!(env, env_helper, particle)
@ -50,8 +48,8 @@ function update_table_and_actions_hook!(
action_id = agent(env)
action = env.shared.action_space[action_id]
env_helper.shared.actions[id] = action
env_helper.shared.actions_id[id] = action_id
env_helper.shared.actions[particle.id] = action
env_helper.shared.actions_id[particle.id] = action_id
return nothing
end
@ -70,5 +68,19 @@ function act_hook!(
particle.tmp_c += SVector(vδt * co, vδt * si)
particle.φ += action[2] * δt
return nothing
end
function copy_states_to_old_states_hook!(env_helper::EnvHelper)
n_particles = env_helper.shared.n_particles
@turbo for particle_id in 1:n_particles
env_helper.shared.old_states_id[particle_id] = env_helper.shared.states_id[particle_id]
end
return nothing
end
function copy_states_to_old_states_hook!(::Nothing)
return nothing
end

View file

@ -66,6 +66,7 @@ function euler!(
end
end
RL.copy_states_to_old_states_hook!(env_helper)
state_update_hook!(env_helper, args.particles)
@simd for p in args.particles