mirror of
https://gitlab.rlp.net/mobitar/ReCo.jl.git
synced 2024-12-21 00:51:21 +00:00
Only one agent
This commit is contained in:
parent
5fc3df66cd
commit
46e9a7fb60
5 changed files with 220 additions and 219 deletions
344
src/RL.jl
344
src/RL.jl
|
@ -14,19 +14,23 @@ using ..ReCo: ReCo, Particle, angle2
|
||||||
|
|
||||||
const INITIAL_REWARD = 0.0
|
const INITIAL_REWARD = 0.0
|
||||||
|
|
||||||
mutable struct EnvParams
|
mutable struct Env <: AbstractEnv
|
||||||
action_space::Vector{Tuple{Float64,Float64}}
|
n_actions::Int64
|
||||||
|
action_space::Vector{SVector{2,Float64}}
|
||||||
action_ind_space::Vector{Int64}
|
action_ind_space::Vector{Int64}
|
||||||
|
|
||||||
distance_state_space::Vector{Interval}
|
distance_state_space::Vector{Interval}
|
||||||
angle_state_space::Vector{Interval}
|
angle_state_space::Vector{Interval}
|
||||||
state_space::Vector{Union{Tuple{Interval,Interval},Tuple{Nothing,Nothing}}}
|
|
||||||
state_ind_space::Vector{Int64}
|
|
||||||
n_states::Int64
|
n_states::Int64
|
||||||
|
state_space::Vector{SVector{2,Interval}}
|
||||||
|
state_ind_space::Vector{Int64}
|
||||||
|
state_ind::Int64
|
||||||
|
|
||||||
reward::Float64
|
reward::Float64
|
||||||
|
terminated::Bool
|
||||||
|
|
||||||
function EnvParams(
|
function Env(
|
||||||
min_distance::Float64,
|
min_distance::Float64,
|
||||||
max_distance::Float64;
|
max_distance::Float64;
|
||||||
n_v_actions::Int64=2,
|
n_v_actions::Int64=2,
|
||||||
|
@ -48,12 +52,12 @@ mutable struct EnvParams
|
||||||
|
|
||||||
n_actions = n_v_actions * n_ω_actions
|
n_actions = n_v_actions * n_ω_actions
|
||||||
|
|
||||||
action_space = Vector{Tuple{Float64,Float64}}(undef, n_actions)
|
action_space = Vector{SVector{2,Float64}}(undef, n_actions)
|
||||||
|
|
||||||
ind = 1
|
ind = 1
|
||||||
for v in v_action_space
|
for v in v_action_space
|
||||||
for ω in ω_action_space
|
for ω in ω_action_space
|
||||||
action_space[ind] = (v, ω)
|
action_space[ind] = SVector(v, ω)
|
||||||
ind += 1
|
ind += 1
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -95,156 +99,112 @@ mutable struct EnvParams
|
||||||
|
|
||||||
n_states = n_distance_states * n_angle_states + 1
|
n_states = n_distance_states * n_angle_states + 1
|
||||||
|
|
||||||
state_space = Vector{Union{Tuple{Interval,Interval},Tuple{Nothing,Nothing}}}(
|
state_space = Vector{SVector{2,Interval}}(undef, n_states - 1)
|
||||||
undef, n_states
|
|
||||||
)
|
|
||||||
|
|
||||||
ind = 1
|
ind = 1
|
||||||
for distance_state in distance_state_space
|
for distance_state in distance_state_space
|
||||||
for angle_state in angle_state_space
|
for angle_state in angle_state_space
|
||||||
state_space[ind] = (distance_state, angle_state)
|
state_space[ind] = SVector(distance_state, angle_state)
|
||||||
ind += 1
|
ind += 1
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
state_space[ind] = (nothing, nothing)
|
# Last state is SVector(nothing, nothing)
|
||||||
|
|
||||||
state_ind_space = collect(1:n_states)
|
state_ind_space = collect(1:n_states)
|
||||||
|
|
||||||
|
# initial_state = SVector(nothing, nothing)
|
||||||
|
initial_state_ind = n_states
|
||||||
|
|
||||||
return new(
|
return new(
|
||||||
|
n_actions,
|
||||||
action_space,
|
action_space,
|
||||||
action_ind_space,
|
action_ind_space,
|
||||||
distance_state_space,
|
distance_state_space,
|
||||||
angle_state_space,
|
angle_state_space,
|
||||||
|
n_states,
|
||||||
state_space,
|
state_space,
|
||||||
state_ind_space,
|
state_ind_space,
|
||||||
n_states,
|
initial_state_ind,
|
||||||
INITIAL_REWARD,
|
INITIAL_REWARD,
|
||||||
|
false,
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function reset!(env_params::EnvParams)
|
function reset!(env::Env)
|
||||||
env_params.reward = INITIAL_REWARD
|
env.state_ind = env.n_states
|
||||||
|
env.reward = INITIAL_REWARD
|
||||||
|
env.terminated = false
|
||||||
|
|
||||||
return nothing
|
return nothing
|
||||||
end
|
end
|
||||||
|
|
||||||
mutable struct Env <: AbstractEnv
|
RLBase.state_space(env::Env) = env.state_ind_space
|
||||||
params::EnvParams
|
|
||||||
particle::Particle
|
|
||||||
state_ind::Int64
|
|
||||||
|
|
||||||
function Env(params::EnvParams, particle::Particle)
|
|
||||||
# initial_state = (nothing, nothing)
|
|
||||||
initial_state_ind = params.n_states
|
|
||||||
|
|
||||||
return new(params, particle, initial_state_ind)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function reset!(env::Env, particle::Particle)
|
|
||||||
env.particle = particle
|
|
||||||
env.state_ind = env.params.n_states
|
|
||||||
|
|
||||||
return nothing
|
|
||||||
end
|
|
||||||
|
|
||||||
RLBase.state_space(env::Env) = env.params.state_ind_space
|
|
||||||
|
|
||||||
RLBase.state(env::Env) = env.state_ind
|
RLBase.state(env::Env) = env.state_ind
|
||||||
|
|
||||||
RLBase.action_space(env::Env) = env.params.action_ind_space
|
RLBase.action_space(env::Env) = env.action_ind_space
|
||||||
|
|
||||||
RLBase.reward(env::Env) = env.params.reward
|
RLBase.reward(env::Env) = env.reward
|
||||||
|
|
||||||
RLBase.is_terminated(::Env) = false
|
RLBase.is_terminated(env::Env) = env.terminated
|
||||||
|
|
||||||
function gen_policy(n_states::Int64, n_actions::Int64)
|
|
||||||
return QBasedPolicy(;
|
|
||||||
learner=MonteCarloLearner(;
|
|
||||||
approximator=TabularQApproximator(;
|
|
||||||
n_state=n_states, n_action=n_actions, opt=InvDecay(1.0)
|
|
||||||
),
|
|
||||||
),
|
|
||||||
explorer=EpsilonGreedyExplorer(0.1),
|
|
||||||
)
|
|
||||||
end
|
|
||||||
|
|
||||||
struct Params{H<:AbstractHook}
|
struct Params{H<:AbstractHook}
|
||||||
envs::Vector{Env}
|
env::Env
|
||||||
agents::Vector{Agent}
|
agent::Agent
|
||||||
hooks::Vector{H}
|
hook::H
|
||||||
actions::Vector{Tuple{Float64,Float64}}
|
|
||||||
env_params::EnvParams
|
old_states_ind::Vector{Int64}
|
||||||
|
states_ind::Vector{Int64}
|
||||||
|
|
||||||
|
actions::Vector{SVector{2,Float64}}
|
||||||
|
actions_ind::Vector{Int64}
|
||||||
|
|
||||||
n_steps_before_actions_update::Int64
|
n_steps_before_actions_update::Int64
|
||||||
min_sq_distances::Vector{Float64}
|
|
||||||
vecs_r⃗₁₂_to_min_distance_particle::Vector{SVector{2,Float64}}
|
|
||||||
goal_shape_ratio::Float64
|
goal_shape_ratio::Float64
|
||||||
|
|
||||||
function Params{H}(
|
n_particles::Int64
|
||||||
n_particles::Int64,
|
min_sq_distances::Vector{Float64}
|
||||||
env_params::EnvParams,
|
vecs_r⃗₁₂_to_min_distance_particle::Vector{SVector{2,Float64}}
|
||||||
|
|
||||||
|
function Params(
|
||||||
|
env::Env,
|
||||||
|
agent::Agent,
|
||||||
|
hook::H,
|
||||||
n_steps_before_actions_update::Int64,
|
n_steps_before_actions_update::Int64,
|
||||||
goal_shape_ratio::Float64,
|
goal_shape_ratio::Float64,
|
||||||
|
n_particles::Int64,
|
||||||
) where {H<:AbstractHook}
|
) where {H<:AbstractHook}
|
||||||
envs = [Env(env_params, ReCo.gen_tmp_particle()) for i in 1:n_particles]
|
n_states = env.n_states
|
||||||
|
|
||||||
agents = [
|
return new{H}(
|
||||||
Agent(;
|
env,
|
||||||
policy=gen_policy(env_params.n_states, length(env_params.action_space)),
|
agent,
|
||||||
trajectory=VectorSARTTrajectory(),
|
hook,
|
||||||
) for i in 1:n_particles
|
fill(0, n_particles),
|
||||||
]
|
fill(n_states, n_particles),
|
||||||
|
fill(SVector(0.0, 0.0), n_particles),
|
||||||
hooks = [H() for i in 1:n_particles]
|
fill(0, n_particles),
|
||||||
|
|
||||||
actions = Vector{Tuple{Float64,Float64}}(undef, n_particles)
|
|
||||||
|
|
||||||
min_sq_distances = fill(Inf64, n_particles)
|
|
||||||
|
|
||||||
vecs_r⃗₁₂_to_min_distance_particle = fill(SVector(0.0, 0.0), n_particles)
|
|
||||||
|
|
||||||
return new(
|
|
||||||
envs,
|
|
||||||
agents,
|
|
||||||
hooks,
|
|
||||||
actions,
|
|
||||||
env_params,
|
|
||||||
n_steps_before_actions_update,
|
n_steps_before_actions_update,
|
||||||
min_sq_distances,
|
|
||||||
vecs_r⃗₁₂_to_min_distance_particle,
|
|
||||||
goal_shape_ratio,
|
goal_shape_ratio,
|
||||||
|
n_particles,
|
||||||
|
fill(Inf64, n_particles),
|
||||||
|
fill(SVector(0.0, 0.0), n_particles),
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function get_env_agent_hook(rl_params::Params, ind::Int64)
|
function pre_integration_hook(rl_params::Params)
|
||||||
return (rl_params.envs[ind], rl_params.agents[ind], rl_params.hooks[ind])
|
@turbo for i in 1:(rl_params.n_particles)
|
||||||
end
|
|
||||||
|
|
||||||
function pre_integration_hook!(rl_params::Params, n_particles::Int64)
|
|
||||||
@simd for i in 1:n_particles
|
|
||||||
env, agent, hook = get_env_agent_hook(rl_params, i)
|
|
||||||
|
|
||||||
# Update action
|
|
||||||
action_ind = agent(env)
|
|
||||||
action = rl_params.env_params.action_space[action_ind]
|
|
||||||
rl_params.actions[i] = action
|
|
||||||
|
|
||||||
# Pre act
|
|
||||||
agent(PRE_ACT_STAGE, env, action_ind)
|
|
||||||
hook(PRE_ACT_STAGE, agent, env, action_ind)
|
|
||||||
end
|
|
||||||
|
|
||||||
@turbo for i in 1:n_particles
|
|
||||||
rl_params.min_sq_distances[i] = Inf64
|
rl_params.min_sq_distances[i] = Inf64
|
||||||
end
|
end
|
||||||
|
|
||||||
return nothing
|
return nothing
|
||||||
end
|
end
|
||||||
|
|
||||||
function state_hook(
|
function state_update_helper_hook(
|
||||||
id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}, distance²::Float64, rl_params::Params
|
rl_params::Params, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}, distance²::Float64
|
||||||
)
|
)
|
||||||
if rl_params.min_sq_distances[id1] > distance²
|
if rl_params.min_sq_distances[id1] > distance²
|
||||||
rl_params.min_sq_distances[id1] = distance²
|
rl_params.min_sq_distances[id1] = distance²
|
||||||
|
@ -261,56 +221,33 @@ function state_hook(
|
||||||
return nothing
|
return nothing
|
||||||
end
|
end
|
||||||
|
|
||||||
function integration_hook!(
|
function get_state_ind(state::S, state_space::Vector{S}) where {S<:SVector{2,Interval}}
|
||||||
particle::Particle, rl_params::Params, δt::Float64, si::Float64, co::Float64
|
return findfirst(x -> x == state, state_space)
|
||||||
|
end
|
||||||
|
|
||||||
|
function state_update_hook(
|
||||||
|
rl_params::Params, particles::Vector{Particle}, n_particles::Int64
|
||||||
)
|
)
|
||||||
# Apply action
|
@turbo for i in 1:n_particles
|
||||||
action = rl_params.actions[particle.id]
|
rl_params.old_states_ind[i] = rl_params.states_ind[i]
|
||||||
|
end
|
||||||
|
|
||||||
vδt = action[1] * δt
|
env = rl_params.env
|
||||||
particle.tmp_c += SVector(vδt * co, vδt * si)
|
|
||||||
particle.φ += action[2] * δt
|
|
||||||
|
|
||||||
return nothing
|
n_states = env.n_states
|
||||||
end
|
|
||||||
|
|
||||||
function get_state_ind(state::Tuple{Interval,Interval}, env_params::EnvParams)
|
env_angle_state = env.angle_state_space[1]
|
||||||
return findfirst(x -> x == state, env_params.state_space)
|
|
||||||
end
|
|
||||||
|
|
||||||
function get_state_ind(::Tuple{Nothing,Nothing}, env_params::EnvParams)
|
state_space = env.state_space
|
||||||
return env_params.n_states
|
|
||||||
end
|
|
||||||
|
|
||||||
function post_integration_hook(
|
|
||||||
rl_params::Params,
|
|
||||||
n_particles::Int64,
|
|
||||||
particles::Vector{Particle},
|
|
||||||
half_box_len::Float64,
|
|
||||||
)
|
|
||||||
# Update reward
|
|
||||||
rl_params.env_params.reward =
|
|
||||||
1 -
|
|
||||||
(
|
|
||||||
ReCo.gyration_tensor_eigvals_ratio(particles, half_box_len) -
|
|
||||||
rl_params.goal_shape_ratio
|
|
||||||
)^2
|
|
||||||
|
|
||||||
# Update states
|
|
||||||
n_states = rl_params.env_params.n_states
|
|
||||||
|
|
||||||
env_angle_state = rl_params.env_params.angle_state_space[1]
|
|
||||||
|
|
||||||
for i in 1:n_particles
|
for i in 1:n_particles
|
||||||
env, agent, hook = get_env_agent_hook(rl_params, i)
|
|
||||||
|
|
||||||
env_distance_state::Union{Interval,Nothing} = nothing
|
env_distance_state::Union{Interval,Nothing} = nothing
|
||||||
|
|
||||||
min_sq_distance = rl_params.min_sq_distances[i]
|
min_sq_distance = rl_params.min_sq_distances[i]
|
||||||
min_distance = sqrt(min_sq_distance)
|
min_distance = sqrt(min_sq_distance)
|
||||||
|
|
||||||
if !isinf(min_sq_distance)
|
if !isinf(min_sq_distance)
|
||||||
for distance_state in rl_params.env_params.distance_state_space
|
for distance_state in env.distance_state_space
|
||||||
if min_distance in distance_state
|
if min_distance in distance_state
|
||||||
env_distance_state = distance_state
|
env_distance_state = distance_state
|
||||||
break
|
break
|
||||||
|
@ -318,10 +255,10 @@ function post_integration_hook(
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
if isnothing(env_distance_state)
|
|
||||||
# (nothing, nothing)
|
# (nothing, nothing)
|
||||||
env.state_ind = n_states
|
state_ind = n_states
|
||||||
else
|
|
||||||
|
if !isnothing(env_distance_state)
|
||||||
r⃗₁₂ = rl_params.vecs_r⃗₁₂_to_min_distance_particle[i]
|
r⃗₁₂ = rl_params.vecs_r⃗₁₂_to_min_distance_particle[i]
|
||||||
si, co = sincos(particles[i].φ)
|
si, co = sincos(particles[i].φ)
|
||||||
|
|
||||||
|
@ -337,24 +274,99 @@ function post_integration_hook(
|
||||||
=#
|
=#
|
||||||
angle = angle2(SVector(co, si), r⃗₁₂)
|
angle = angle2(SVector(co, si), r⃗₁₂)
|
||||||
|
|
||||||
for angle_state in rl_params.env_params.angle_state_space
|
for angle_state in env.angle_state_space
|
||||||
if angle in angle_state
|
if angle in angle_state
|
||||||
env_angle_state = angle_state
|
env_angle_state = angle_state
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
state = (env_distance_state, env_angle_state)
|
state = SVector{2,Interval}(env_distance_state, env_angle_state)
|
||||||
env.state_ind = get_state_ind(state, env.params)
|
state_ind = get_state_ind(state, state_space)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
rl_params.states_ind[i] = state_ind
|
||||||
|
end
|
||||||
|
|
||||||
|
return nothing
|
||||||
|
end
|
||||||
|
|
||||||
|
function get_env_agent_hook(rl_params::Params)
|
||||||
|
return (rl_params.env, rl_params.agent, rl_params.hook)
|
||||||
|
end
|
||||||
|
|
||||||
|
function update_table_and_actions_hook(
|
||||||
|
rl_params::Params, particle::Particle, first_integration_step::Bool
|
||||||
|
)
|
||||||
|
env, agent, hook = get_env_agent_hook(rl_params)
|
||||||
|
|
||||||
|
id = particle.id
|
||||||
|
|
||||||
|
if !first_integration_step
|
||||||
|
# Old state
|
||||||
|
env.state_ind = rl_params.old_states_ind[id]
|
||||||
|
|
||||||
|
action_ind = rl_params.actions_ind[id]
|
||||||
|
|
||||||
|
# Pre act
|
||||||
|
agent(PRE_ACT_STAGE, env, action_ind)
|
||||||
|
hook(PRE_ACT_STAGE, agent, env, action_ind)
|
||||||
|
|
||||||
|
# Update to current state
|
||||||
|
env.state_ind = rl_params.states_ind[id]
|
||||||
|
|
||||||
|
# Update reward
|
||||||
|
env.reward = -(particle.c[1]^2 + particle.c[2]^2)
|
||||||
|
|
||||||
|
#=
|
||||||
|
1 -
|
||||||
|
(
|
||||||
|
ReCo.gyration_tensor_eigvals_ratio(particles, half_box_len) -
|
||||||
|
rl_params.goal_shape_ratio
|
||||||
|
)^2
|
||||||
|
=#
|
||||||
|
|
||||||
# Post act
|
# Post act
|
||||||
agent(POST_ACT_STAGE, env)
|
agent(POST_ACT_STAGE, env)
|
||||||
hook(POST_ACT_STAGE, agent, env)
|
hook(POST_ACT_STAGE, agent, env)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# Update action
|
||||||
|
action_ind = agent(env)
|
||||||
|
action = env.action_space[action_ind]
|
||||||
|
|
||||||
|
rl_params.actions[id] = action
|
||||||
|
rl_params.actions_ind[id] = action_ind
|
||||||
return nothing
|
return nothing
|
||||||
end
|
end
|
||||||
|
|
||||||
|
act_hook(::Nothing, args...) = nothing
|
||||||
|
|
||||||
|
function act_hook(
|
||||||
|
rl_params::Params, particle::Particle, δt::Float64, si::Float64, co::Float64
|
||||||
|
)
|
||||||
|
# Apply action
|
||||||
|
action = rl_params.actions[particle.id]
|
||||||
|
|
||||||
|
vδt = action[1] * δt
|
||||||
|
particle.tmp_c += SVector(vδt * co, vδt * si)
|
||||||
|
particle.φ += action[2] * δt
|
||||||
|
|
||||||
|
return nothing
|
||||||
|
end
|
||||||
|
|
||||||
|
function gen_agent(n_states::Int64, n_actions::Int64)
|
||||||
|
policy = QBasedPolicy(;
|
||||||
|
learner=MonteCarloLearner(;
|
||||||
|
approximator=TabularQApproximator(;
|
||||||
|
n_state=n_states, n_action=n_actions, opt=InvDecay(1.0)
|
||||||
|
),
|
||||||
|
),
|
||||||
|
explorer=EpsilonGreedyExplorer(0.1),
|
||||||
|
)
|
||||||
|
|
||||||
|
return Agent(; policy=policy, trajectory=VectorSARTTrajectory())
|
||||||
|
end
|
||||||
|
|
||||||
function run_rl(;
|
function run_rl(;
|
||||||
goal_shape_ratio::Float64,
|
goal_shape_ratio::Float64,
|
||||||
n_episodes::Int64=100,
|
n_episodes::Int64=100,
|
||||||
|
@ -375,60 +387,46 @@ function run_rl(;
|
||||||
sim_consts = ReCo.gen_sim_consts(n_particles, 0.0; skin_to_interaction_r_ratio=1.6)
|
sim_consts = ReCo.gen_sim_consts(n_particles, 0.0; skin_to_interaction_r_ratio=1.6)
|
||||||
n_particles = sim_consts.n_particles
|
n_particles = sim_consts.n_particles
|
||||||
|
|
||||||
env_params = EnvParams(sim_consts.particle_radius, sim_consts.skin_r)
|
env = Env(sim_consts.particle_radius, sim_consts.skin_r)
|
||||||
|
|
||||||
|
agent = gen_agent(env.n_states, env.n_actions)
|
||||||
|
|
||||||
n_steps_before_actions_update = round(Int64, update_actions_at / sim_consts.δt)
|
n_steps_before_actions_update = round(Int64, update_actions_at / sim_consts.δt)
|
||||||
|
|
||||||
rl_params = Params{TotalRewardPerEpisode}(
|
hook = TotalRewardPerEpisode()
|
||||||
n_particles, env_params, n_steps_before_actions_update, goal_shape_ratio
|
|
||||||
|
rl_params = Params(
|
||||||
|
env, agent, hook, n_steps_before_actions_update, goal_shape_ratio, n_particles
|
||||||
)
|
)
|
||||||
|
|
||||||
# Pre experiment
|
# Pre experiment
|
||||||
@simd for i in 1:n_particles
|
|
||||||
env, agent, hook = get_env_agent_hook(rl_params, i)
|
|
||||||
|
|
||||||
hook(PRE_EXPERIMENT_STAGE, agent, env)
|
hook(PRE_EXPERIMENT_STAGE, agent, env)
|
||||||
agent(PRE_EXPERIMENT_STAGE, env)
|
agent(PRE_EXPERIMENT_STAGE, env)
|
||||||
end
|
|
||||||
|
|
||||||
@showprogress 0.6 for episode in 1:n_episodes
|
@showprogress 0.6 for episode in 1:n_episodes
|
||||||
dir, particles = ReCo.init_sim_with_sim_consts(sim_consts; parent_dir="RL")
|
dir = ReCo.init_sim_with_sim_consts(sim_consts; parent_dir="RL")
|
||||||
|
|
||||||
# Reset
|
# Reset
|
||||||
@simd for i in 1:n_particles
|
reset!(env)
|
||||||
reset!(rl_params.envs[i], particles[i])
|
|
||||||
end
|
|
||||||
|
|
||||||
reset!(rl_params.env_params)
|
|
||||||
|
|
||||||
# Pre espisode
|
# Pre espisode
|
||||||
@simd for i in 1:n_particles
|
|
||||||
env, agent, hook = get_env_agent_hook(rl_params, i)
|
|
||||||
|
|
||||||
hook(PRE_EPISODE_STAGE, agent, env)
|
hook(PRE_EPISODE_STAGE, agent, env)
|
||||||
agent(PRE_EPISODE_STAGE, env)
|
agent(PRE_EPISODE_STAGE, env)
|
||||||
end
|
|
||||||
|
|
||||||
# Episode
|
# Episode
|
||||||
ReCo.run_sim(
|
ReCo.run_sim(
|
||||||
dir; duration=episode_duration, seed=rand(1:typemax(Int64)), rl_params=rl_params
|
dir; duration=episode_duration, seed=rand(1:typemax(Int64)), rl_params=rl_params
|
||||||
)
|
)
|
||||||
|
|
||||||
# Post episode
|
env.terminated = true
|
||||||
@simd for i in 1:n_particles
|
|
||||||
env, agent, hook = get_env_agent_hook(rl_params, i)
|
|
||||||
|
|
||||||
|
# Post episode
|
||||||
hook(POST_EPISODE_STAGE, agent, env)
|
hook(POST_EPISODE_STAGE, agent, env)
|
||||||
agent(POST_EPISODE_STAGE, env)
|
agent(POST_EPISODE_STAGE, env)
|
||||||
end
|
end
|
||||||
end
|
|
||||||
|
|
||||||
# Post experiment
|
# Post experiment
|
||||||
@simd for i in 1:n_particles
|
|
||||||
env, agent, hook = get_env_agent_hook(rl_params, i)
|
|
||||||
|
|
||||||
hook(POST_EXPERIMENT_STAGE, agent, env)
|
hook(POST_EXPERIMENT_STAGE, agent, env)
|
||||||
end
|
|
||||||
|
|
||||||
return rl_params
|
return rl_params
|
||||||
end
|
end
|
||||||
|
|
|
@ -13,8 +13,6 @@ using CellListMap: Box, CellList, map_pairwise!, UpdateCellList!
|
||||||
using Random: Random
|
using Random: Random
|
||||||
using Dates: Dates, now
|
using Dates: Dates, now
|
||||||
|
|
||||||
import Base: wait
|
|
||||||
|
|
||||||
include("PreVectors.jl")
|
include("PreVectors.jl")
|
||||||
using .PreVectors
|
using .PreVectors
|
||||||
|
|
||||||
|
|
11
src/run.jl
11
src/run.jl
|
@ -102,14 +102,6 @@ function run_sim(
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if !isnothing(rl_params)
|
|
||||||
pre_integration_hook! = RL.pre_integration_hook!
|
|
||||||
integration_hook! = RL.integration_hook!
|
|
||||||
post_integration_hook = RL.post_integration_hook
|
|
||||||
else
|
|
||||||
pre_integration_hook! = integration_hook! = post_integration_hook = empty_hook
|
|
||||||
end
|
|
||||||
|
|
||||||
simulate(
|
simulate(
|
||||||
args,
|
args,
|
||||||
T0,
|
T0,
|
||||||
|
@ -120,9 +112,6 @@ function run_sim(
|
||||||
dir,
|
dir,
|
||||||
save_data,
|
save_data,
|
||||||
rl_params,
|
rl_params,
|
||||||
pre_integration_hook!,
|
|
||||||
integration_hook!,
|
|
||||||
post_integration_hook,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return nothing
|
return nothing
|
||||||
|
|
|
@ -142,7 +142,7 @@ function init_sim_with_sim_consts(
|
||||||
|
|
||||||
wait(task)
|
wait(task)
|
||||||
|
|
||||||
return (dir, particles)
|
return dir
|
||||||
end
|
end
|
||||||
|
|
||||||
function init_sim(;
|
function init_sim(;
|
||||||
|
@ -165,5 +165,5 @@ function init_sim(;
|
||||||
|
|
||||||
return init_sim_with_sim_consts(
|
return init_sim_with_sim_consts(
|
||||||
sim_consts; exports_dir=exports_dir, parent_dir=parent_dir, comment=comment
|
sim_consts; exports_dir=exports_dir, parent_dir=parent_dir, comment=comment
|
||||||
)[1]
|
)
|
||||||
end
|
end
|
|
@ -34,9 +34,11 @@ end
|
||||||
|
|
||||||
function euler!(
|
function euler!(
|
||||||
args,
|
args,
|
||||||
state_hook::Function,
|
first_integration_step::Bool,
|
||||||
integration_hook!::Function,
|
|
||||||
rl_params::Union{RL.Params,Nothing},
|
rl_params::Union{RL.Params,Nothing},
|
||||||
|
state_update_helper_hook::Function,
|
||||||
|
state_update_hook::Function,
|
||||||
|
update_table_and_actions_hook::Function,
|
||||||
)
|
)
|
||||||
for id1 in 1:(args.n_particles - 1)
|
for id1 in 1:(args.n_particles - 1)
|
||||||
p1 = args.particles[id1]
|
p1 = args.particles[id1]
|
||||||
|
@ -50,18 +52,20 @@ function euler!(
|
||||||
p1_c, p2.c, args.interaction_r², args.half_box_len
|
p1_c, p2.c, args.interaction_r², args.half_box_len
|
||||||
)
|
)
|
||||||
|
|
||||||
state_hook(id1, id2, r⃗₁₂, distance², rl_params)
|
state_update_helper_hook(rl_params, id1, id2, r⃗₁₂, distance²)
|
||||||
|
|
||||||
if overlapping
|
if overlapping
|
||||||
c = args.c₁ / (distance²^4) * (args.c₂ / (distance²^3) - 1.0)
|
factor = args.c₁ / (distance²^4) * (args.c₂ / (distance²^3) - 1.0)
|
||||||
|
dc = factor * r⃗₁₂
|
||||||
|
|
||||||
dc = c * r⃗₁₂
|
|
||||||
p1.tmp_c -= dc
|
p1.tmp_c -= dc
|
||||||
p2.tmp_c += dc
|
p2.tmp_c += dc
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
state_update_hook(rl_params, args.particles, args.n_particles)
|
||||||
|
|
||||||
@simd for p in args.particles
|
@simd for p in args.particles
|
||||||
si, co = sincos(p.φ)
|
si, co = sincos(p.φ)
|
||||||
p.tmp_c += SVector(
|
p.tmp_c += SVector(
|
||||||
|
@ -71,7 +75,9 @@ function euler!(
|
||||||
|
|
||||||
restrict_coordinates!(p, args.half_box_len)
|
restrict_coordinates!(p, args.half_box_len)
|
||||||
|
|
||||||
integration_hook!(p, rl_params, args.δt, si, co)
|
update_table_and_actions_hook(rl_params, p, first_integration_step)
|
||||||
|
|
||||||
|
RL.act_hook(rl_params, p, args.δt, si, co)
|
||||||
|
|
||||||
p.φ += args.c₄ * rand_normal01()
|
p.φ += args.c₄ * rand_normal01()
|
||||||
|
|
||||||
|
@ -81,11 +87,11 @@ function euler!(
|
||||||
return nothing
|
return nothing
|
||||||
end
|
end
|
||||||
|
|
||||||
wait(::Nothing) = nothing
|
Base.wait(::Nothing) = nothing
|
||||||
|
|
||||||
gen_run_hooks(::Nothing, args...) = false
|
gen_run_additional_hooks(::Nothing, args...) = false
|
||||||
|
|
||||||
function gen_run_hooks(rl_params::RL.Params, integration_step::Int64)
|
function gen_run_additional_hooks(rl_params::RL.Params, integration_step::Int64)
|
||||||
return (integration_step % rl_params.n_steps_before_actions_update == 0) ||
|
return (integration_step % rl_params.n_steps_before_actions_update == 0) ||
|
||||||
(integration_step == 1)
|
(integration_step == 1)
|
||||||
end
|
end
|
||||||
|
@ -100,9 +106,6 @@ function simulate(
|
||||||
dir::String,
|
dir::String,
|
||||||
save_data::Bool,
|
save_data::Bool,
|
||||||
rl_params::Union{RL.Params,Nothing},
|
rl_params::Union{RL.Params,Nothing},
|
||||||
pre_integration_hook!::Function,
|
|
||||||
integration_hook!::Function,
|
|
||||||
post_integration_hook::Function,
|
|
||||||
)
|
)
|
||||||
bundle_snapshot_counter = 0
|
bundle_snapshot_counter = 0
|
||||||
|
|
||||||
|
@ -111,8 +114,11 @@ function simulate(
|
||||||
cl = CellList(args.particles_c, args.box; parallel=false)
|
cl = CellList(args.particles_c, args.box; parallel=false)
|
||||||
cl = update_verlet_lists!(args, cl)
|
cl = update_verlet_lists!(args, cl)
|
||||||
|
|
||||||
|
first_integration_step = true
|
||||||
|
|
||||||
run_hooks = false
|
run_hooks = false
|
||||||
state_hook = empty_hook
|
state_update_helper_hook =
|
||||||
|
state_update_hook = update_table_and_actions_hook = empty_hook
|
||||||
|
|
||||||
start_time = now()
|
start_time = now()
|
||||||
println("Started simulation at $start_time.")
|
println("Started simulation at $start_time.")
|
||||||
|
@ -138,21 +144,31 @@ function simulate(
|
||||||
cl = update_verlet_lists!(args, cl)
|
cl = update_verlet_lists!(args, cl)
|
||||||
end
|
end
|
||||||
|
|
||||||
run_hooks = gen_run_hooks(rl_params, integration_step)
|
run_additional_hooks = gen_run_additional_hooks(rl_params, integration_step)
|
||||||
|
|
||||||
if run_hooks
|
if run_additional_hooks
|
||||||
pre_integration_hook!(rl_params, args.n_particles)
|
RL.pre_integration_hook(rl_params)
|
||||||
state_hook = RL.state_hook
|
|
||||||
|
state_update_helper_hook = RL.state_update_helper_hook
|
||||||
|
state_update_hook = RL.state_update_hook
|
||||||
|
update_table_and_actions_hook = RL.update_table_and_actions_hook
|
||||||
end
|
end
|
||||||
|
|
||||||
euler!(args, state_hook, integration_hook!, rl_params)
|
euler!(
|
||||||
|
args,
|
||||||
if run_hooks
|
first_integration_step,
|
||||||
post_integration_hook(
|
rl_params,
|
||||||
rl_params, args.n_particles, args.particles, args.half_box_len
|
state_update_helper_hook,
|
||||||
|
state_update_hook,
|
||||||
|
update_table_and_actions_hook,
|
||||||
)
|
)
|
||||||
state_hook = empty_hook
|
|
||||||
|
if run_additional_hooks
|
||||||
|
state_update_helper_hook =
|
||||||
|
state_update_hook = update_table_and_actions_hook = empty_hook
|
||||||
end
|
end
|
||||||
|
|
||||||
|
first_integration_step = false
|
||||||
end
|
end
|
||||||
|
|
||||||
wait(task)
|
wait(task)
|
||||||
|
|
Loading…
Reference in a new issue