mirror of
https://gitlab.rlp.net/mobitar/ReCo.jl.git
synced 2024-11-08 22:21:08 +00:00
Added copy states hook
This commit is contained in:
parent
bcf760243c
commit
8850e5dd34
3 changed files with 24 additions and 15 deletions
|
@ -1,6 +1,6 @@
|
||||||
export LocalCOMWithAdditionalShapeRewardEnv
|
export LocalCOMWithAdditionalShapeRewardEnv
|
||||||
|
|
||||||
using ..ReCo: Particle
|
using ..ReCo: ReCo
|
||||||
|
|
||||||
struct LocalCOMWithAdditionalShapeRewardEnv <: Env
|
struct LocalCOMWithAdditionalShapeRewardEnv <: Env
|
||||||
shared::EnvSharedProps
|
shared::EnvSharedProps
|
||||||
|
@ -51,7 +51,7 @@ mutable struct LocalCOMWithAdditionalShapeRewardEnvHelper <: EnvHelper
|
||||||
max_elliptical_distance::Float64
|
max_elliptical_distance::Float64
|
||||||
|
|
||||||
function LocalCOMWithAdditionalShapeRewardEnvHelper(
|
function LocalCOMWithAdditionalShapeRewardEnvHelper(
|
||||||
shared::EnvHelperSharedProps, half_box_len::Float64, skin_radius
|
shared::EnvHelperSharedProps, half_box_len::Float64, skin_radius::Float64
|
||||||
)
|
)
|
||||||
max_elliptical_distance = sqrt(2) * half_box_len / shared.elliptical_a_b_ratio
|
max_elliptical_distance = sqrt(2) * half_box_len / shared.elliptical_a_b_ratio
|
||||||
|
|
||||||
|
@ -106,14 +106,10 @@ function state_update_helper_hook!(
|
||||||
end
|
end
|
||||||
|
|
||||||
function state_update_hook!(
|
function state_update_hook!(
|
||||||
env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper, particles::Vector{Particle}
|
env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper, particles::Vector{ReCo.Particle}
|
||||||
)
|
)
|
||||||
n_particles = env_helper.shared.n_particles
|
n_particles = env_helper.shared.n_particles
|
||||||
|
|
||||||
@turbo for id in 1:n_particles
|
|
||||||
env_helper.shared.old_states_id[id] = env_helper.shared.states_id[id]
|
|
||||||
end
|
|
||||||
|
|
||||||
env = env_helper.shared.env
|
env = env_helper.shared.env
|
||||||
|
|
||||||
distance_to_local_center_of_mass_sum = 0.0
|
distance_to_local_center_of_mass_sum = 0.0
|
||||||
|
@ -173,7 +169,7 @@ end
|
||||||
function update_reward!(
|
function update_reward!(
|
||||||
env::LocalCOMWithAdditionalShapeRewardEnv,
|
env::LocalCOMWithAdditionalShapeRewardEnv,
|
||||||
env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper,
|
env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper,
|
||||||
particle::Particle,
|
particle::ReCo.Particle,
|
||||||
)
|
)
|
||||||
id = particle.id
|
id = particle.id
|
||||||
|
|
||||||
|
|
|
@ -23,20 +23,18 @@ function update_table_and_actions_hook!(
|
||||||
)
|
)
|
||||||
env, agent, hook = get_env_agent_hook(env_helper)
|
env, agent, hook = get_env_agent_hook(env_helper)
|
||||||
|
|
||||||
id = particle.id
|
|
||||||
|
|
||||||
if !first_integration_step
|
if !first_integration_step
|
||||||
# Old state
|
# Old state
|
||||||
env.shared.state_id = env_helper.shared.old_states_id[id]
|
env.shared.state_id = env_helper.shared.old_states_id[particle.id]
|
||||||
|
|
||||||
action_id = env_helper.shared.actions_id[id]
|
action_id = env_helper.shared.actions_id[particle.id]
|
||||||
|
|
||||||
# Pre act
|
# Pre act
|
||||||
agent(PRE_ACT_STAGE, env, action_id)
|
agent(PRE_ACT_STAGE, env, action_id)
|
||||||
hook(PRE_ACT_STAGE, agent, env, action_id)
|
hook(PRE_ACT_STAGE, agent, env, action_id)
|
||||||
|
|
||||||
# Update to current state
|
# Update to current state
|
||||||
env.shared.state_id = env_helper.shared.states_id[id]
|
env.shared.state_id = env_helper.shared.states_id[particle.id]
|
||||||
|
|
||||||
# Update reward
|
# Update reward
|
||||||
update_reward!(env, env_helper, particle)
|
update_reward!(env, env_helper, particle)
|
||||||
|
@ -50,8 +48,8 @@ function update_table_and_actions_hook!(
|
||||||
action_id = agent(env)
|
action_id = agent(env)
|
||||||
action = env.shared.action_space[action_id]
|
action = env.shared.action_space[action_id]
|
||||||
|
|
||||||
env_helper.shared.actions[id] = action
|
env_helper.shared.actions[particle.id] = action
|
||||||
env_helper.shared.actions_id[id] = action_id
|
env_helper.shared.actions_id[particle.id] = action_id
|
||||||
|
|
||||||
return nothing
|
return nothing
|
||||||
end
|
end
|
||||||
|
@ -72,3 +70,17 @@ function act_hook!(
|
||||||
|
|
||||||
return nothing
|
return nothing
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function copy_states_to_old_states_hook!(env_helper::EnvHelper)
|
||||||
|
n_particles = env_helper.shared.n_particles
|
||||||
|
|
||||||
|
@turbo for particle_id in 1:n_particles
|
||||||
|
env_helper.shared.old_states_id[particle_id] = env_helper.shared.states_id[particle_id]
|
||||||
|
end
|
||||||
|
|
||||||
|
return nothing
|
||||||
|
end
|
||||||
|
|
||||||
|
function copy_states_to_old_states_hook!(::Nothing)
|
||||||
|
return nothing
|
||||||
|
end
|
|
@ -66,6 +66,7 @@ function euler!(
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
RL.copy_states_to_old_states_hook!(env_helper)
|
||||||
state_update_hook!(env_helper, args.particles)
|
state_update_hook!(env_helper, args.particles)
|
||||||
|
|
||||||
@simd for p in args.particles
|
@simd for p in args.particles
|
||||||
|
|
Loading…
Reference in a new issue