mirror of
https://gitlab.rlp.net/mobitar/ReCo.jl.git
synced 2024-12-21 00:51:21 +00:00
Added copy states hook
This commit is contained in:
parent
bcf760243c
commit
8850e5dd34
3 changed files with 24 additions and 15 deletions
|
@ -1,6 +1,6 @@
|
|||
export LocalCOMWithAdditionalShapeRewardEnv
|
||||
|
||||
using ..ReCo: Particle
|
||||
using ..ReCo: ReCo
|
||||
|
||||
struct LocalCOMWithAdditionalShapeRewardEnv <: Env
|
||||
shared::EnvSharedProps
|
||||
|
@ -51,7 +51,7 @@ mutable struct LocalCOMWithAdditionalShapeRewardEnvHelper <: EnvHelper
|
|||
max_elliptical_distance::Float64
|
||||
|
||||
function LocalCOMWithAdditionalShapeRewardEnvHelper(
|
||||
shared::EnvHelperSharedProps, half_box_len::Float64, skin_radius
|
||||
shared::EnvHelperSharedProps, half_box_len::Float64, skin_radius::Float64
|
||||
)
|
||||
max_elliptical_distance = sqrt(2) * half_box_len / shared.elliptical_a_b_ratio
|
||||
|
||||
|
@ -106,14 +106,10 @@ function state_update_helper_hook!(
|
|||
end
|
||||
|
||||
function state_update_hook!(
|
||||
env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper, particles::Vector{Particle}
|
||||
env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper, particles::Vector{ReCo.Particle}
|
||||
)
|
||||
n_particles = env_helper.shared.n_particles
|
||||
|
||||
@turbo for id in 1:n_particles
|
||||
env_helper.shared.old_states_id[id] = env_helper.shared.states_id[id]
|
||||
end
|
||||
|
||||
env = env_helper.shared.env
|
||||
|
||||
distance_to_local_center_of_mass_sum = 0.0
|
||||
|
@ -173,7 +169,7 @@ end
|
|||
function update_reward!(
|
||||
env::LocalCOMWithAdditionalShapeRewardEnv,
|
||||
env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper,
|
||||
particle::Particle,
|
||||
particle::ReCo.Particle,
|
||||
)
|
||||
id = particle.id
|
||||
|
||||
|
|
|
@ -23,20 +23,18 @@ function update_table_and_actions_hook!(
|
|||
)
|
||||
env, agent, hook = get_env_agent_hook(env_helper)
|
||||
|
||||
id = particle.id
|
||||
|
||||
if !first_integration_step
|
||||
# Old state
|
||||
env.shared.state_id = env_helper.shared.old_states_id[id]
|
||||
env.shared.state_id = env_helper.shared.old_states_id[particle.id]
|
||||
|
||||
action_id = env_helper.shared.actions_id[id]
|
||||
action_id = env_helper.shared.actions_id[particle.id]
|
||||
|
||||
# Pre act
|
||||
agent(PRE_ACT_STAGE, env, action_id)
|
||||
hook(PRE_ACT_STAGE, agent, env, action_id)
|
||||
|
||||
# Update to current state
|
||||
env.shared.state_id = env_helper.shared.states_id[id]
|
||||
env.shared.state_id = env_helper.shared.states_id[particle.id]
|
||||
|
||||
# Update reward
|
||||
update_reward!(env, env_helper, particle)
|
||||
|
@ -50,8 +48,8 @@ function update_table_and_actions_hook!(
|
|||
action_id = agent(env)
|
||||
action = env.shared.action_space[action_id]
|
||||
|
||||
env_helper.shared.actions[id] = action
|
||||
env_helper.shared.actions_id[id] = action_id
|
||||
env_helper.shared.actions[particle.id] = action
|
||||
env_helper.shared.actions_id[particle.id] = action_id
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
@ -72,3 +70,17 @@ function act_hook!(
|
|||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function copy_states_to_old_states_hook!(env_helper::EnvHelper)
|
||||
n_particles = env_helper.shared.n_particles
|
||||
|
||||
@turbo for particle_id in 1:n_particles
|
||||
env_helper.shared.old_states_id[particle_id] = env_helper.shared.states_id[particle_id]
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function copy_states_to_old_states_hook!(::Nothing)
|
||||
return nothing
|
||||
end
|
|
@ -66,6 +66,7 @@ function euler!(
|
|||
end
|
||||
end
|
||||
|
||||
RL.copy_states_to_old_states_hook!(env_helper)
|
||||
state_update_hook!(env_helper, args.particles)
|
||||
|
||||
@simd for p in args.particles
|
||||
|
|
Loading…
Reference in a new issue