1
0
Fork 0
mirror of https://gitlab.rlp.net/mobitar/ReCo.jl.git synced 2024-12-21 00:51:21 +00:00
ReCo.jl/src/RL/Hooks.jl
2022-04-05 03:12:25 +02:00

86 lines
2.2 KiB
Julia

using ..ReCo: Particle
function pre_integration_hook!(::EnvHelper)
return ReCo.method_not_implemented()
end
function state_update_helper_hook!(
::EnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}, distance²::Float64
)
return ReCo.method_not_implemented()
end
function state_update_hook!(::EnvHelper, particles::Vector{Particle})
return ReCo.method_not_implemented()
end
function update_reward!(::Env, ::EnvHelper, particle::Particle)
return ReCo.method_not_implemented()
end
function update_table_and_actions_hook!(
env_helper::EnvHelper, particle::Particle, first_integration_step::Bool
)
env, agent, hook = get_env_agent_hook(env_helper)
if !first_integration_step
# Old state
env.shared.state_id = env_helper.shared.old_states_id[particle.id]
action_id = env_helper.shared.actions_id[particle.id]
# Pre act
agent(PRE_ACT_STAGE, env, action_id)
hook(PRE_ACT_STAGE, agent, env, action_id)
# Update to current state
env.shared.state_id = env_helper.shared.states_id[particle.id]
# Update reward
update_reward!(env, env_helper, particle)
# Post act
agent(POST_ACT_STAGE, env)
hook(POST_ACT_STAGE, agent, env)
end
# Update action
action_id = agent(env)
action = env.shared.action_space[action_id]
env_helper.shared.actions[particle.id] = action
env_helper.shared.actions_id[particle.id] = action_id
return nothing
end
function act_hook!(::Particle, ::Nothing, args...)
return nothing
end
function act_hook!(
particle::Particle, env_helper::EnvHelper, δt::Float64, si::Float64, co::Float64
)
# Apply action
action = env_helper.shared.actions[particle.id]
vδt = action[1] * δt
particle.tmp_c += SVector(vδt * co, vδt * si)
particle.φ += action[2] * δt
return nothing
end
function copy_states_to_old_states_hook!(env_helper::EnvHelper)
n_particles = env_helper.shared.n_particles
@simd for particle_id in 1:n_particles
env_helper.shared.old_states_id[particle_id] = env_helper.shared.states_id[particle_id]
end
return nothing
end
function copy_states_to_old_states_hook!(::Nothing)
return nothing
end