mirror of
https://gitlab.rlp.net/mobitar/ReCo.jl.git
synced 2024-11-08 22:21:08 +00:00
Added norm2d and sq_norm2d
This commit is contained in:
parent
bb3246a1e7
commit
0cf59e04a5
6 changed files with 59 additions and 51 deletions
|
@ -1,6 +1,6 @@
|
||||||
module Geometry
|
module Geometry
|
||||||
|
|
||||||
export angle2
|
export angle2, norm2d, sq_norm2d
|
||||||
|
|
||||||
using StaticArrays: SVector
|
using StaticArrays: SVector
|
||||||
|
|
||||||
|
@ -18,4 +18,7 @@ function angle2(a::SVector{2,Float64}, b::SVector{2,Float64})
|
||||||
return rem2pi(θ, RoundNearest)
|
return rem2pi(θ, RoundNearest)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
sq_norm2d(v::SVector{2,Float64}) = v[1]^2 + v[2]^2
|
||||||
|
norm2d(v::SVector{2,Float64}) = sqrt(sq_norm2d(v))
|
||||||
|
|
||||||
end # module
|
end # module
|
|
@ -64,7 +64,7 @@ function are_overlapping(
|
||||||
|
|
||||||
r⃗₁₂ = minimum_image(r⃗₁₂, half_box_len)
|
r⃗₁₂ = minimum_image(r⃗₁₂, half_box_len)
|
||||||
|
|
||||||
distance² = r⃗₁₂[1]^2 + r⃗₁₂[2]^2
|
distance² = sq_norm2d(r⃗₁₂)
|
||||||
|
|
||||||
overlapping = distance² < overlapping_r²
|
overlapping = distance² < overlapping_r²
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
export LocalCOMEnv
|
export LocalCOMEnv
|
||||||
|
|
||||||
struct LocalCOMEnv <: Env
|
struct LocalCOMEnv <: Env
|
||||||
params::EnvParams
|
shared::EnvSharedProps
|
||||||
|
|
||||||
distance_state_space::Vector{Interval}
|
distance_state_space::Vector{Interval}
|
||||||
direction_angle_state_space::Vector{Interval}
|
direction_angle_state_space::Vector{Interval}
|
||||||
|
@ -35,31 +35,35 @@ struct LocalCOMEnv <: Env
|
||||||
end
|
end
|
||||||
# Last state is when no particle is in the skin radius
|
# Last state is when no particle is in the skin radius
|
||||||
|
|
||||||
params = EnvParams(n_states, state_space)
|
shared = EnvSharedProps(n_states, state_space)
|
||||||
|
|
||||||
return new(params, distance_state_space, direction_angle_state_space, max_distance)
|
return new(shared, distance_state_space, direction_angle_state_space, max_distance)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
struct LocalCOMEnvHelper <: EnvHelper
|
struct LocalCOMEnvHelper <: EnvHelper
|
||||||
params::EnvHelperParams
|
shared::EnvHelperSharedProps
|
||||||
|
|
||||||
vec_to_neighbour_sums::Vector{SVector{2,Float64}}
|
vec_to_neighbour_sums::Vector{SVector{2,Float64}}
|
||||||
n_neighbours::Vector{Int64}
|
n_neighbours::Vector{Int64}
|
||||||
|
sq_norm2d_vec_to_local_center_of_mass::Vector{Float64}
|
||||||
|
|
||||||
function LocalCOMEnvHelper(params::EnvHelperParams)
|
function LocalCOMEnvHelper(shared::EnvHelperSharedProps)
|
||||||
return new(
|
return new(
|
||||||
params, fill(SVector(0.0, 0.0), params.n_particles), fill(0, params.n_particles)
|
shared,
|
||||||
|
fill(SVector(0.0, 0.0), shared.n_particles),
|
||||||
|
fill(0, shared.n_particles),
|
||||||
|
zeros(shared.n_particles),
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function gen_env_helper(::LocalCOMEnv, env_helper_params::EnvHelperParams)
|
function gen_env_helper(::LocalCOMEnv, env_helper_params::EnvHelperSharedProps)
|
||||||
return LocalCOMEnvHelper(env_helper_params)
|
return LocalCOMEnvHelper(env_helper_params)
|
||||||
end
|
end
|
||||||
|
|
||||||
function pre_integration_hook(env_helper::LocalCOMEnvHelper)
|
function pre_integration_hook(env_helper::LocalCOMEnvHelper)
|
||||||
@simd for id in 1:(env_helper.params.n_particles)
|
@simd for id in 1:(env_helper.shared.n_particles)
|
||||||
env_helper.vec_to_neighbour_sums[id] = SVector(0.0, 0.0)
|
env_helper.vec_to_neighbour_sums[id] = SVector(0.0, 0.0)
|
||||||
env_helper.n_neighbours[id] = 0
|
env_helper.n_neighbours[id] = 0
|
||||||
end
|
end
|
||||||
|
@ -80,26 +84,28 @@ function state_update_helper_hook(
|
||||||
end
|
end
|
||||||
|
|
||||||
function state_update_hook(env_helper::LocalCOMEnvHelper, particles::Vector{Particle})
|
function state_update_hook(env_helper::LocalCOMEnvHelper, particles::Vector{Particle})
|
||||||
n_particles = env_helper.params.n_particles
|
n_particles = env_helper.shared.n_particles
|
||||||
|
|
||||||
@turbo for id in 1:(n_particles)
|
@turbo for id in 1:(n_particles)
|
||||||
env_helper.params.old_states_ind[id] = env_helper.params.states_ind[id]
|
env_helper.shared.old_states_ind[id] = env_helper.shared.states_ind[id]
|
||||||
end
|
end
|
||||||
|
|
||||||
env = env_helper.params.env
|
env = env_helper.shared.env
|
||||||
|
|
||||||
for id in 1:n_particles
|
for id in 1:n_particles
|
||||||
n_neighbours = env_helper.n_neighbours[id]
|
n_neighbours = env_helper.n_neighbours[id]
|
||||||
|
|
||||||
if n_neighbours == 0
|
if n_neighbours == 0
|
||||||
state_ind = env.params.n_states
|
state_ind = env.shared.n_states
|
||||||
else
|
else
|
||||||
vec_to_local_center_of_mass =
|
vec_to_local_center_of_mass =
|
||||||
env_helper.vec_to_neighbour_sums[id] / n_neighbours
|
env_helper.vec_to_neighbour_sums[id] / n_neighbours
|
||||||
|
|
||||||
distance = sqrt(
|
sq_norm2d_vec_to_local_center_of_mass = sq_norm2d(vec_to_local_center_of_mass)
|
||||||
vec_to_local_center_of_mass[1]^2 + vec_to_local_center_of_mass[2]^2
|
env_helper.sq_norm2d_vec_to_local_center_of_mass[id] =
|
||||||
)
|
sq_norm2d_vec_to_local_center_of_mass
|
||||||
|
|
||||||
|
distance = sqrt(sq_norm2d_vec_to_local_center_of_mass)
|
||||||
|
|
||||||
distance_state = find_state_interval(distance, env.distance_state_space)
|
distance_state = find_state_interval(distance, env.distance_state_space)
|
||||||
|
|
||||||
|
@ -112,10 +118,10 @@ function state_update_hook(env_helper::LocalCOMEnvHelper, particles::Vector{Part
|
||||||
)
|
)
|
||||||
|
|
||||||
state = SVector{2,Interval}(distance_state, direction_angle_state)
|
state = SVector{2,Interval}(distance_state, direction_angle_state)
|
||||||
state_ind = find_state_ind(state, env.params.state_space)
|
state_ind = find_state_ind(state, env.shared.state_space)
|
||||||
end
|
end
|
||||||
|
|
||||||
env_helper.params.states_ind[id] = state_ind
|
env_helper.shared.states_ind[id] = state_ind
|
||||||
end
|
end
|
||||||
|
|
||||||
return nothing
|
return nothing
|
||||||
|
@ -124,16 +130,14 @@ end
|
||||||
function update_reward!(env::LocalCOMEnv, env_helper::LocalCOMEnvHelper, particle::Particle)
|
function update_reward!(env::LocalCOMEnv, env_helper::LocalCOMEnvHelper, particle::Particle)
|
||||||
id = particle.id
|
id = particle.id
|
||||||
|
|
||||||
normalization = (env.max_distance * env_helper.params.n_particles)
|
normalization = (env.max_distance * env_helper.shared.n_particles)
|
||||||
|
|
||||||
n_neighbours = env_helper.n_neighbours[id]
|
n_neighbours = env_helper.n_neighbours[id]
|
||||||
if n_neighbours == 0
|
if n_neighbours == 0
|
||||||
env.params.reward = -(env.max_distance^2) / normalization
|
env.shared.reward = -(env.max_distance^2) / normalization
|
||||||
else
|
else
|
||||||
vec_to_local_center_of_mass = env_helper.vec_to_neighbour_sums[id] / n_neighbours # TODO: Reuse vec_to_local_center_of_mass from state_update_hook
|
env.shared.reward =
|
||||||
env.params.reward =
|
-(env_helper.sq_norm2d_vec_to_local_center_of_mass[id]) / normalization # TODO: Add shape term
|
||||||
-(vec_to_local_center_of_mass[1]^2 + vec_to_local_center_of_mass[2]^2) /
|
|
||||||
normalization
|
|
||||||
end
|
end
|
||||||
|
|
||||||
return nothing
|
return nothing
|
||||||
|
|
47
src/RL/RL.jl
47
src/RL/RL.jl
|
@ -12,7 +12,8 @@ using LoopVectorization: @turbo
|
||||||
using Random: Random
|
using Random: Random
|
||||||
using ProgressMeter: @showprogress
|
using ProgressMeter: @showprogress
|
||||||
|
|
||||||
using ..ReCo: ReCo, Particle, angle2, Shape, DEFAULT_SKIN_TO_INTERACTION_R_RATIO
|
using ..ReCo:
|
||||||
|
ReCo, Particle, angle2, norm2d, sq_norm2d, Shape, DEFAULT_SKIN_TO_INTERACTION_R_RATIO
|
||||||
|
|
||||||
const INITIAL_STATE_IND = 1
|
const INITIAL_STATE_IND = 1
|
||||||
const INITIAL_REWARD = 0.0
|
const INITIAL_REWARD = 0.0
|
||||||
|
@ -69,7 +70,7 @@ end
|
||||||
|
|
||||||
abstract type Env <: AbstractEnv end
|
abstract type Env <: AbstractEnv end
|
||||||
|
|
||||||
mutable struct EnvParams{state_dims}
|
mutable struct EnvSharedProps{state_dims}
|
||||||
n_actions::Int64
|
n_actions::Int64
|
||||||
action_space::Vector{SVector{2,Float64}}
|
action_space::Vector{SVector{2,Float64}}
|
||||||
action_ind_space::OneTo{Int64}
|
action_ind_space::OneTo{Int64}
|
||||||
|
@ -82,7 +83,7 @@ mutable struct EnvParams{state_dims}
|
||||||
reward::Float64
|
reward::Float64
|
||||||
terminated::Bool
|
terminated::Bool
|
||||||
|
|
||||||
function EnvParams(
|
function EnvSharedProps(
|
||||||
n_states::Int64,
|
n_states::Int64,
|
||||||
state_space::Vector{SVector{state_dims,Interval}};
|
state_space::Vector{SVector{state_dims,Interval}};
|
||||||
n_v_actions::Int64=2,
|
n_v_actions::Int64=2,
|
||||||
|
@ -129,22 +130,22 @@ mutable struct EnvParams{state_dims}
|
||||||
end
|
end
|
||||||
|
|
||||||
function reset!(env::Env)
|
function reset!(env::Env)
|
||||||
env.params.terminated = false
|
env.shared.terminated = false
|
||||||
|
|
||||||
return nothing
|
return nothing
|
||||||
end
|
end
|
||||||
|
|
||||||
RLBase.state_space(env::Env) = env.params.state_ind_space
|
RLBase.state_space(env::Env) = env.shared.state_ind_space
|
||||||
|
|
||||||
RLBase.state(env::Env) = env.params.state_ind
|
RLBase.state(env::Env) = env.shared.state_ind
|
||||||
|
|
||||||
RLBase.action_space(env::Env) = env.params.action_ind_space
|
RLBase.action_space(env::Env) = env.shared.action_ind_space
|
||||||
|
|
||||||
RLBase.reward(env::Env) = env.params.reward
|
RLBase.reward(env::Env) = env.shared.reward
|
||||||
|
|
||||||
RLBase.is_terminated(env::Env) = env.params.terminated
|
RLBase.is_terminated(env::Env) = env.shared.terminated
|
||||||
|
|
||||||
struct EnvHelperParams{H<:AbstractHook}
|
struct EnvHelperSharedProps{H<:AbstractHook}
|
||||||
env::Env
|
env::Env
|
||||||
agent::Agent
|
agent::Agent
|
||||||
hook::H
|
hook::H
|
||||||
|
@ -161,7 +162,7 @@ struct EnvHelperParams{H<:AbstractHook}
|
||||||
actions::Vector{SVector{2,Float64}}
|
actions::Vector{SVector{2,Float64}}
|
||||||
actions_ind::Vector{Int64}
|
actions_ind::Vector{Int64}
|
||||||
|
|
||||||
function EnvHelperParams(
|
function EnvHelperSharedProps(
|
||||||
env::Env,
|
env::Env,
|
||||||
agent::Agent,
|
agent::Agent,
|
||||||
hook::H,
|
hook::H,
|
||||||
|
@ -186,7 +187,7 @@ end
|
||||||
|
|
||||||
abstract type EnvHelper end
|
abstract type EnvHelper end
|
||||||
|
|
||||||
function gen_env_helper(::Env, env_helper_params::EnvHelperParams)
|
function gen_env_helper(::Env, env_helper_params::EnvHelperSharedProps)
|
||||||
return method_not_implemented()
|
return method_not_implemented()
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -217,7 +218,7 @@ function state_update_hook(::EnvHelper, particles::Vector{Particle})
|
||||||
end
|
end
|
||||||
|
|
||||||
function get_env_agent_hook(env_helper::EnvHelper)
|
function get_env_agent_hook(env_helper::EnvHelper)
|
||||||
return (env_helper.params.env, env_helper.params.agent, env_helper.params.hook)
|
return (env_helper.shared.env, env_helper.shared.agent, env_helper.shared.hook)
|
||||||
end
|
end
|
||||||
|
|
||||||
function update_reward!(::Env, ::EnvHelper, particle::Particle)
|
function update_reward!(::Env, ::EnvHelper, particle::Particle)
|
||||||
|
@ -233,16 +234,16 @@ function update_table_and_actions_hook(
|
||||||
|
|
||||||
if !first_integration_step
|
if !first_integration_step
|
||||||
# Old state
|
# Old state
|
||||||
env.params.state_ind = env_helper.params.old_states_ind[id]
|
env.shared.state_ind = env_helper.shared.old_states_ind[id]
|
||||||
|
|
||||||
action_ind = env_helper.params.actions_ind[id]
|
action_ind = env_helper.shared.actions_ind[id]
|
||||||
|
|
||||||
# Pre act
|
# Pre act
|
||||||
agent(PRE_ACT_STAGE, env, action_ind)
|
agent(PRE_ACT_STAGE, env, action_ind)
|
||||||
hook(PRE_ACT_STAGE, agent, env, action_ind)
|
hook(PRE_ACT_STAGE, agent, env, action_ind)
|
||||||
|
|
||||||
# Update to current state
|
# Update to current state
|
||||||
env.params.state_ind = env_helper.params.states_ind[id]
|
env.shared.state_ind = env_helper.shared.states_ind[id]
|
||||||
|
|
||||||
# Update reward
|
# Update reward
|
||||||
update_reward!(env, env_helper, particle)
|
update_reward!(env, env_helper, particle)
|
||||||
|
@ -254,10 +255,10 @@ function update_table_and_actions_hook(
|
||||||
|
|
||||||
# Update action
|
# Update action
|
||||||
action_ind = agent(env)
|
action_ind = agent(env)
|
||||||
action = env.params.action_space[action_ind]
|
action = env.shared.action_space[action_ind]
|
||||||
|
|
||||||
env_helper.params.actions[id] = action
|
env_helper.shared.actions[id] = action
|
||||||
env_helper.params.actions_ind[id] = action_ind
|
env_helper.shared.actions_ind[id] = action_ind
|
||||||
|
|
||||||
return nothing
|
return nothing
|
||||||
end
|
end
|
||||||
|
@ -268,7 +269,7 @@ function act_hook(
|
||||||
env_helper::EnvHelper, particle::Particle, δt::Float64, si::Float64, co::Float64
|
env_helper::EnvHelper, particle::Particle, δt::Float64, si::Float64, co::Float64
|
||||||
)
|
)
|
||||||
# Apply action
|
# Apply action
|
||||||
action = env_helper.params.actions[particle.id]
|
action = env_helper.shared.actions[particle.id]
|
||||||
|
|
||||||
vδt = action[1] * δt
|
vδt = action[1] * δt
|
||||||
particle.tmp_c += SVector(vδt * co, vδt * si)
|
particle.tmp_c += SVector(vδt * co, vδt * si)
|
||||||
|
@ -333,13 +334,13 @@ function run_rl(;
|
||||||
|
|
||||||
env = EnvType(sim_consts)
|
env = EnvType(sim_consts)
|
||||||
|
|
||||||
agent = gen_agent(env.params.n_states, env.params.n_actions, ϵ_stable)
|
agent = gen_agent(env.shared.n_states, env.shared.n_actions, ϵ_stable)
|
||||||
|
|
||||||
n_steps_before_actions_update = round(Int64, update_actions_at / sim_consts.δt)
|
n_steps_before_actions_update = round(Int64, update_actions_at / sim_consts.δt)
|
||||||
|
|
||||||
hook = TotalRewardPerEpisode()
|
hook = TotalRewardPerEpisode()
|
||||||
|
|
||||||
env_helper_params = EnvHelperParams(
|
env_helper_params = EnvHelperSharedProps(
|
||||||
env,
|
env,
|
||||||
agent,
|
agent,
|
||||||
hook,
|
hook,
|
||||||
|
@ -374,7 +375,7 @@ function run_rl(;
|
||||||
env_helper=env_helper,
|
env_helper=env_helper,
|
||||||
)
|
)
|
||||||
|
|
||||||
env.params.terminated = true
|
env.shared.terminated = true
|
||||||
|
|
||||||
# Post episode
|
# Post episode
|
||||||
hook(POST_EPISODE_STAGE, agent, env)
|
hook(POST_EPISODE_STAGE, agent, env)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
using CairoMakie, LaTeXStrings
|
using CairoMakie, LaTeXStrings
|
||||||
using LoopVectorization: @turbo
|
using LoopVectorization: @turbo
|
||||||
|
|
||||||
using ReCo: minimum_image
|
using ReCo: minimum_image, norm2d
|
||||||
|
|
||||||
function plot_g(radius, g, variables)
|
function plot_g(radius, g, variables)
|
||||||
fig = Figure()
|
fig = Figure()
|
||||||
|
@ -44,7 +44,7 @@ function pair_correlation(sol, variables)
|
||||||
|
|
||||||
r⃗₁₂ = minimum_image(r⃗₁₂, variables.half_box_len)
|
r⃗₁₂ = minimum_image(r⃗₁₂, variables.half_box_len)
|
||||||
|
|
||||||
distance = sqrt(r⃗₁₂[1]^2 + r⃗₁₂[2]^2)
|
distance = norm2d(r⃗₁₂)
|
||||||
|
|
||||||
if (distance >= r) && (distance <= r + dr)
|
if (distance >= r) && (distance <= r + dr)
|
||||||
N_g[i, r_ind] += 1
|
N_g[i, r_ind] += 1
|
||||||
|
|
|
@ -92,7 +92,7 @@ Base.wait(::Nothing) = nothing
|
||||||
gen_run_additional_hooks(::Nothing, args...) = false
|
gen_run_additional_hooks(::Nothing, args...) = false
|
||||||
|
|
||||||
function gen_run_additional_hooks(env_helper::RL.EnvHelper, integration_step::Int64)
|
function gen_run_additional_hooks(env_helper::RL.EnvHelper, integration_step::Int64)
|
||||||
return (integration_step % env_helper.params.n_steps_before_actions_update == 0) ||
|
return (integration_step % env_helper.shared.n_steps_before_actions_update == 0) ||
|
||||||
(integration_step == 1)
|
(integration_step == 1)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue