mirror of
https://gitlab.rlp.net/mobitar/ReCo.jl.git
synced 2024-11-08 22:21:08 +00:00
Fixes incl. state
This commit is contained in:
parent
cfb38c6a19
commit
8ad67229a8
5 changed files with 170 additions and 117 deletions
|
@ -1,6 +1,6 @@
|
||||||
module ReCo
|
module ReCo
|
||||||
|
|
||||||
export init_sim, run_sim
|
export init_sim, run_sim, RL
|
||||||
|
|
||||||
include("PreVector.jl")
|
include("PreVector.jl")
|
||||||
include("Particle.jl")
|
include("Particle.jl")
|
||||||
|
|
|
@ -1,9 +1,12 @@
|
||||||
module RL
|
module RL
|
||||||
|
|
||||||
|
export run
|
||||||
|
|
||||||
using ReinforcementLearning
|
using ReinforcementLearning
|
||||||
using Flux: InvDecay
|
using Flux: InvDecay
|
||||||
using Intervals
|
using Intervals
|
||||||
using StaticArrays: SVector
|
using StaticArrays: SVector
|
||||||
|
using LoopVectorization: @turbo
|
||||||
using Random: Random
|
using Random: Random
|
||||||
using ProgressMeter: @showprogress
|
using ProgressMeter: @showprogress
|
||||||
|
|
||||||
|
@ -11,6 +14,8 @@ using ..ReCo
|
||||||
|
|
||||||
import Base: run
|
import Base: run
|
||||||
|
|
||||||
|
const INITIAL_REWARD = 0.0
|
||||||
|
|
||||||
struct DistanceState{L<:Bound}
|
struct DistanceState{L<:Bound}
|
||||||
interval::Interval{Float64,L,Closed}
|
interval::Interval{Float64,L,Closed}
|
||||||
|
|
||||||
|
@ -29,22 +34,25 @@ end
|
||||||
|
|
||||||
mutable struct EnvParams
|
mutable struct EnvParams
|
||||||
action_space::Vector{Tuple{Float64,Float64}}
|
action_space::Vector{Tuple{Float64,Float64}}
|
||||||
action_space_ind::Vector{Int64}
|
action_ind_space::Vector{Int64}
|
||||||
|
|
||||||
distance_state_space::Vector{DistanceState}
|
distance_state_space::Vector{DistanceState}
|
||||||
direction_state_space::Vector{DirectionState}
|
direction_state_space::Vector{DirectionState}
|
||||||
state_space::Vector{Tuple{Union{DistanceState,Nothing},Union{DirectionState,Nothing}}}
|
state_space::Vector{Union{Tuple{DistanceState,DirectionState},Tuple{Nothing,Nothing}}}
|
||||||
state_space_ind::Vector{Int64}
|
state_ind_space::Vector{Int64}
|
||||||
|
n_states::Int64
|
||||||
|
|
||||||
reward::Float64
|
reward::Float64
|
||||||
|
|
||||||
function EnvParams(
|
function EnvParams(
|
||||||
min_distance::Float64,
|
min_distance::Float64,
|
||||||
max_distance::Float64;
|
max_distance::Float64;
|
||||||
n_v_actions::Int64=5,
|
n_v_actions::Int64=3,
|
||||||
n_ω_actions::Int64=5,
|
n_ω_actions::Int64=3,
|
||||||
max_v::Float64=80.0,
|
max_v::Float64=80.0,
|
||||||
max_ω::Float64=π / 1.5,
|
max_ω::Float64=π / 1.5,
|
||||||
n_distance_states::Int64=3,
|
n_distance_states::Int64=2,
|
||||||
n_direction_states::Int64=4,
|
n_direction_states::Int64=2,
|
||||||
)
|
)
|
||||||
@assert min_distance > 0.0
|
@assert min_distance > 0.0
|
||||||
@assert max_distance > min_distance
|
@assert max_distance > min_distance
|
||||||
|
@ -68,14 +76,14 @@ mutable struct EnvParams
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
action_space_ind = collect(1:n_actions)
|
action_ind_space = collect(1:n_actions)
|
||||||
|
|
||||||
distance_range =
|
distance_range =
|
||||||
min_distance:((max_distance - min_distance) / n_distance_states):max_distance
|
min_distance:((max_distance - min_distance) / n_distance_states):max_distance
|
||||||
|
|
||||||
distance_state_space = Vector{DistanceState}(undef, n_distance_states)
|
distance_state_space = Vector{DistanceState}(undef, n_distance_states)
|
||||||
|
|
||||||
for i in 1:n_distance_states
|
@simd for i in 1:n_distance_states
|
||||||
if i == 1
|
if i == 1
|
||||||
bound = Closed
|
bound = Closed
|
||||||
else
|
else
|
||||||
|
@ -91,7 +99,7 @@ mutable struct EnvParams
|
||||||
|
|
||||||
direction_state_space = Vector{DirectionState}(undef, n_direction_states)
|
direction_state_space = Vector{DirectionState}(undef, n_direction_states)
|
||||||
|
|
||||||
for i in 1:n_direction_states
|
@simd for i in 1:n_direction_states
|
||||||
direction_state_space[i] = DirectionState(
|
direction_state_space[i] = DirectionState(
|
||||||
direction_range[i], direction_range[i + 1]
|
direction_range[i], direction_range[i + 1]
|
||||||
)
|
)
|
||||||
|
@ -100,7 +108,7 @@ mutable struct EnvParams
|
||||||
n_states = n_distance_states * n_direction_states + 1
|
n_states = n_distance_states * n_direction_states + 1
|
||||||
|
|
||||||
state_space = Vector{
|
state_space = Vector{
|
||||||
Tuple{Union{DistanceState,Nothing},Union{DirectionState,Nothing}}
|
Union{Tuple{DistanceState,DirectionState},Tuple{Nothing,Nothing}}
|
||||||
}(
|
}(
|
||||||
undef, n_states
|
undef, n_states
|
||||||
)
|
)
|
||||||
|
@ -114,22 +122,27 @@ mutable struct EnvParams
|
||||||
end
|
end
|
||||||
state_space[ind] = (nothing, nothing)
|
state_space[ind] = (nothing, nothing)
|
||||||
|
|
||||||
state_space_ind = collect(1:n_states)
|
state_ind_space = collect(1:n_states)
|
||||||
|
|
||||||
initial_reward = 0.0
|
|
||||||
|
|
||||||
return new(
|
return new(
|
||||||
action_space,
|
action_space,
|
||||||
action_space_ind,
|
action_ind_space,
|
||||||
distance_state_space,
|
distance_state_space,
|
||||||
direction_state_space,
|
direction_state_space,
|
||||||
state_space,
|
state_space,
|
||||||
state_space_ind,
|
state_ind_space,
|
||||||
initial_reward,
|
n_states,
|
||||||
|
INITIAL_REWARD,
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function reset!(env_params::EnvParams)
|
||||||
|
env_params.reward = INITIAL_REWARD
|
||||||
|
|
||||||
|
return nothing
|
||||||
|
end
|
||||||
|
|
||||||
mutable struct Env <: AbstractEnv
|
mutable struct Env <: AbstractEnv
|
||||||
params::EnvParams
|
params::EnvParams
|
||||||
particle::ReCo.Particle
|
particle::ReCo.Particle
|
||||||
|
@ -137,12 +150,29 @@ mutable struct Env <: AbstractEnv
|
||||||
|
|
||||||
function Env(params::EnvParams, particle::ReCo.Particle)
|
function Env(params::EnvParams, particle::ReCo.Particle)
|
||||||
# initial_state = (nothing, nothing)
|
# initial_state = (nothing, nothing)
|
||||||
initial_state_ind = length(params.state_space_ind)
|
initial_state_ind = params.n_states
|
||||||
|
|
||||||
return new(params, particle, initial_state_ind)
|
return new(params, particle, initial_state_ind)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function reset!(env::Env, particle::ReCo.Particle)
|
||||||
|
env.particle = particle
|
||||||
|
env.state_ind = env.params.n_states
|
||||||
|
|
||||||
|
return nothing
|
||||||
|
end
|
||||||
|
|
||||||
|
RLBase.state_space(env::Env) = env.params.state_ind_space
|
||||||
|
|
||||||
|
RLBase.state(env::Env) = env.state_ind
|
||||||
|
|
||||||
|
RLBase.action_space(env::Env) = env.params.action_ind_space
|
||||||
|
|
||||||
|
RLBase.reward(env::Env) = env.params.reward
|
||||||
|
|
||||||
|
RLBase.is_terminated(::Env) = false
|
||||||
|
|
||||||
function gen_policy(n_states::Int64, n_actions::Int64)
|
function gen_policy(n_states::Int64, n_actions::Int64)
|
||||||
return QBasedPolicy(;
|
return QBasedPolicy(;
|
||||||
learner=MonteCarloLearner(;
|
learner=MonteCarloLearner(;
|
||||||
|
@ -161,8 +191,8 @@ struct Params{H<:AbstractHook}
|
||||||
actions::Vector{Tuple{Float64,Float64}}
|
actions::Vector{Tuple{Float64,Float64}}
|
||||||
env_params::EnvParams
|
env_params::EnvParams
|
||||||
n_steps_before_actions_update::Int64
|
n_steps_before_actions_update::Int64
|
||||||
min_distance²::Vector{Float64}
|
min_sq_distances::Vector{Float64}
|
||||||
r⃗₁₂_to_min_distance_particle::Vector{SVector{2,Float64}}
|
vecs_r⃗₁₂_to_min_distance_particle::Vector{SVector{2,Float64}}
|
||||||
goal_shape_ratio::Float64
|
goal_shape_ratio::Float64
|
||||||
|
|
||||||
function Params{H}(
|
function Params{H}(
|
||||||
|
@ -171,48 +201,57 @@ struct Params{H<:AbstractHook}
|
||||||
n_steps_before_actions_update::Int64,
|
n_steps_before_actions_update::Int64,
|
||||||
goal_shape_ratio::Float64,
|
goal_shape_ratio::Float64,
|
||||||
) where {H<:AbstractHook}
|
) where {H<:AbstractHook}
|
||||||
policies = [
|
envs = [Env(env_params, ReCo.gen_tmp_particle()) for i in 1:n_particles]
|
||||||
gen_policy(length(env_params.state_space), length(env_params.action_space)) for
|
|
||||||
i in 1:n_particles
|
|
||||||
]
|
|
||||||
agents = [
|
agents = [
|
||||||
Agent(; policy=policy, trajectory=VectorSARTTrajectory()) for policy in policies
|
Agent(;
|
||||||
|
policy=gen_policy(env_params.n_states, length(env_params.action_space)),
|
||||||
|
trajectory=VectorSARTTrajectory(),
|
||||||
|
) for i in 1:n_particles
|
||||||
]
|
]
|
||||||
|
|
||||||
|
hooks = [H() for i in 1:n_particles]
|
||||||
|
|
||||||
|
actions = Vector{Tuple{Float64,Float64}}(undef, n_particles)
|
||||||
|
|
||||||
|
min_sq_distances = fill(Inf64, n_particles)
|
||||||
|
|
||||||
|
vecs_r⃗₁₂_to_min_distance_particle = fill(SVector(0.0, 0.0), n_particles)
|
||||||
|
|
||||||
return new(
|
return new(
|
||||||
[Env(env_params, ReCo.gen_tmp_particle()) for i in 1:n_particles],
|
envs,
|
||||||
agents,
|
agents,
|
||||||
[H() for i in 1:n_particles],
|
hooks,
|
||||||
Vector{Tuple{Float64,Float64}}(undef, n_particles),
|
actions,
|
||||||
env_params,
|
env_params,
|
||||||
n_steps_before_actions_update,
|
n_steps_before_actions_update,
|
||||||
zeros(n_particles),
|
min_sq_distances,
|
||||||
fill(SVector(0.0, 0.0), n_particles),
|
vecs_r⃗₁₂_to_min_distance_particle,
|
||||||
goal_shape_ratio,
|
goal_shape_ratio,
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
RLBase.state_space(env::Env) = env.params.state_space_ind
|
function get_env_agent_hook(rl_params::Params, ind::Int64)
|
||||||
|
return (rl_params.envs[ind], rl_params.agents[ind], rl_params.hooks[ind])
|
||||||
RLBase.state(env::Env) = env.state_ind
|
end
|
||||||
|
|
||||||
RLBase.action_space(env::Env) = env.params.action_space_ind
|
|
||||||
|
|
||||||
RLBase.reward(env::Env) = env.params.reward
|
|
||||||
|
|
||||||
RLBase.is_terminated(::Env) = false
|
|
||||||
|
|
||||||
function pre_integration_hook!(rl_params::Params, n_particles::Int64)
|
function pre_integration_hook!(rl_params::Params, n_particles::Int64)
|
||||||
for i in 1:n_particles
|
@simd for i in 1:n_particles
|
||||||
env = rl_params.envs[i]
|
env, agent, hook = get_env_agent_hook(rl_params, i)
|
||||||
agent = rl_params.agents[i]
|
|
||||||
|
|
||||||
|
# Update action
|
||||||
action_ind = agent(env)
|
action_ind = agent(env)
|
||||||
action = rl_params.env_params.action_space[action_ind]
|
action = rl_params.env_params.action_space[action_ind]
|
||||||
rl_params.actions[i] = action
|
rl_params.actions[i] = action
|
||||||
|
|
||||||
|
# Pre act
|
||||||
agent(PRE_ACT_STAGE, env, action_ind)
|
agent(PRE_ACT_STAGE, env, action_ind)
|
||||||
rl_params.hooks[i](PRE_ACT_STAGE, agent, env, action_ind)
|
hook(PRE_ACT_STAGE, agent, env, action_ind)
|
||||||
|
end
|
||||||
|
|
||||||
|
@turbo for i in 1:n_particles
|
||||||
|
rl_params.min_sq_distances[i] = Inf64
|
||||||
end
|
end
|
||||||
|
|
||||||
return nothing
|
return nothing
|
||||||
|
@ -221,24 +260,25 @@ end
|
||||||
function state_hook(
|
function state_hook(
|
||||||
id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}, distance²::Float64, rl_params::Params
|
id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}, distance²::Float64, rl_params::Params
|
||||||
)
|
)
|
||||||
if rl_params.min_distance²[id1] > distance²
|
if rl_params.min_sq_distances[id1] > distance²
|
||||||
rl_params.min_distance²[id1] = distance²
|
rl_params.min_sq_distances[id1] = distance²
|
||||||
|
|
||||||
rl_params.r⃗₁₂_to_min_distance_particle[id1] = r⃗₁₂
|
rl_params.vecs_r⃗₁₂_to_min_distance_particle[id1] = r⃗₁₂
|
||||||
end
|
end
|
||||||
|
|
||||||
if rl_params.min_distance²[id2] > distance²
|
if rl_params.min_sq_distances[id2] > distance²
|
||||||
rl_params.min_distance²[id2] = distance²
|
rl_params.min_sq_distances[id2] = distance²
|
||||||
|
|
||||||
rl_params.r⃗₁₂_to_min_distance_particle[id2] = -r⃗₁₂
|
rl_params.vecs_r⃗₁₂_to_min_distance_particle[id2] = -r⃗₁₂
|
||||||
end
|
end
|
||||||
|
|
||||||
return nothing
|
return nothing
|
||||||
end
|
end
|
||||||
|
|
||||||
function integration_hook(
|
function integration_hook!(
|
||||||
particle::ReCo.Particle, rl_params::Params, δt::Float64, si::Float64, co::Float64
|
particle::ReCo.Particle, rl_params::Params, δt::Float64, si::Float64, co::Float64
|
||||||
)
|
)
|
||||||
|
# Apply action
|
||||||
action = rl_params.actions[particle.id]
|
action = rl_params.actions[particle.id]
|
||||||
|
|
||||||
vδt = action[1] * δt
|
vδt = action[1] * δt
|
||||||
|
@ -248,10 +288,12 @@ function integration_hook(
|
||||||
return nothing
|
return nothing
|
||||||
end
|
end
|
||||||
|
|
||||||
function get_state_ind(
|
function get_state_ind(state::Tuple{DistanceState,DirectionState}, env_params::EnvParams)
|
||||||
state::T, states::Vector{T}
|
return findfirst(x -> x == state, env_params.state_space)
|
||||||
) where {T<:Tuple{Union{DistanceState,Nothing},Union{DirectionState,Nothing}}}
|
end
|
||||||
return findfirst(x -> x == state, states)
|
|
||||||
|
function get_state_ind(::Tuple{Nothing,Nothing}, env_params::EnvParams)
|
||||||
|
return env_params.n_states
|
||||||
end
|
end
|
||||||
|
|
||||||
function post_integration_hook(
|
function post_integration_hook(
|
||||||
|
@ -260,28 +302,41 @@ function post_integration_hook(
|
||||||
particles::Vector{ReCo.Particle},
|
particles::Vector{ReCo.Particle},
|
||||||
half_box_len::Float64,
|
half_box_len::Float64,
|
||||||
)
|
)
|
||||||
|
# Update reward
|
||||||
|
rl_params.env_params.reward =
|
||||||
|
1 -
|
||||||
|
(
|
||||||
|
ReCo.gyration_tensor_eigvals_ratio(particles, half_box_len) -
|
||||||
|
rl_params.goal_shape_ratio
|
||||||
|
)^2
|
||||||
|
|
||||||
|
# Update states
|
||||||
|
n_states = rl_params.env_params.n_states
|
||||||
|
|
||||||
env_direction_state = rl_params.env_params.direction_state_space[1]
|
env_direction_state = rl_params.env_params.direction_state_space[1]
|
||||||
|
|
||||||
for i in 1:n_particles
|
for i in 1:n_particles
|
||||||
env = rl_params.envs[i]
|
env, agent, hook = get_env_agent_hook(rl_params, i)
|
||||||
agent = rl_params.agents[i]
|
|
||||||
|
|
||||||
min_distance = sqrt(rl_params.min_distance²[i])
|
|
||||||
|
|
||||||
env_distance_state::Union{DistanceState,Nothing} = nothing
|
env_distance_state::Union{DistanceState,Nothing} = nothing
|
||||||
|
|
||||||
|
min_sq_distance = rl_params.min_sq_distances[i]
|
||||||
|
min_distance = sqrt(min_sq_distance)
|
||||||
|
|
||||||
|
if !isinf(min_sq_distance)
|
||||||
for distance_state in rl_params.env_params.distance_state_space
|
for distance_state in rl_params.env_params.distance_state_space
|
||||||
if min_distance in distance_state.interval
|
if min_distance in distance_state.interval
|
||||||
env_distance_state = distance_state
|
env_distance_state = distance_state
|
||||||
break
|
break
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
end
|
||||||
|
|
||||||
if isnothing(env_distance_state)
|
if isnothing(env_distance_state)
|
||||||
# (nothing, nothing)
|
# (nothing, nothing)
|
||||||
env.state_ind = length(env.params.state_space)
|
env.state_ind = n_states
|
||||||
else
|
else
|
||||||
r⃗₁₂ = rl_params.r⃗₁₂_to_min_distance_particle[i]
|
r⃗₁₂ = rl_params.vecs_r⃗₁₂_to_min_distance_particle[i]
|
||||||
si, co = sincos(particles[i].φ)
|
si, co = sincos(particles[i].φ)
|
||||||
|
|
||||||
#=
|
#=
|
||||||
|
@ -290,28 +345,25 @@ function post_integration_hook(
|
||||||
angle = acos(dot(r⃗₁₂, e) / (norm(r⃗₁₂) * norm(e)))
|
angle = acos(dot(r⃗₁₂, e) / (norm(r⃗₁₂) * norm(e)))
|
||||||
norm(r⃗₁₂) == min_distance
|
norm(r⃗₁₂) == min_distance
|
||||||
norm(e) == 1
|
norm(e) == 1
|
||||||
|
|
||||||
|
min_distance is not infinite, because otherwise
|
||||||
|
env_direction_state would be nothing and this else block will not be called
|
||||||
=#
|
=#
|
||||||
direction = acos((r⃗₁₂[1] * co + r⃗₁₂[2] * si) / min_distance)
|
direction = acos((r⃗₁₂[1] * co + r⃗₁₂[2] * si) / min_distance)
|
||||||
|
|
||||||
for direction_state in rl_params.env_params.direction_state_space
|
for direction_state in rl_params.env_params.direction_state_space
|
||||||
if direction in direction_state
|
if direction in direction_state.interval
|
||||||
env_direction_state = direction_state
|
env_direction_state = direction_state
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
state = (env_distance_state, env_direction_state)
|
state = (env_distance_state, env_direction_state)
|
||||||
env.state_ind = get_state_ind(state, env.params.state_space)
|
env.state_ind = get_state_ind(state, env.params)
|
||||||
end
|
end
|
||||||
|
|
||||||
env.params.reward =
|
# Post act
|
||||||
1 -
|
|
||||||
(
|
|
||||||
ReCo.gyration_tensor_eigvals_ratio(particles, half_box_len) -
|
|
||||||
rl_params.goal_shape_ratio
|
|
||||||
)^2
|
|
||||||
|
|
||||||
agent(POST_ACT_STAGE, env)
|
agent(POST_ACT_STAGE, env)
|
||||||
rl_params.hooks[i](POST_ACT_STAGE, agent, env)
|
hook(POST_ACT_STAGE, agent, env)
|
||||||
end
|
end
|
||||||
|
|
||||||
return nothing
|
return nothing
|
||||||
|
@ -320,9 +372,10 @@ end
|
||||||
function run(;
|
function run(;
|
||||||
goal_shape_ratio::Float64,
|
goal_shape_ratio::Float64,
|
||||||
n_episodes::Int64=100,
|
n_episodes::Int64=100,
|
||||||
episode_duration::Float64=100.0,
|
episode_duration::Float64=50.0,
|
||||||
update_actions_at::Float64=0.1,
|
update_actions_at::Float64=0.2,
|
||||||
n_particles::Int64=100,
|
n_particles::Int64=100,
|
||||||
|
seed::Int64=42,
|
||||||
)
|
)
|
||||||
@assert 0.0 <= goal_shape_ratio <= 1.0
|
@assert 0.0 <= goal_shape_ratio <= 1.0
|
||||||
@assert n_episodes > 0
|
@assert n_episodes > 0
|
||||||
|
@ -330,9 +383,10 @@ function run(;
|
||||||
@assert update_actions_at in 0.01:0.01:episode_duration
|
@assert update_actions_at in 0.01:0.01:episode_duration
|
||||||
@assert n_particles > 0
|
@assert n_particles > 0
|
||||||
|
|
||||||
Random.seed!(42)
|
# Setup
|
||||||
|
Random.seed!(seed)
|
||||||
|
|
||||||
sim_consts = ReCo.gen_sim_consts(n_particles, 0.0; skin_to_interaction_r_ratio=4.0)
|
sim_consts = ReCo.gen_sim_consts(n_particles, 0.0; skin_to_interaction_r_ratio=3.0)
|
||||||
n_particles = sim_consts.n_particles
|
n_particles = sim_consts.n_particles
|
||||||
|
|
||||||
env_params = EnvParams(sim_consts.particle_radius, sim_consts.skin_r)
|
env_params = EnvParams(sim_consts.particle_radius, sim_consts.skin_r)
|
||||||
|
@ -343,50 +397,51 @@ function run(;
|
||||||
n_particles, env_params, n_steps_before_actions_update, goal_shape_ratio
|
n_particles, env_params, n_steps_before_actions_update, goal_shape_ratio
|
||||||
)
|
)
|
||||||
|
|
||||||
for i in 1:n_particles
|
# Pre experiment
|
||||||
env = rl_params.envs[i]
|
@simd for i in 1:n_particles
|
||||||
agent = rl_params.agents[i]
|
env, agent, hook = get_env_agent_hook(rl_params, i)
|
||||||
|
|
||||||
rl_params.hooks[i](PRE_EXPERIMENT_STAGE, agent, env)
|
hook(PRE_EXPERIMENT_STAGE, agent, env)
|
||||||
agent(PRE_EXPERIMENT_STAGE, env)
|
agent(PRE_EXPERIMENT_STAGE, env)
|
||||||
end
|
end
|
||||||
|
|
||||||
@showprogress 0.6 for episode in 1:n_episodes
|
@showprogress 0.6 for episode in 1:n_episodes
|
||||||
dir, particles = ReCo.init_sim_with_sim_consts(sim_consts; parent_dir="RL")
|
dir, particles = ReCo.init_sim_with_sim_consts(sim_consts; parent_dir="RL")
|
||||||
|
|
||||||
for i in 1:n_particles
|
# Reset
|
||||||
env = rl_params.envs[i]
|
@simd for i in 1:n_particles
|
||||||
agent = rl_params.agents[i]
|
reset!(rl_params.envs[i], particles[i])
|
||||||
|
end
|
||||||
|
|
||||||
rl_params.hooks[i](PRE_EPISODE_STAGE, agent, env)
|
reset!(rl_params.env_params)
|
||||||
|
|
||||||
|
# Pre espisode
|
||||||
|
@simd for i in 1:n_particles
|
||||||
|
env, agent, hook = get_env_agent_hook(rl_params, i)
|
||||||
|
|
||||||
|
hook(PRE_EPISODE_STAGE, agent, env)
|
||||||
agent(PRE_EPISODE_STAGE, env)
|
agent(PRE_EPISODE_STAGE, env)
|
||||||
end
|
end
|
||||||
|
|
||||||
for i in 1:n_particles
|
# Episode
|
||||||
rl_params.envs[i].particle = particles[i]
|
|
||||||
rl_params.envs[i].state_ind = length(rl_params.env_params.state_space)
|
|
||||||
end
|
|
||||||
|
|
||||||
rl_params.env_params.reward = 0.0
|
|
||||||
|
|
||||||
run_sim(
|
run_sim(
|
||||||
dir; duration=episode_duration, seed=rand(1:typemax(Int64)), rl_params=rl_params
|
dir; duration=episode_duration, seed=rand(1:typemax(Int64)), rl_params=rl_params
|
||||||
)
|
)
|
||||||
|
|
||||||
for i in 1:n_particles
|
# Post episode
|
||||||
env = rl_params.envs[i]
|
@simd for i in 1:n_particles
|
||||||
agent = rl_params.agents[i]
|
env, agent, hook = get_env_agent_hook(rl_params, i)
|
||||||
|
|
||||||
rl_params.hooks[i](POST_EPISODE_STAGE, agent, env)
|
hook(POST_EPISODE_STAGE, agent, env)
|
||||||
agent(POST_EPISODE_STAGE, env)
|
agent(POST_EPISODE_STAGE, env)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
for i in 1:n_particles
|
# Post experiment
|
||||||
env = rl_params.envs[i]
|
@simd for i in 1:n_particles
|
||||||
agent = rl_params.agents[i]
|
env, agent, hook = get_env_agent_hook(rl_params, i)
|
||||||
|
|
||||||
rl_params.hooks[i](POST_EXPERIMENT_STAGE, agent, env)
|
hook(POST_EXPERIMENT_STAGE, agent, env)
|
||||||
end
|
end
|
||||||
|
|
||||||
return rl_params
|
return rl_params
|
||||||
|
|
|
@ -109,12 +109,10 @@ function run_sim(
|
||||||
|
|
||||||
if !isnothing(rl_params)
|
if !isnothing(rl_params)
|
||||||
pre_integration_hook! = RL.pre_integration_hook!
|
pre_integration_hook! = RL.pre_integration_hook!
|
||||||
integration_hook = RL.integration_hook
|
integration_hook! = RL.integration_hook!
|
||||||
post_integration_hook = RL.post_integration_hook
|
post_integration_hook = RL.post_integration_hook
|
||||||
else
|
else
|
||||||
pre_integration_hook! = empty_hook
|
pre_integration_hook! = integration_hook! = post_integration_hook = empty_hook
|
||||||
integration_hook = empty_hook
|
|
||||||
post_integration_hook = empty_hook
|
|
||||||
end
|
end
|
||||||
|
|
||||||
simulate(
|
simulate(
|
||||||
|
@ -128,7 +126,7 @@ function run_sim(
|
||||||
save_data,
|
save_data,
|
||||||
rl_params,
|
rl_params,
|
||||||
pre_integration_hook!,
|
pre_integration_hook!,
|
||||||
integration_hook,
|
integration_hook!,
|
||||||
post_integration_hook,
|
post_integration_hook,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -65,11 +65,11 @@ function gen_sim_consts(
|
||||||
ϵ = 100.0
|
ϵ = 100.0
|
||||||
interaction_r = 2^(1 / 6) * σ
|
interaction_r = 2^(1 / 6) * σ
|
||||||
|
|
||||||
if v₀ != 0.0
|
|
||||||
buffer = 1.8
|
buffer = 1.8
|
||||||
max_approach_after_one_integration_step = buffer * (2 * v₀ * δt) / interaction_r
|
max_approach_after_one_integration_step = buffer * (2 * v₀ * δt) / interaction_r
|
||||||
@assert skin_to_interaction_r_ratio >= 1 + max_approach_after_one_integration_step
|
@assert skin_to_interaction_r_ratio >= 1 + max_approach_after_one_integration_step
|
||||||
|
|
||||||
|
if v₀ != 0.0
|
||||||
n_steps_before_verlet_list_update = round(
|
n_steps_before_verlet_list_update = round(
|
||||||
Int64,
|
Int64,
|
||||||
(skin_to_interaction_r_ratio - 1) / max_approach_after_one_integration_step,
|
(skin_to_interaction_r_ratio - 1) / max_approach_after_one_integration_step,
|
||||||
|
|
|
@ -43,7 +43,7 @@ end
|
||||||
function euler!(
|
function euler!(
|
||||||
args,
|
args,
|
||||||
state_hook::Function,
|
state_hook::Function,
|
||||||
integration_hook::Function,
|
integration_hook!::Function,
|
||||||
rl_params::Union{RL.Params,Nothing},
|
rl_params::Union{RL.Params,Nothing},
|
||||||
)
|
)
|
||||||
for id1 in 1:(args.n_particles - 1)
|
for id1 in 1:(args.n_particles - 1)
|
||||||
|
@ -79,7 +79,7 @@ function euler!(
|
||||||
|
|
||||||
restrict_coordinates!(p, args.half_box_len)
|
restrict_coordinates!(p, args.half_box_len)
|
||||||
|
|
||||||
integration_hook(p, rl_params, args.δt, si, co)
|
integration_hook!(p, rl_params, args.δt, si, co)
|
||||||
|
|
||||||
p.φ += args.c₄ * rand_normal01()
|
p.φ += args.c₄ * rand_normal01()
|
||||||
|
|
||||||
|
@ -94,8 +94,8 @@ wait(::Nothing) = nothing
|
||||||
gen_run_hooks(::Nothing, args...) = false
|
gen_run_hooks(::Nothing, args...) = false
|
||||||
|
|
||||||
function gen_run_hooks(rl_params::RL.Params, integration_step::Int64)
|
function gen_run_hooks(rl_params::RL.Params, integration_step::Int64)
|
||||||
return (integration_step == 1) ||
|
return (integration_step % rl_params.n_steps_before_actions_update == 0) ||
|
||||||
(integration_step % rl_params.n_steps_before_actions_update == 0)
|
(integration_step == 1)
|
||||||
end
|
end
|
||||||
|
|
||||||
function simulate(
|
function simulate(
|
||||||
|
@ -109,7 +109,7 @@ function simulate(
|
||||||
save_data::Bool,
|
save_data::Bool,
|
||||||
rl_params::Union{RL.Params,Nothing},
|
rl_params::Union{RL.Params,Nothing},
|
||||||
pre_integration_hook!::Function,
|
pre_integration_hook!::Function,
|
||||||
integration_hook::Function,
|
integration_hook!::Function,
|
||||||
post_integration_hook::Function,
|
post_integration_hook::Function,
|
||||||
)
|
)
|
||||||
bundle_snapshot_counter = 0
|
bundle_snapshot_counter = 0
|
||||||
|
@ -153,7 +153,7 @@ function simulate(
|
||||||
state_hook = RL.state_hook
|
state_hook = RL.state_hook
|
||||||
end
|
end
|
||||||
|
|
||||||
euler!(args, state_hook, integration_hook, rl_params)
|
euler!(args, state_hook, integration_hook!, rl_params)
|
||||||
|
|
||||||
if run_hooks
|
if run_hooks
|
||||||
post_integration_hook(
|
post_integration_hook(
|
||||||
|
|
Loading…
Reference in a new issue