mirror of
https://gitlab.rlp.net/mobitar/ReCo.jl.git
synced 2024-11-08 22:21:08 +00:00
Restructured code to include all environments
This commit is contained in:
parent
9c00da84ea
commit
bb3246a1e7
5 changed files with 265 additions and 191 deletions
140
src/RL/LocalCOMEnv.jl
Normal file
140
src/RL/LocalCOMEnv.jl
Normal file
|
@ -0,0 +1,140 @@
|
||||||
|
export LocalCOMEnv
|
||||||
|
|
||||||
|
struct LocalCOMEnv <: Env
|
||||||
|
params::EnvParams
|
||||||
|
|
||||||
|
distance_state_space::Vector{Interval}
|
||||||
|
direction_angle_state_space::Vector{Interval}
|
||||||
|
|
||||||
|
max_distance::Float64
|
||||||
|
|
||||||
|
function LocalCOMEnv(
|
||||||
|
sim_consts; n_distance_states::Int64=3, n_direction_angle_states::Int64=3
|
||||||
|
)
|
||||||
|
@assert n_direction_angle_states > 1
|
||||||
|
|
||||||
|
direction_angle_state_space = gen_angle_state_space(n_direction_angle_states)
|
||||||
|
|
||||||
|
min_distance = 0.0
|
||||||
|
max_distance = sim_consts.skin_r
|
||||||
|
|
||||||
|
distance_state_space = gen_distance_state_space(
|
||||||
|
min_distance, max_distance, n_distance_states
|
||||||
|
)
|
||||||
|
|
||||||
|
n_states = n_distance_states * n_direction_angle_states + 1
|
||||||
|
|
||||||
|
state_space = Vector{SVector{2,Interval}}(undef, n_states - 1)
|
||||||
|
|
||||||
|
ind = 1
|
||||||
|
for distance_state in distance_state_space
|
||||||
|
for direction_angle_state in direction_angle_state_space
|
||||||
|
state_space[ind] = SVector(distance_state, direction_angle_state)
|
||||||
|
ind += 1
|
||||||
|
end
|
||||||
|
end
|
||||||
|
# Last state is when no particle is in the skin radius
|
||||||
|
|
||||||
|
params = EnvParams(n_states, state_space)
|
||||||
|
|
||||||
|
return new(params, distance_state_space, direction_angle_state_space, max_distance)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
struct LocalCOMEnvHelper <: EnvHelper
|
||||||
|
params::EnvHelperParams
|
||||||
|
|
||||||
|
vec_to_neighbour_sums::Vector{SVector{2,Float64}}
|
||||||
|
n_neighbours::Vector{Int64}
|
||||||
|
|
||||||
|
function LocalCOMEnvHelper(params::EnvHelperParams)
|
||||||
|
return new(
|
||||||
|
params, fill(SVector(0.0, 0.0), params.n_particles), fill(0, params.n_particles)
|
||||||
|
)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function gen_env_helper(::LocalCOMEnv, env_helper_params::EnvHelperParams)
|
||||||
|
return LocalCOMEnvHelper(env_helper_params)
|
||||||
|
end
|
||||||
|
|
||||||
|
function pre_integration_hook(env_helper::LocalCOMEnvHelper)
|
||||||
|
@simd for id in 1:(env_helper.params.n_particles)
|
||||||
|
env_helper.vec_to_neighbour_sums[id] = SVector(0.0, 0.0)
|
||||||
|
env_helper.n_neighbours[id] = 0
|
||||||
|
end
|
||||||
|
|
||||||
|
return nothing
|
||||||
|
end
|
||||||
|
|
||||||
|
function state_update_helper_hook(
|
||||||
|
env_helper::LocalCOMEnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}
|
||||||
|
)
|
||||||
|
env_helper.vec_to_neighbour_sums[id1] += r⃗₁₂
|
||||||
|
env_helper.vec_to_neighbour_sums[id2] -= r⃗₁₂
|
||||||
|
|
||||||
|
env_helper.n_neighbours[id1] += 1
|
||||||
|
env_helper.n_neighbours[id2] += 1
|
||||||
|
|
||||||
|
return nothing
|
||||||
|
end
|
||||||
|
|
||||||
|
function state_update_hook(env_helper::LocalCOMEnvHelper, particles::Vector{Particle})
|
||||||
|
n_particles = env_helper.params.n_particles
|
||||||
|
|
||||||
|
@turbo for id in 1:(n_particles)
|
||||||
|
env_helper.params.old_states_ind[id] = env_helper.params.states_ind[id]
|
||||||
|
end
|
||||||
|
|
||||||
|
env = env_helper.params.env
|
||||||
|
|
||||||
|
for id in 1:n_particles
|
||||||
|
n_neighbours = env_helper.n_neighbours[id]
|
||||||
|
|
||||||
|
if n_neighbours == 0
|
||||||
|
state_ind = env.params.n_states
|
||||||
|
else
|
||||||
|
vec_to_local_center_of_mass =
|
||||||
|
env_helper.vec_to_neighbour_sums[id] / n_neighbours
|
||||||
|
|
||||||
|
distance = sqrt(
|
||||||
|
vec_to_local_center_of_mass[1]^2 + vec_to_local_center_of_mass[2]^2
|
||||||
|
)
|
||||||
|
|
||||||
|
distance_state = find_state_interval(distance, env.distance_state_space)
|
||||||
|
|
||||||
|
si, co = sincos(particles[id].φ)
|
||||||
|
|
||||||
|
direction_angle = angle2(SVector(co, si), vec_to_local_center_of_mass)
|
||||||
|
|
||||||
|
direction_angle_state = find_state_interval(
|
||||||
|
direction_angle, env.direction_angle_state_space
|
||||||
|
)
|
||||||
|
|
||||||
|
state = SVector{2,Interval}(distance_state, direction_angle_state)
|
||||||
|
state_ind = find_state_ind(state, env.params.state_space)
|
||||||
|
end
|
||||||
|
|
||||||
|
env_helper.params.states_ind[id] = state_ind
|
||||||
|
end
|
||||||
|
|
||||||
|
return nothing
|
||||||
|
end
|
||||||
|
|
||||||
|
function update_reward!(env::LocalCOMEnv, env_helper::LocalCOMEnvHelper, particle::Particle)
|
||||||
|
id = particle.id
|
||||||
|
|
||||||
|
normalization = (env.max_distance * env_helper.params.n_particles)
|
||||||
|
|
||||||
|
n_neighbours = env_helper.n_neighbours[id]
|
||||||
|
if n_neighbours == 0
|
||||||
|
env.params.reward = -(env.max_distance^2) / normalization
|
||||||
|
else
|
||||||
|
vec_to_local_center_of_mass = env_helper.vec_to_neighbour_sums[id] / n_neighbours # TODO: Reuse vec_to_local_center_of_mass from state_update_hook
|
||||||
|
env.params.reward =
|
||||||
|
-(vec_to_local_center_of_mass[1]^2 + vec_to_local_center_of_mass[2]^2) /
|
||||||
|
normalization
|
||||||
|
end
|
||||||
|
|
||||||
|
return nothing
|
||||||
|
end
|
|
@ -1,6 +1,6 @@
|
||||||
module RL
|
module RL
|
||||||
|
|
||||||
export run_rl
|
export run_rl, LocalCOMEnv
|
||||||
|
|
||||||
using Base: OneTo
|
using Base: OneTo
|
||||||
|
|
||||||
|
@ -12,12 +12,14 @@ using LoopVectorization: @turbo
|
||||||
using Random: Random
|
using Random: Random
|
||||||
using ProgressMeter: @showprogress
|
using ProgressMeter: @showprogress
|
||||||
|
|
||||||
using ..ReCo: ReCo, Particle, angle2, Shape
|
using ..ReCo: ReCo, Particle, angle2, Shape, DEFAULT_SKIN_TO_INTERACTION_R_RATIO
|
||||||
|
|
||||||
const INITIAL_REWARD = 0.0
|
|
||||||
const INITIAL_STATE_IND = 1
|
const INITIAL_STATE_IND = 1
|
||||||
|
const INITIAL_REWARD = 0.0
|
||||||
|
|
||||||
function angle_state_space(n_angle_states::Int64)
|
method_not_implemented() = error("Method not implemented!")
|
||||||
|
|
||||||
|
function gen_angle_state_space(n_angle_states::Int64)
|
||||||
angle_range = range(; start=-π, stop=π, length=n_angle_states + 1)
|
angle_range = range(; start=-π, stop=π, length=n_angle_states + 1)
|
||||||
|
|
||||||
angle_state_space = Vector{Interval}(undef, n_angle_states)
|
angle_state_space = Vector{Interval}(undef, n_angle_states)
|
||||||
|
@ -37,57 +39,12 @@ function angle_state_space(n_angle_states::Int64)
|
||||||
return angle_state_space
|
return angle_state_space
|
||||||
end
|
end
|
||||||
|
|
||||||
mutable struct Env <: AbstractEnv
|
function gen_distance_state_space(
|
||||||
n_actions::Int64
|
min_distance::Float64, max_distance::Float64, n_distance_states::Int64
|
||||||
action_space::Vector{SVector{2,Float64}}
|
|
||||||
action_ind_space::OneTo{Int64}
|
|
||||||
|
|
||||||
distance_state_space::Vector{Interval}
|
|
||||||
direction_angle_state_space::Vector{Interval}
|
|
||||||
|
|
||||||
n_states::Int64
|
|
||||||
state_space::Vector{SVector{2,Interval}}
|
|
||||||
state_ind_space::OneTo{Int64}
|
|
||||||
state_ind::Int64
|
|
||||||
|
|
||||||
reward::Float64
|
|
||||||
terminated::Bool
|
|
||||||
|
|
||||||
function Env(;
|
|
||||||
max_distance::Float64,
|
|
||||||
min_distance::Float64=0.0,
|
|
||||||
n_v_actions::Int64=2,
|
|
||||||
n_ω_actions::Int64=3,
|
|
||||||
max_v::Float64=40.0,
|
|
||||||
max_ω::Float64=π / 2,
|
|
||||||
n_distance_states::Int64=3,
|
|
||||||
n_direction_angle_states::Int64=3,
|
|
||||||
)
|
)
|
||||||
@assert min_distance >= 0.0
|
@assert min_distance >= 0.0
|
||||||
@assert max_distance > min_distance
|
@assert max_distance > min_distance
|
||||||
@assert n_v_actions > 1
|
|
||||||
@assert n_ω_actions > 1
|
|
||||||
@assert max_v > 0
|
|
||||||
@assert max_ω > 0
|
|
||||||
@assert n_distance_states > 1
|
@assert n_distance_states > 1
|
||||||
@assert n_direction_angle_states > 1
|
|
||||||
|
|
||||||
v_action_space = range(; start=0.0, stop=max_v, length=n_v_actions)
|
|
||||||
ω_action_space = range(; start=-max_ω, stop=max_ω, length=n_ω_actions)
|
|
||||||
|
|
||||||
n_actions = n_v_actions * n_ω_actions
|
|
||||||
|
|
||||||
action_space = Vector{SVector{2,Float64}}(undef, n_actions)
|
|
||||||
|
|
||||||
ind = 1
|
|
||||||
for v in v_action_space
|
|
||||||
for ω in ω_action_space
|
|
||||||
action_space[ind] = SVector(v, ω)
|
|
||||||
ind += 1
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
action_ind_space = OneTo(n_actions)
|
|
||||||
|
|
||||||
distance_range = range(;
|
distance_range = range(;
|
||||||
start=min_distance, stop=max_distance, length=n_distance_states + 1
|
start=min_distance, stop=max_distance, length=n_distance_states + 1
|
||||||
|
@ -107,29 +64,60 @@ mutable struct Env <: AbstractEnv
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
direction_angle_state_space = angle_state_space(n_direction_angle_states)
|
return distance_state_space
|
||||||
|
end
|
||||||
|
|
||||||
n_states = n_distance_states * n_direction_angle_states + 1
|
abstract type Env <: AbstractEnv end
|
||||||
|
|
||||||
state_space = Vector{SVector{2,Interval}}(undef, n_states - 1)
|
mutable struct EnvParams{state_dims}
|
||||||
|
n_actions::Int64
|
||||||
|
action_space::Vector{SVector{2,Float64}}
|
||||||
|
action_ind_space::OneTo{Int64}
|
||||||
|
|
||||||
|
n_states::Int64
|
||||||
|
state_space::Vector{SVector{state_dims,Interval}}
|
||||||
|
state_ind_space::OneTo{Int64}
|
||||||
|
state_ind::Int64
|
||||||
|
|
||||||
|
reward::Float64
|
||||||
|
terminated::Bool
|
||||||
|
|
||||||
|
function EnvParams(
|
||||||
|
n_states::Int64,
|
||||||
|
state_space::Vector{SVector{state_dims,Interval}};
|
||||||
|
n_v_actions::Int64=2,
|
||||||
|
n_ω_actions::Int64=3,
|
||||||
|
max_v::Float64=40.0,
|
||||||
|
max_ω::Float64=π / 2,
|
||||||
|
) where {state_dims}
|
||||||
|
@assert n_v_actions > 1
|
||||||
|
@assert n_ω_actions > 1
|
||||||
|
@assert max_v > 0
|
||||||
|
@assert max_ω > 0
|
||||||
|
|
||||||
|
v_action_space = range(; start=0.0, stop=max_v, length=n_v_actions)
|
||||||
|
ω_action_space = range(; start=-max_ω, stop=max_ω, length=n_ω_actions)
|
||||||
|
|
||||||
|
n_actions = n_v_actions * n_ω_actions
|
||||||
|
|
||||||
|
action_space = Vector{SVector{2,Float64}}(undef, n_actions)
|
||||||
|
|
||||||
ind = 1
|
ind = 1
|
||||||
for distance_state in distance_state_space
|
for v in v_action_space
|
||||||
for direction_angle_state in direction_angle_state_space
|
for ω in ω_action_space
|
||||||
state_space[ind] = SVector(distance_state, direction_angle_state)
|
action_space[ind] = SVector(v, ω)
|
||||||
ind += 1
|
ind += 1
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
# Last state is when no particle is in the skin radius
|
|
||||||
|
action_ind_space = OneTo(n_actions)
|
||||||
|
|
||||||
state_ind_space = OneTo(n_states)
|
state_ind_space = OneTo(n_states)
|
||||||
|
|
||||||
return new(
|
return new{state_dims}(
|
||||||
n_actions,
|
n_actions,
|
||||||
action_space,
|
action_space,
|
||||||
action_ind_space,
|
action_ind_space,
|
||||||
distance_state_space,
|
|
||||||
direction_angle_state_space,
|
|
||||||
n_states,
|
n_states,
|
||||||
state_space,
|
state_space,
|
||||||
state_ind_space,
|
state_ind_space,
|
||||||
|
@ -141,94 +129,78 @@ mutable struct Env <: AbstractEnv
|
||||||
end
|
end
|
||||||
|
|
||||||
function reset!(env::Env)
|
function reset!(env::Env)
|
||||||
env.state_ind = env.n_states
|
env.params.terminated = false
|
||||||
env.terminated = false
|
|
||||||
|
|
||||||
return nothing
|
return nothing
|
||||||
end
|
end
|
||||||
|
|
||||||
RLBase.state_space(env::Env) = env.state_ind_space
|
RLBase.state_space(env::Env) = env.params.state_ind_space
|
||||||
|
|
||||||
RLBase.state(env::Env) = env.state_ind
|
RLBase.state(env::Env) = env.params.state_ind
|
||||||
|
|
||||||
RLBase.action_space(env::Env) = env.action_ind_space
|
RLBase.action_space(env::Env) = env.params.action_ind_space
|
||||||
|
|
||||||
RLBase.reward(env::Env) = env.reward
|
RLBase.reward(env::Env) = env.params.reward
|
||||||
|
|
||||||
RLBase.is_terminated(env::Env) = env.terminated
|
RLBase.is_terminated(env::Env) = env.params.terminated
|
||||||
|
|
||||||
struct Params{H<:AbstractHook}
|
struct EnvHelperParams{H<:AbstractHook}
|
||||||
env::Env
|
env::Env
|
||||||
agent::Agent
|
agent::Agent
|
||||||
hook::H
|
hook::H
|
||||||
|
|
||||||
|
n_steps_before_actions_update::Int64
|
||||||
|
|
||||||
|
goal_gyration_tensor_eigvals_ratio::Float64
|
||||||
|
|
||||||
|
n_particles::Int64
|
||||||
|
|
||||||
old_states_ind::Vector{Int64}
|
old_states_ind::Vector{Int64}
|
||||||
states_ind::Vector{Int64}
|
states_ind::Vector{Int64}
|
||||||
|
|
||||||
actions::Vector{SVector{2,Float64}}
|
actions::Vector{SVector{2,Float64}}
|
||||||
actions_ind::Vector{Int64}
|
actions_ind::Vector{Int64}
|
||||||
|
|
||||||
n_steps_before_actions_update::Int64
|
function EnvHelperParams(
|
||||||
|
|
||||||
goal_gyration_tensor_eigvals_ratio::Float64
|
|
||||||
|
|
||||||
n_particles::Int64
|
|
||||||
max_distance::Float64
|
|
||||||
|
|
||||||
vec_to_neighbour_sums::Vector{SVector{2,Float64}}
|
|
||||||
n_neighbours::Vector{Int64}
|
|
||||||
|
|
||||||
function Params(
|
|
||||||
env::Env,
|
env::Env,
|
||||||
agent::Agent,
|
agent::Agent,
|
||||||
hook::H,
|
hook::H,
|
||||||
n_steps_before_actions_update::Int64,
|
n_steps_before_actions_update::Int64,
|
||||||
goal_gyration_tensor_eigvals_ratio::Float64,
|
goal_gyration_tensor_eigvals_ratio::Float64,
|
||||||
n_particles::Int64,
|
n_particles::Int64,
|
||||||
max_distance::Float64,
|
|
||||||
) where {H<:AbstractHook}
|
) where {H<:AbstractHook}
|
||||||
n_states = env.n_states
|
|
||||||
|
|
||||||
return new{H}(
|
return new{H}(
|
||||||
env,
|
env,
|
||||||
agent,
|
agent,
|
||||||
hook,
|
hook,
|
||||||
fill(0, n_particles),
|
|
||||||
fill(n_states, n_particles),
|
|
||||||
fill(SVector(0.0, 0.0), n_particles),
|
|
||||||
fill(0, n_particles),
|
|
||||||
n_steps_before_actions_update,
|
n_steps_before_actions_update,
|
||||||
goal_gyration_tensor_eigvals_ratio,
|
goal_gyration_tensor_eigvals_ratio,
|
||||||
n_particles,
|
n_particles,
|
||||||
max_distance,
|
fill(0, n_particles),
|
||||||
|
fill(0, n_particles),
|
||||||
fill(SVector(0.0, 0.0), n_particles),
|
fill(SVector(0.0, 0.0), n_particles),
|
||||||
fill(0, n_particles),
|
fill(0, n_particles),
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function pre_integration_hook(rl_params::Params)
|
abstract type EnvHelper end
|
||||||
@simd for id in 1:(rl_params.n_particles)
|
|
||||||
rl_params.vec_to_neighbour_sums[id] = SVector(0.0, 0.0)
|
function gen_env_helper(::Env, env_helper_params::EnvHelperParams)
|
||||||
rl_params.n_neighbours[id] = 0
|
return method_not_implemented()
|
||||||
end
|
end
|
||||||
|
|
||||||
return nothing
|
function pre_integration_hook(::EnvHelper)
|
||||||
|
return method_not_implemented()
|
||||||
end
|
end
|
||||||
|
|
||||||
function state_update_helper_hook(
|
function state_update_helper_hook(
|
||||||
rl_params::Params, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}
|
::EnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}
|
||||||
)
|
)
|
||||||
rl_params.vec_to_neighbour_sums[id1] += r⃗₁₂
|
return method_not_implemented()
|
||||||
rl_params.vec_to_neighbour_sums[id2] -= r⃗₁₂
|
|
||||||
|
|
||||||
rl_params.n_neighbours[id1] += 1
|
|
||||||
rl_params.n_neighbours[id2] += 1
|
|
||||||
|
|
||||||
return nothing
|
|
||||||
end
|
end
|
||||||
|
|
||||||
function find_state_ind(state::S, state_space::Vector{S}) where {S<:SVector{2,Interval}}
|
function find_state_ind(state::S, state_space::Vector{S}) where {S<:SVector}
|
||||||
return findfirst(x -> x == state, state_space)
|
return findfirst(x -> x == state, state_space)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -240,89 +212,40 @@ function find_state_interval(value::Float64, state_space::Vector{Interval})::Int
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function state_update_hook(rl_params::Params, particles::Vector{Particle})
|
function state_update_hook(::EnvHelper, particles::Vector{Particle})
|
||||||
@turbo for id in 1:(rl_params.n_particles)
|
return method_not_implemented()
|
||||||
rl_params.old_states_ind[id] = rl_params.states_ind[id]
|
|
||||||
end
|
end
|
||||||
|
|
||||||
env = rl_params.env
|
function get_env_agent_hook(env_helper::EnvHelper)
|
||||||
|
return (env_helper.params.env, env_helper.params.agent, env_helper.params.hook)
|
||||||
for id in 1:(rl_params.n_particles)
|
|
||||||
n_neighbours = rl_params.n_neighbours[id]
|
|
||||||
|
|
||||||
if n_neighbours == 0
|
|
||||||
state_ind = env.n_states
|
|
||||||
else
|
|
||||||
vec_to_local_center_of_mass = rl_params.vec_to_neighbour_sums[id] / n_neighbours
|
|
||||||
|
|
||||||
distance = sqrt(
|
|
||||||
vec_to_local_center_of_mass[1]^2 + vec_to_local_center_of_mass[2]^2
|
|
||||||
)
|
|
||||||
|
|
||||||
distance_state = find_state_interval(distance, env.distance_state_space)
|
|
||||||
|
|
||||||
si, co = sincos(particles[id].φ)
|
|
||||||
|
|
||||||
direction_angle = angle2(SVector(co, si), vec_to_local_center_of_mass)
|
|
||||||
|
|
||||||
direction_angle_state = find_state_interval(
|
|
||||||
direction_angle, env.direction_angle_state_space
|
|
||||||
)
|
|
||||||
|
|
||||||
state = SVector{2,Interval}(distance_state, direction_angle_state)
|
|
||||||
state_ind = find_state_ind(state, env.state_space)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
rl_params.states_ind[id] = state_ind
|
function update_reward!(::Env, ::EnvHelper, particle::Particle)
|
||||||
end
|
return method_not_implemented()
|
||||||
|
|
||||||
return nothing
|
|
||||||
end
|
|
||||||
|
|
||||||
function get_env_agent_hook(rl_params::Params)
|
|
||||||
return (rl_params.env, rl_params.agent, rl_params.hook)
|
|
||||||
end
|
|
||||||
|
|
||||||
function update_reward!(env::Env, rl_params::Params, particle::Particle)
|
|
||||||
id = particle.id
|
|
||||||
|
|
||||||
normalization = (rl_params.max_distance * rl_params.n_particles)
|
|
||||||
|
|
||||||
n_neighbours = rl_params.n_neighbours[id]
|
|
||||||
if n_neighbours == 0
|
|
||||||
env.reward = -(rl_params.max_distance^2) / normalization
|
|
||||||
else
|
|
||||||
vec_to_local_center_of_mass = rl_params.vec_to_neighbour_sums[id] / n_neighbours # TODO: Reuse vec_to_local_center_of_mass from state_update_hook
|
|
||||||
env.reward =
|
|
||||||
-(vec_to_local_center_of_mass[1]^2 + vec_to_local_center_of_mass[2]^2) /
|
|
||||||
normalization
|
|
||||||
end
|
|
||||||
|
|
||||||
return nothing
|
|
||||||
end
|
end
|
||||||
|
|
||||||
function update_table_and_actions_hook(
|
function update_table_and_actions_hook(
|
||||||
rl_params::Params, particle::Particle, first_integration_step::Bool
|
env_helper::EnvHelper, particle::Particle, first_integration_step::Bool
|
||||||
)
|
)
|
||||||
env, agent, hook = get_env_agent_hook(rl_params)
|
env, agent, hook = get_env_agent_hook(env_helper)
|
||||||
|
|
||||||
id = particle.id
|
id = particle.id
|
||||||
|
|
||||||
if !first_integration_step
|
if !first_integration_step
|
||||||
# Old state
|
# Old state
|
||||||
env.state_ind = rl_params.old_states_ind[id]
|
env.params.state_ind = env_helper.params.old_states_ind[id]
|
||||||
|
|
||||||
action_ind = rl_params.actions_ind[id]
|
action_ind = env_helper.params.actions_ind[id]
|
||||||
|
|
||||||
# Pre act
|
# Pre act
|
||||||
agent(PRE_ACT_STAGE, env, action_ind)
|
agent(PRE_ACT_STAGE, env, action_ind)
|
||||||
hook(PRE_ACT_STAGE, agent, env, action_ind)
|
hook(PRE_ACT_STAGE, agent, env, action_ind)
|
||||||
|
|
||||||
# Update to current state
|
# Update to current state
|
||||||
env.state_ind = rl_params.states_ind[id]
|
env.params.state_ind = env_helper.params.states_ind[id]
|
||||||
|
|
||||||
# Update reward
|
# Update reward
|
||||||
update_reward!(env, rl_params, particle)
|
update_reward!(env, env_helper, particle)
|
||||||
|
|
||||||
# Post act
|
# Post act
|
||||||
agent(POST_ACT_STAGE, env)
|
agent(POST_ACT_STAGE, env)
|
||||||
|
@ -331,10 +254,10 @@ function update_table_and_actions_hook(
|
||||||
|
|
||||||
# Update action
|
# Update action
|
||||||
action_ind = agent(env)
|
action_ind = agent(env)
|
||||||
action = env.action_space[action_ind]
|
action = env.params.action_space[action_ind]
|
||||||
|
|
||||||
rl_params.actions[id] = action
|
env_helper.params.actions[id] = action
|
||||||
rl_params.actions_ind[id] = action_ind
|
env_helper.params.actions_ind[id] = action_ind
|
||||||
|
|
||||||
return nothing
|
return nothing
|
||||||
end
|
end
|
||||||
|
@ -342,10 +265,10 @@ end
|
||||||
act_hook(::Nothing, args...) = nothing
|
act_hook(::Nothing, args...) = nothing
|
||||||
|
|
||||||
function act_hook(
|
function act_hook(
|
||||||
rl_params::Params, particle::Particle, δt::Float64, si::Float64, co::Float64
|
env_helper::EnvHelper, particle::Particle, δt::Float64, si::Float64, co::Float64
|
||||||
)
|
)
|
||||||
# Apply action
|
# Apply action
|
||||||
action = rl_params.actions[particle.id]
|
action = env_helper.params.actions[particle.id]
|
||||||
|
|
||||||
vδt = action[1] * δt
|
vδt = action[1] * δt
|
||||||
particle.tmp_c += SVector(vδt * co, vδt * si)
|
particle.tmp_c += SVector(vδt * co, vδt * si)
|
||||||
|
@ -378,6 +301,8 @@ function gen_agent(n_states::Int64, n_actions::Int64, ϵ_stable::Float64)
|
||||||
end
|
end
|
||||||
|
|
||||||
function run_rl(;
|
function run_rl(;
|
||||||
|
EnvType::Type{E},
|
||||||
|
parent_dir_appendix::String,
|
||||||
goal_gyration_tensor_eigvals_ratio::Float64,
|
goal_gyration_tensor_eigvals_ratio::Float64,
|
||||||
n_episodes::Int64=200,
|
n_episodes::Int64=200,
|
||||||
episode_duration::Float64=50.0,
|
episode_duration::Float64=50.0,
|
||||||
|
@ -385,8 +310,9 @@ function run_rl(;
|
||||||
n_particles::Int64=100,
|
n_particles::Int64=100,
|
||||||
seed::Int64=42,
|
seed::Int64=42,
|
||||||
ϵ_stable::Float64=0.0001,
|
ϵ_stable::Float64=0.0001,
|
||||||
parent_dir::String="",
|
skin_to_interaction_r_ratio::Float64=DEFAULT_SKIN_TO_INTERACTION_R_RATIO,
|
||||||
)
|
packing_ratio=0.22,
|
||||||
|
) where {E<:Env}
|
||||||
@assert 0.0 <= goal_gyration_tensor_eigvals_ratio <= 1.0
|
@assert 0.0 <= goal_gyration_tensor_eigvals_ratio <= 1.0
|
||||||
@assert n_episodes > 0
|
@assert n_episodes > 0
|
||||||
@assert episode_duration > 0
|
@assert episode_duration > 0
|
||||||
|
@ -398,30 +324,33 @@ function run_rl(;
|
||||||
Random.seed!(seed)
|
Random.seed!(seed)
|
||||||
|
|
||||||
sim_consts = ReCo.gen_sim_consts(
|
sim_consts = ReCo.gen_sim_consts(
|
||||||
n_particles, 0.0; skin_to_interaction_r_ratio=2.0, packing_ratio=0.22
|
n_particles,
|
||||||
|
0.0;
|
||||||
|
skin_to_interaction_r_ratio=skin_to_interaction_r_ratio,
|
||||||
|
packing_ratio=packing_ratio,
|
||||||
)
|
)
|
||||||
n_particles = sim_consts.n_particles
|
n_particles = sim_consts.n_particles # This not always equal to the input!
|
||||||
|
|
||||||
max_distance = sim_consts.skin_r
|
env = EnvType(sim_consts)
|
||||||
env = Env(; max_distance=max_distance)
|
|
||||||
|
|
||||||
agent = gen_agent(env.n_states, env.n_actions, ϵ_stable)
|
agent = gen_agent(env.params.n_states, env.params.n_actions, ϵ_stable)
|
||||||
|
|
||||||
n_steps_before_actions_update = round(Int64, update_actions_at / sim_consts.δt)
|
n_steps_before_actions_update = round(Int64, update_actions_at / sim_consts.δt)
|
||||||
|
|
||||||
hook = TotalRewardPerEpisode()
|
hook = TotalRewardPerEpisode()
|
||||||
|
|
||||||
rl_params = Params(
|
env_helper_params = EnvHelperParams(
|
||||||
env,
|
env,
|
||||||
agent,
|
agent,
|
||||||
hook,
|
hook,
|
||||||
n_steps_before_actions_update,
|
n_steps_before_actions_update,
|
||||||
goal_gyration_tensor_eigvals_ratio,
|
goal_gyration_tensor_eigvals_ratio,
|
||||||
n_particles,
|
n_particles,
|
||||||
max_distance,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parent_dir = "RL" * parent_dir
|
env_helper = gen_env_helper(env, env_helper_params)
|
||||||
|
|
||||||
|
parent_dir = "RL_" * parent_dir_appendix
|
||||||
|
|
||||||
# Pre experiment
|
# Pre experiment
|
||||||
hook(PRE_EXPERIMENT_STAGE, agent, env)
|
hook(PRE_EXPERIMENT_STAGE, agent, env)
|
||||||
|
@ -439,10 +368,13 @@ function run_rl(;
|
||||||
|
|
||||||
# Episode
|
# Episode
|
||||||
ReCo.run_sim(
|
ReCo.run_sim(
|
||||||
dir; duration=episode_duration, seed=rand(1:typemax(Int64)), rl_params=rl_params
|
dir;
|
||||||
|
duration=episode_duration,
|
||||||
|
seed=rand(1:typemax(Int64)),
|
||||||
|
env_helper=env_helper,
|
||||||
)
|
)
|
||||||
|
|
||||||
env.terminated = true
|
env.params.terminated = true
|
||||||
|
|
||||||
# Post episode
|
# Post episode
|
||||||
hook(POST_EPISODE_STAGE, agent, env)
|
hook(POST_EPISODE_STAGE, agent, env)
|
||||||
|
@ -456,7 +388,9 @@ function run_rl(;
|
||||||
# Post experiment
|
# Post experiment
|
||||||
hook(POST_EXPERIMENT_STAGE, agent, env)
|
hook(POST_EXPERIMENT_STAGE, agent, env)
|
||||||
|
|
||||||
return rl_params
|
return env_helper
|
||||||
end
|
end
|
||||||
|
|
||||||
|
include("LocalCOMEnv.jl")
|
||||||
|
|
||||||
end # module
|
end # module
|
|
@ -1,6 +1,6 @@
|
||||||
module ReCo
|
module ReCo
|
||||||
|
|
||||||
export init_sim, run_sim, run_rl, animate
|
export init_sim, run_sim, run_rl, animate, LocalCOMEnv
|
||||||
|
|
||||||
using StaticArrays: SVector
|
using StaticArrays: SVector
|
||||||
using OrderedCollections: OrderedDict
|
using OrderedCollections: OrderedDict
|
||||||
|
@ -26,7 +26,7 @@ include("setup.jl")
|
||||||
include("Shape.jl")
|
include("Shape.jl")
|
||||||
using .Shape
|
using .Shape
|
||||||
|
|
||||||
include("RL.jl")
|
include("RL/RL.jl")
|
||||||
using .RL
|
using .RL
|
||||||
|
|
||||||
include("simulation.jl")
|
include("simulation.jl")
|
||||||
|
|
|
@ -6,7 +6,7 @@ function run_sim(
|
||||||
snapshot_at::Float64=0.1,
|
snapshot_at::Float64=0.1,
|
||||||
seed::Int64=42,
|
seed::Int64=42,
|
||||||
n_bundle_snapshots::Int64=100,
|
n_bundle_snapshots::Int64=100,
|
||||||
rl_params::Union{RL.Params,Nothing}=nothing,
|
env_helper::Union{RL.EnvHelper,Nothing}=nothing,
|
||||||
)
|
)
|
||||||
@assert length(dir) > 0
|
@assert length(dir) > 0
|
||||||
@assert duration > 0
|
@assert duration > 0
|
||||||
|
@ -111,7 +111,7 @@ function run_sim(
|
||||||
n_bundles,
|
n_bundles,
|
||||||
dir,
|
dir,
|
||||||
save_data,
|
save_data,
|
||||||
rl_params,
|
env_helper,
|
||||||
)
|
)
|
||||||
|
|
||||||
return nothing
|
return nothing
|
||||||
|
|
|
@ -35,7 +35,7 @@ end
|
||||||
function euler!(
|
function euler!(
|
||||||
args,
|
args,
|
||||||
first_integration_step::Bool,
|
first_integration_step::Bool,
|
||||||
rl_params::Union{RL.Params,Nothing},
|
env_helper::Union{RL.EnvHelper,Nothing},
|
||||||
state_update_helper_hook::Function,
|
state_update_helper_hook::Function,
|
||||||
state_update_hook::Function,
|
state_update_hook::Function,
|
||||||
update_table_and_actions_hook::Function,
|
update_table_and_actions_hook::Function,
|
||||||
|
@ -52,7 +52,7 @@ function euler!(
|
||||||
p1_c, p2.c, args.interaction_r², args.half_box_len
|
p1_c, p2.c, args.interaction_r², args.half_box_len
|
||||||
)
|
)
|
||||||
|
|
||||||
state_update_helper_hook(rl_params, id1, id2, r⃗₁₂)
|
state_update_helper_hook(env_helper, id1, id2, r⃗₁₂)
|
||||||
|
|
||||||
if overlapping
|
if overlapping
|
||||||
factor = args.c₁ / (distance²^4) * (args.c₂ / (distance²^3) - 1.0)
|
factor = args.c₁ / (distance²^4) * (args.c₂ / (distance²^3) - 1.0)
|
||||||
|
@ -64,7 +64,7 @@ function euler!(
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
state_update_hook(rl_params, args.particles)
|
state_update_hook(env_helper, args.particles)
|
||||||
|
|
||||||
@simd for p in args.particles
|
@simd for p in args.particles
|
||||||
si, co = sincos(p.φ)
|
si, co = sincos(p.φ)
|
||||||
|
@ -75,9 +75,9 @@ function euler!(
|
||||||
|
|
||||||
restrict_coordinates!(p, args.half_box_len)
|
restrict_coordinates!(p, args.half_box_len)
|
||||||
|
|
||||||
update_table_and_actions_hook(rl_params, p, first_integration_step)
|
update_table_and_actions_hook(env_helper, p, first_integration_step)
|
||||||
|
|
||||||
RL.act_hook(rl_params, p, args.δt, si, co)
|
RL.act_hook(env_helper, p, args.δt, si, co)
|
||||||
|
|
||||||
p.φ += args.c₄ * rand_normal01()
|
p.φ += args.c₄ * rand_normal01()
|
||||||
|
|
||||||
|
@ -91,8 +91,8 @@ Base.wait(::Nothing) = nothing
|
||||||
|
|
||||||
gen_run_additional_hooks(::Nothing, args...) = false
|
gen_run_additional_hooks(::Nothing, args...) = false
|
||||||
|
|
||||||
function gen_run_additional_hooks(rl_params::RL.Params, integration_step::Int64)
|
function gen_run_additional_hooks(env_helper::RL.EnvHelper, integration_step::Int64)
|
||||||
return (integration_step % rl_params.n_steps_before_actions_update == 0) ||
|
return (integration_step % env_helper.params.n_steps_before_actions_update == 0) ||
|
||||||
(integration_step == 1)
|
(integration_step == 1)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -105,7 +105,7 @@ function simulate(
|
||||||
n_bundles::Int64,
|
n_bundles::Int64,
|
||||||
dir::String,
|
dir::String,
|
||||||
save_data::Bool,
|
save_data::Bool,
|
||||||
rl_params::Union{RL.Params,Nothing},
|
env_helper::Union{RL.EnvHelper,Nothing},
|
||||||
)
|
)
|
||||||
bundle_snapshot_counter = 0
|
bundle_snapshot_counter = 0
|
||||||
|
|
||||||
|
@ -143,10 +143,10 @@ function simulate(
|
||||||
cl = update_verlet_lists!(args, cl)
|
cl = update_verlet_lists!(args, cl)
|
||||||
end
|
end
|
||||||
|
|
||||||
run_additional_hooks = gen_run_additional_hooks(rl_params, integration_step)
|
run_additional_hooks = gen_run_additional_hooks(env_helper, integration_step)
|
||||||
|
|
||||||
if run_additional_hooks
|
if run_additional_hooks
|
||||||
RL.pre_integration_hook(rl_params)
|
RL.pre_integration_hook(env_helper)
|
||||||
|
|
||||||
state_update_helper_hook = RL.state_update_helper_hook
|
state_update_helper_hook = RL.state_update_helper_hook
|
||||||
state_update_hook = RL.state_update_hook
|
state_update_hook = RL.state_update_hook
|
||||||
|
@ -156,7 +156,7 @@ function simulate(
|
||||||
euler!(
|
euler!(
|
||||||
args,
|
args,
|
||||||
first_integration_step,
|
first_integration_step,
|
||||||
rl_params,
|
env_helper,
|
||||||
state_update_helper_hook,
|
state_update_helper_hook,
|
||||||
state_update_hook,
|
state_update_hook,
|
||||||
update_table_and_actions_hook,
|
update_table_and_actions_hook,
|
||||||
|
|
Loading…
Reference in a new issue