2021-12-12 23:19:18 +00:00
|
|
|
module RL
|
2021-12-12 14:29:08 +00:00
|
|
|
|
2021-12-15 03:45:15 +00:00
|
|
|
export run_rl
|
2021-12-14 03:03:14 +00:00
|
|
|
|
2021-12-10 02:16:45 +00:00
|
|
|
using ReinforcementLearning
|
2021-12-12 17:27:56 +00:00
|
|
|
using Flux: InvDecay
|
2021-12-12 14:29:08 +00:00
|
|
|
using Intervals
|
2021-12-12 17:27:56 +00:00
|
|
|
using StaticArrays: SVector
|
2021-12-14 03:03:14 +00:00
|
|
|
using LoopVectorization: @turbo
|
2021-12-12 23:19:18 +00:00
|
|
|
using Random: Random
|
2021-12-13 01:24:34 +00:00
|
|
|
using ProgressMeter: @showprogress
|
2021-12-12 23:19:18 +00:00
|
|
|
|
2021-12-16 12:50:38 +00:00
|
|
|
using ..ReCo: ReCo, Particle, angle2
|
2021-12-12 14:29:08 +00:00
|
|
|
|
2021-12-14 03:03:14 +00:00
|
|
|
const INITIAL_REWARD = 0.0
|
|
|
|
|
2021-12-12 14:29:08 +00:00
|
|
|
mutable struct EnvParams
|
2021-12-10 02:16:45 +00:00
|
|
|
action_space::Vector{Tuple{Float64,Float64}}
|
2021-12-14 03:03:14 +00:00
|
|
|
action_ind_space::Vector{Int64}
|
|
|
|
|
2021-12-16 13:54:52 +00:00
|
|
|
distance_state_space::Vector{Interval}
|
|
|
|
angle_state_space::Vector{Interval}
|
|
|
|
state_space::Vector{Union{Tuple{Interval,Interval},Tuple{Nothing,Nothing}}}
|
2021-12-14 03:03:14 +00:00
|
|
|
state_ind_space::Vector{Int64}
|
|
|
|
n_states::Int64
|
|
|
|
|
2021-12-10 02:16:45 +00:00
|
|
|
reward::Float64
|
|
|
|
|
2021-12-12 14:29:08 +00:00
|
|
|
function EnvParams(
|
|
|
|
min_distance::Float64,
|
|
|
|
max_distance::Float64;
|
2021-12-16 12:50:38 +00:00
|
|
|
n_v_actions::Int64=2,
|
2021-12-14 03:03:14 +00:00
|
|
|
n_ω_actions::Int64=3,
|
2021-12-13 01:24:34 +00:00
|
|
|
max_v::Float64=80.0,
|
2021-12-16 12:50:38 +00:00
|
|
|
max_ω::Float64=π / 2,
|
2021-12-14 03:03:14 +00:00
|
|
|
n_distance_states::Int64=2,
|
2021-12-16 13:54:52 +00:00
|
|
|
n_angle_states::Int64=2,
|
2021-12-10 02:16:45 +00:00
|
|
|
)
|
2021-12-12 14:29:08 +00:00
|
|
|
@assert min_distance > 0.0
|
|
|
|
@assert max_distance > min_distance
|
2021-12-10 02:16:45 +00:00
|
|
|
@assert n_v_actions > 1
|
2021-12-12 23:19:18 +00:00
|
|
|
@assert n_ω_actions > 1
|
2021-12-10 02:16:45 +00:00
|
|
|
@assert max_v > 0
|
|
|
|
@assert max_ω > 0
|
|
|
|
|
|
|
|
v_action_space = 0.0:(max_v / (n_v_actions - 1)):max_v
|
|
|
|
ω_action_space = (-max_ω):(2 * max_ω / (n_ω_actions - 1)):max_ω
|
|
|
|
|
|
|
|
n_actions = n_v_actions * n_ω_actions
|
|
|
|
|
|
|
|
action_space = Vector{Tuple{Float64,Float64}}(undef, n_actions)
|
|
|
|
|
|
|
|
ind = 1
|
|
|
|
for v in v_action_space
|
|
|
|
for ω in ω_action_space
|
|
|
|
action_space[ind] = (v, ω)
|
|
|
|
ind += 1
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
2021-12-14 03:03:14 +00:00
|
|
|
action_ind_space = collect(1:n_actions)
|
2021-12-13 01:24:34 +00:00
|
|
|
|
2021-12-12 14:29:08 +00:00
|
|
|
distance_range =
|
|
|
|
min_distance:((max_distance - min_distance) / n_distance_states):max_distance
|
2021-12-10 02:16:45 +00:00
|
|
|
|
2021-12-16 13:54:52 +00:00
|
|
|
distance_state_space = Vector{Interval}(undef, n_distance_states)
|
2021-12-10 02:16:45 +00:00
|
|
|
|
2021-12-14 03:03:14 +00:00
|
|
|
@simd for i in 1:n_distance_states
|
2021-12-12 14:29:08 +00:00
|
|
|
if i == 1
|
|
|
|
bound = Closed
|
|
|
|
else
|
|
|
|
bound = Open
|
|
|
|
end
|
|
|
|
|
2021-12-16 13:54:52 +00:00
|
|
|
distance_state_space[i] = Interval{Float64,bound,Closed}(
|
2021-12-12 14:29:08 +00:00
|
|
|
distance_range[i], distance_range[i + 1]
|
|
|
|
)
|
|
|
|
end
|
|
|
|
|
2021-12-16 13:54:52 +00:00
|
|
|
angle_range = (-π):(2 * π / n_angle_states):π
|
|
|
|
|
|
|
|
angle_state_space = Vector{Interval}(undef, n_angle_states)
|
2021-12-12 14:29:08 +00:00
|
|
|
|
2021-12-16 13:54:52 +00:00
|
|
|
@simd for i in 1:n_angle_states
|
|
|
|
if i == 1
|
|
|
|
bound = Closed
|
|
|
|
else
|
|
|
|
bound = Open
|
|
|
|
end
|
2021-12-12 14:29:08 +00:00
|
|
|
|
2021-12-16 13:54:52 +00:00
|
|
|
angle_state_space[i] = Interval{Float64,bound,Closed}(
|
|
|
|
angle_range[i], angle_range[i + 1]
|
2021-12-12 14:29:08 +00:00
|
|
|
)
|
|
|
|
end
|
|
|
|
|
2021-12-16 13:54:52 +00:00
|
|
|
n_states = n_distance_states * n_angle_states + 1
|
2021-12-12 23:19:18 +00:00
|
|
|
|
2021-12-16 13:54:52 +00:00
|
|
|
state_space = Vector{Union{Tuple{Interval,Interval},Tuple{Nothing,Nothing}}}(
|
2021-12-12 17:27:56 +00:00
|
|
|
undef, n_states
|
|
|
|
)
|
2021-12-10 02:16:45 +00:00
|
|
|
|
|
|
|
ind = 1
|
2021-12-12 14:29:08 +00:00
|
|
|
for distance_state in distance_state_space
|
2021-12-16 13:54:52 +00:00
|
|
|
for angle_state in angle_state_space
|
|
|
|
state_space[ind] = (distance_state, angle_state)
|
2021-12-10 02:16:45 +00:00
|
|
|
ind += 1
|
|
|
|
end
|
|
|
|
end
|
2021-12-12 14:29:08 +00:00
|
|
|
state_space[ind] = (nothing, nothing)
|
2021-12-10 02:16:45 +00:00
|
|
|
|
2021-12-14 03:03:14 +00:00
|
|
|
state_ind_space = collect(1:n_states)
|
2021-12-13 01:24:34 +00:00
|
|
|
|
2021-12-12 14:29:08 +00:00
|
|
|
return new(
|
2021-12-13 01:24:34 +00:00
|
|
|
action_space,
|
2021-12-14 03:03:14 +00:00
|
|
|
action_ind_space,
|
2021-12-13 01:24:34 +00:00
|
|
|
distance_state_space,
|
2021-12-16 13:54:52 +00:00
|
|
|
angle_state_space,
|
2021-12-13 01:24:34 +00:00
|
|
|
state_space,
|
2021-12-14 03:03:14 +00:00
|
|
|
state_ind_space,
|
|
|
|
n_states,
|
|
|
|
INITIAL_REWARD,
|
2021-12-12 14:29:08 +00:00
|
|
|
)
|
2021-12-10 02:16:45 +00:00
|
|
|
end
|
|
|
|
end
|
|
|
|
|
2021-12-14 03:03:14 +00:00
|
|
|
function reset!(env_params::EnvParams)
|
|
|
|
env_params.reward = INITIAL_REWARD
|
|
|
|
|
|
|
|
return nothing
|
|
|
|
end
|
|
|
|
|
2021-12-12 14:29:08 +00:00
|
|
|
mutable struct Env <: AbstractEnv
|
|
|
|
params::EnvParams
|
2021-12-15 19:50:18 +00:00
|
|
|
particle::Particle
|
2021-12-13 01:24:34 +00:00
|
|
|
state_ind::Int64
|
2021-12-10 02:16:45 +00:00
|
|
|
|
2021-12-15 19:50:18 +00:00
|
|
|
function Env(params::EnvParams, particle::Particle)
|
2021-12-13 01:24:34 +00:00
|
|
|
# initial_state = (nothing, nothing)
|
2021-12-14 03:03:14 +00:00
|
|
|
initial_state_ind = params.n_states
|
2021-12-13 01:24:34 +00:00
|
|
|
|
|
|
|
return new(params, particle, initial_state_ind)
|
2021-12-10 02:16:45 +00:00
|
|
|
end
|
|
|
|
end
|
|
|
|
|
2021-12-15 19:50:18 +00:00
|
|
|
function reset!(env::Env, particle::Particle)
|
2021-12-14 03:03:14 +00:00
|
|
|
env.particle = particle
|
|
|
|
env.state_ind = env.params.n_states
|
|
|
|
|
|
|
|
return nothing
|
|
|
|
end
|
|
|
|
|
|
|
|
RLBase.state_space(env::Env) = env.params.state_ind_space
|
|
|
|
|
|
|
|
RLBase.state(env::Env) = env.state_ind
|
|
|
|
|
|
|
|
RLBase.action_space(env::Env) = env.params.action_ind_space
|
|
|
|
|
|
|
|
RLBase.reward(env::Env) = env.params.reward
|
|
|
|
|
|
|
|
RLBase.is_terminated(::Env) = false
|
|
|
|
|
2021-12-13 01:24:34 +00:00
|
|
|
function gen_policy(n_states::Int64, n_actions::Int64)
|
2021-12-12 17:27:56 +00:00
|
|
|
return QBasedPolicy(;
|
|
|
|
learner=MonteCarloLearner(;
|
|
|
|
approximator=TabularQApproximator(;
|
|
|
|
n_state=n_states, n_action=n_actions, opt=InvDecay(1.0)
|
|
|
|
),
|
|
|
|
),
|
|
|
|
explorer=EpsilonGreedyExplorer(0.1),
|
|
|
|
)
|
|
|
|
end
|
|
|
|
|
|
|
|
struct Params{H<:AbstractHook}
|
2021-12-12 14:29:08 +00:00
|
|
|
envs::Vector{Env}
|
2021-12-12 17:27:56 +00:00
|
|
|
agents::Vector{Agent}
|
|
|
|
hooks::Vector{H}
|
2021-12-12 14:29:08 +00:00
|
|
|
actions::Vector{Tuple{Float64,Float64}}
|
|
|
|
env_params::EnvParams
|
|
|
|
n_steps_before_actions_update::Int64
|
2021-12-14 03:03:14 +00:00
|
|
|
min_sq_distances::Vector{Float64}
|
|
|
|
vecs_r⃗₁₂_to_min_distance_particle::Vector{SVector{2,Float64}}
|
2021-12-13 01:24:34 +00:00
|
|
|
goal_shape_ratio::Float64
|
2021-12-10 02:16:45 +00:00
|
|
|
|
2021-12-12 17:27:56 +00:00
|
|
|
function Params{H}(
|
2021-12-13 01:24:34 +00:00
|
|
|
n_particles::Int64,
|
|
|
|
env_params::EnvParams,
|
|
|
|
n_steps_before_actions_update::Int64,
|
|
|
|
goal_shape_ratio::Float64,
|
2021-12-12 17:27:56 +00:00
|
|
|
) where {H<:AbstractHook}
|
2021-12-14 03:03:14 +00:00
|
|
|
envs = [Env(env_params, ReCo.gen_tmp_particle()) for i in 1:n_particles]
|
|
|
|
|
2021-12-12 17:27:56 +00:00
|
|
|
agents = [
|
2021-12-14 03:03:14 +00:00
|
|
|
Agent(;
|
|
|
|
policy=gen_policy(env_params.n_states, length(env_params.action_space)),
|
|
|
|
trajectory=VectorSARTTrajectory(),
|
|
|
|
) for i in 1:n_particles
|
2021-12-12 17:27:56 +00:00
|
|
|
]
|
2021-12-14 03:03:14 +00:00
|
|
|
|
|
|
|
hooks = [H() for i in 1:n_particles]
|
|
|
|
|
|
|
|
actions = Vector{Tuple{Float64,Float64}}(undef, n_particles)
|
|
|
|
|
|
|
|
min_sq_distances = fill(Inf64, n_particles)
|
|
|
|
|
|
|
|
vecs_r⃗₁₂_to_min_distance_particle = fill(SVector(0.0, 0.0), n_particles)
|
|
|
|
|
2021-12-12 14:29:08 +00:00
|
|
|
return new(
|
2021-12-14 03:03:14 +00:00
|
|
|
envs,
|
2021-12-12 17:27:56 +00:00
|
|
|
agents,
|
2021-12-14 03:03:14 +00:00
|
|
|
hooks,
|
|
|
|
actions,
|
2021-12-12 14:29:08 +00:00
|
|
|
env_params,
|
|
|
|
n_steps_before_actions_update,
|
2021-12-14 03:03:14 +00:00
|
|
|
min_sq_distances,
|
|
|
|
vecs_r⃗₁₂_to_min_distance_particle,
|
2021-12-13 01:24:34 +00:00
|
|
|
goal_shape_ratio,
|
2021-12-12 14:29:08 +00:00
|
|
|
)
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
2021-12-14 03:03:14 +00:00
|
|
|
function get_env_agent_hook(rl_params::Params, ind::Int64)
|
|
|
|
return (rl_params.envs[ind], rl_params.agents[ind], rl_params.hooks[ind])
|
|
|
|
end
|
2021-12-13 01:24:34 +00:00
|
|
|
|
2021-12-12 17:27:56 +00:00
|
|
|
function pre_integration_hook!(rl_params::Params, n_particles::Int64)
|
2021-12-14 03:03:14 +00:00
|
|
|
@simd for i in 1:n_particles
|
|
|
|
env, agent, hook = get_env_agent_hook(rl_params, i)
|
2021-12-13 01:24:34 +00:00
|
|
|
|
2021-12-14 03:03:14 +00:00
|
|
|
# Update action
|
2021-12-13 01:24:34 +00:00
|
|
|
action_ind = agent(env)
|
|
|
|
action = rl_params.env_params.action_space[action_ind]
|
2021-12-12 17:27:56 +00:00
|
|
|
rl_params.actions[i] = action
|
2021-12-12 14:29:08 +00:00
|
|
|
|
2021-12-14 03:03:14 +00:00
|
|
|
# Pre act
|
2021-12-13 01:24:34 +00:00
|
|
|
agent(PRE_ACT_STAGE, env, action_ind)
|
2021-12-14 03:03:14 +00:00
|
|
|
hook(PRE_ACT_STAGE, agent, env, action_ind)
|
|
|
|
end
|
|
|
|
|
|
|
|
@turbo for i in 1:n_particles
|
|
|
|
rl_params.min_sq_distances[i] = Inf64
|
2021-12-12 17:27:56 +00:00
|
|
|
end
|
2021-12-12 14:29:08 +00:00
|
|
|
|
2021-12-12 17:27:56 +00:00
|
|
|
return nothing
|
|
|
|
end
|
|
|
|
|
|
|
|
function state_hook(
|
|
|
|
id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}, distance²::Float64, rl_params::Params
|
|
|
|
)
|
2021-12-14 03:03:14 +00:00
|
|
|
if rl_params.min_sq_distances[id1] > distance²
|
|
|
|
rl_params.min_sq_distances[id1] = distance²
|
2021-12-12 17:27:56 +00:00
|
|
|
|
2021-12-14 03:03:14 +00:00
|
|
|
rl_params.vecs_r⃗₁₂_to_min_distance_particle[id1] = r⃗₁₂
|
2021-12-12 17:27:56 +00:00
|
|
|
end
|
|
|
|
|
2021-12-14 03:03:14 +00:00
|
|
|
if rl_params.min_sq_distances[id2] > distance²
|
|
|
|
rl_params.min_sq_distances[id2] = distance²
|
2021-12-12 17:27:56 +00:00
|
|
|
|
2021-12-14 03:03:14 +00:00
|
|
|
rl_params.vecs_r⃗₁₂_to_min_distance_particle[id2] = -r⃗₁₂
|
2021-12-12 17:27:56 +00:00
|
|
|
end
|
|
|
|
|
|
|
|
return nothing
|
|
|
|
end
|
|
|
|
|
2021-12-14 03:03:14 +00:00
|
|
|
function integration_hook!(
|
2021-12-15 19:50:18 +00:00
|
|
|
particle::Particle, rl_params::Params, δt::Float64, si::Float64, co::Float64
|
2021-12-13 01:24:34 +00:00
|
|
|
)
|
2021-12-14 03:03:14 +00:00
|
|
|
# Apply action
|
2021-12-12 17:27:56 +00:00
|
|
|
action = rl_params.actions[particle.id]
|
|
|
|
|
2021-12-13 01:24:34 +00:00
|
|
|
vδt = action[1] * δt
|
|
|
|
particle.tmp_c += SVector(vδt * co, vδt * si)
|
2021-12-12 17:27:56 +00:00
|
|
|
particle.φ += action[2] * δt
|
|
|
|
|
|
|
|
return nothing
|
|
|
|
end
|
|
|
|
|
2021-12-16 13:54:52 +00:00
|
|
|
function get_state_ind(state::Tuple{Interval,Interval}, env_params::EnvParams)
|
2021-12-14 03:03:14 +00:00
|
|
|
return findfirst(x -> x == state, env_params.state_space)
|
|
|
|
end
|
|
|
|
|
|
|
|
function get_state_ind(::Tuple{Nothing,Nothing}, env_params::EnvParams)
|
|
|
|
return env_params.n_states
|
2021-12-13 01:24:34 +00:00
|
|
|
end
|
|
|
|
|
2021-12-12 17:27:56 +00:00
|
|
|
function post_integration_hook(
|
2021-12-13 01:24:34 +00:00
|
|
|
rl_params::Params,
|
|
|
|
n_particles::Int64,
|
2021-12-15 19:50:18 +00:00
|
|
|
particles::Vector{Particle},
|
2021-12-13 01:24:34 +00:00
|
|
|
half_box_len::Float64,
|
2021-12-12 17:27:56 +00:00
|
|
|
)
|
2021-12-14 03:03:14 +00:00
|
|
|
# Update reward
|
|
|
|
rl_params.env_params.reward =
|
|
|
|
1 -
|
|
|
|
(
|
|
|
|
ReCo.gyration_tensor_eigvals_ratio(particles, half_box_len) -
|
|
|
|
rl_params.goal_shape_ratio
|
|
|
|
)^2
|
|
|
|
|
|
|
|
# Update states
|
|
|
|
n_states = rl_params.env_params.n_states
|
|
|
|
|
2021-12-16 13:54:52 +00:00
|
|
|
env_angle_state = rl_params.env_params.angle_state_space[1]
|
2021-12-12 17:27:56 +00:00
|
|
|
|
|
|
|
for i in 1:n_particles
|
2021-12-14 03:03:14 +00:00
|
|
|
env, agent, hook = get_env_agent_hook(rl_params, i)
|
2021-12-12 17:27:56 +00:00
|
|
|
|
2021-12-16 13:54:52 +00:00
|
|
|
env_distance_state::Union{Interval,Nothing} = nothing
|
2021-12-12 17:27:56 +00:00
|
|
|
|
2021-12-14 03:03:14 +00:00
|
|
|
min_sq_distance = rl_params.min_sq_distances[i]
|
|
|
|
min_distance = sqrt(min_sq_distance)
|
|
|
|
|
|
|
|
if !isinf(min_sq_distance)
|
|
|
|
for distance_state in rl_params.env_params.distance_state_space
|
2021-12-16 13:54:52 +00:00
|
|
|
if min_distance in distance_state
|
2021-12-14 03:03:14 +00:00
|
|
|
env_distance_state = distance_state
|
|
|
|
break
|
|
|
|
end
|
2021-12-12 17:27:56 +00:00
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
if isnothing(env_distance_state)
|
2021-12-13 01:24:34 +00:00
|
|
|
# (nothing, nothing)
|
2021-12-14 03:03:14 +00:00
|
|
|
env.state_ind = n_states
|
2021-12-12 17:27:56 +00:00
|
|
|
else
|
2021-12-14 03:03:14 +00:00
|
|
|
r⃗₁₂ = rl_params.vecs_r⃗₁₂_to_min_distance_particle[i]
|
2021-12-12 17:27:56 +00:00
|
|
|
si, co = sincos(particles[i].φ)
|
|
|
|
|
|
|
|
#=
|
|
|
|
Angle between two vectors
|
|
|
|
e = (co, si)
|
|
|
|
angle = acos(dot(r⃗₁₂, e) / (norm(r⃗₁₂) * norm(e)))
|
|
|
|
norm(r⃗₁₂) == min_distance
|
|
|
|
norm(e) == 1
|
2021-12-14 03:03:14 +00:00
|
|
|
|
|
|
|
min_distance is not infinite, because otherwise
|
2021-12-16 13:54:52 +00:00
|
|
|
env_angle_state would be nothing and this else block will not be called
|
2021-12-12 17:27:56 +00:00
|
|
|
=#
|
2021-12-16 13:54:52 +00:00
|
|
|
angle = angle2(SVector(co, si), r⃗₁₂)
|
2021-12-12 17:27:56 +00:00
|
|
|
|
2021-12-16 13:54:52 +00:00
|
|
|
for angle_state in rl_params.env_params.angle_state_space
|
|
|
|
if angle in angle_state
|
|
|
|
env_angle_state = angle_state
|
2021-12-12 17:27:56 +00:00
|
|
|
end
|
|
|
|
end
|
|
|
|
|
2021-12-16 13:54:52 +00:00
|
|
|
state = (env_distance_state, env_angle_state)
|
2021-12-14 03:03:14 +00:00
|
|
|
env.state_ind = get_state_ind(state, env.params)
|
2021-12-12 17:27:56 +00:00
|
|
|
end
|
|
|
|
|
2021-12-14 03:03:14 +00:00
|
|
|
# Post act
|
2021-12-12 17:27:56 +00:00
|
|
|
agent(POST_ACT_STAGE, env)
|
2021-12-14 03:03:14 +00:00
|
|
|
hook(POST_ACT_STAGE, agent, env)
|
2021-12-12 17:27:56 +00:00
|
|
|
end
|
|
|
|
|
|
|
|
return nothing
|
|
|
|
end
|
2021-12-12 14:29:08 +00:00
|
|
|
|
2021-12-15 03:45:15 +00:00
|
|
|
function run_rl(;
|
2021-12-13 01:24:34 +00:00
|
|
|
goal_shape_ratio::Float64,
|
2021-12-12 14:29:08 +00:00
|
|
|
n_episodes::Int64=100,
|
2021-12-14 03:03:14 +00:00
|
|
|
episode_duration::Float64=50.0,
|
|
|
|
update_actions_at::Float64=0.2,
|
2021-12-13 01:24:34 +00:00
|
|
|
n_particles::Int64=100,
|
2021-12-14 03:03:14 +00:00
|
|
|
seed::Int64=42,
|
2021-12-12 14:29:08 +00:00
|
|
|
)
|
2021-12-13 01:24:34 +00:00
|
|
|
@assert 0.0 <= goal_shape_ratio <= 1.0
|
2021-12-12 14:29:08 +00:00
|
|
|
@assert n_episodes > 0
|
|
|
|
@assert episode_duration > 0
|
|
|
|
@assert update_actions_at in 0.01:0.01:episode_duration
|
|
|
|
@assert n_particles > 0
|
|
|
|
|
2021-12-14 03:03:14 +00:00
|
|
|
# Setup
|
|
|
|
Random.seed!(seed)
|
2021-12-10 02:16:45 +00:00
|
|
|
|
2021-12-16 13:54:52 +00:00
|
|
|
sim_consts = ReCo.gen_sim_consts(n_particles, 0.0; skin_to_interaction_r_ratio=1.6)
|
2021-12-12 23:19:18 +00:00
|
|
|
n_particles = sim_consts.n_particles
|
2021-12-12 14:29:08 +00:00
|
|
|
|
|
|
|
env_params = EnvParams(sim_consts.particle_radius, sim_consts.skin_r)
|
|
|
|
|
2021-12-12 23:19:18 +00:00
|
|
|
n_steps_before_actions_update = round(Int64, update_actions_at / sim_consts.δt)
|
|
|
|
|
2021-12-12 17:27:56 +00:00
|
|
|
rl_params = Params{TotalRewardPerEpisode}(
|
2021-12-13 01:24:34 +00:00
|
|
|
n_particles, env_params, n_steps_before_actions_update, goal_shape_ratio
|
2021-12-12 17:27:56 +00:00
|
|
|
)
|
|
|
|
|
2021-12-14 03:03:14 +00:00
|
|
|
# Pre experiment
|
|
|
|
@simd for i in 1:n_particles
|
|
|
|
env, agent, hook = get_env_agent_hook(rl_params, i)
|
2021-12-12 17:27:56 +00:00
|
|
|
|
2021-12-14 03:03:14 +00:00
|
|
|
hook(PRE_EXPERIMENT_STAGE, agent, env)
|
2021-12-12 17:27:56 +00:00
|
|
|
agent(PRE_EXPERIMENT_STAGE, env)
|
|
|
|
end
|
2021-12-12 14:29:08 +00:00
|
|
|
|
2021-12-13 01:24:34 +00:00
|
|
|
@showprogress 0.6 for episode in 1:n_episodes
|
2021-12-12 23:19:18 +00:00
|
|
|
dir, particles = ReCo.init_sim_with_sim_consts(sim_consts; parent_dir="RL")
|
2021-12-12 17:27:56 +00:00
|
|
|
|
2021-12-14 03:03:14 +00:00
|
|
|
# Reset
|
|
|
|
@simd for i in 1:n_particles
|
|
|
|
reset!(rl_params.envs[i], particles[i])
|
2021-12-12 17:27:56 +00:00
|
|
|
end
|
|
|
|
|
2021-12-14 03:03:14 +00:00
|
|
|
reset!(rl_params.env_params)
|
2021-12-12 14:29:08 +00:00
|
|
|
|
2021-12-14 03:03:14 +00:00
|
|
|
# Pre espisode
|
|
|
|
@simd for i in 1:n_particles
|
|
|
|
env, agent, hook = get_env_agent_hook(rl_params, i)
|
2021-12-12 14:29:08 +00:00
|
|
|
|
2021-12-14 03:03:14 +00:00
|
|
|
hook(PRE_EPISODE_STAGE, agent, env)
|
|
|
|
agent(PRE_EPISODE_STAGE, env)
|
|
|
|
end
|
|
|
|
|
|
|
|
# Episode
|
2021-12-16 12:50:38 +00:00
|
|
|
ReCo.run_sim(
|
2021-12-12 17:27:56 +00:00
|
|
|
dir; duration=episode_duration, seed=rand(1:typemax(Int64)), rl_params=rl_params
|
2021-12-12 14:29:08 +00:00
|
|
|
)
|
2021-12-12 17:27:56 +00:00
|
|
|
|
2021-12-14 03:03:14 +00:00
|
|
|
# Post episode
|
|
|
|
@simd for i in 1:n_particles
|
|
|
|
env, agent, hook = get_env_agent_hook(rl_params, i)
|
2021-12-12 17:27:56 +00:00
|
|
|
|
2021-12-14 03:03:14 +00:00
|
|
|
hook(POST_EPISODE_STAGE, agent, env)
|
2021-12-12 17:27:56 +00:00
|
|
|
agent(POST_EPISODE_STAGE, env)
|
|
|
|
end
|
2021-12-12 14:29:08 +00:00
|
|
|
end
|
|
|
|
|
2021-12-14 03:03:14 +00:00
|
|
|
# Post experiment
|
|
|
|
@simd for i in 1:n_particles
|
|
|
|
env, agent, hook = get_env_agent_hook(rl_params, i)
|
2021-12-12 17:27:56 +00:00
|
|
|
|
2021-12-14 03:03:14 +00:00
|
|
|
hook(POST_EXPERIMENT_STAGE, agent, env)
|
2021-12-12 17:27:56 +00:00
|
|
|
end
|
|
|
|
|
2021-12-13 01:24:34 +00:00
|
|
|
return rl_params
|
2021-12-12 14:29:08 +00:00
|
|
|
end
|
2021-12-10 02:16:45 +00:00
|
|
|
|
2021-12-12 14:29:08 +00:00
|
|
|
end # module
|