mirror of
https://gitlab.rlp.net/mobitar/ReCo.jl.git
synced 2024-12-21 00:51:21 +00:00
Added norm2d and sq_norm2d
This commit is contained in:
parent
bb3246a1e7
commit
0cf59e04a5
6 changed files with 59 additions and 51 deletions
|
@ -1,6 +1,6 @@
|
|||
module Geometry
|
||||
|
||||
export angle2
|
||||
export angle2, norm2d, sq_norm2d
|
||||
|
||||
using StaticArrays: SVector
|
||||
|
||||
|
@ -18,4 +18,7 @@ function angle2(a::SVector{2,Float64}, b::SVector{2,Float64})
|
|||
return rem2pi(θ, RoundNearest)
|
||||
end
|
||||
|
||||
sq_norm2d(v::SVector{2,Float64}) = v[1]^2 + v[2]^2
|
||||
norm2d(v::SVector{2,Float64}) = sqrt(sq_norm2d(v))
|
||||
|
||||
end # module
|
|
@ -64,7 +64,7 @@ function are_overlapping(
|
|||
|
||||
r⃗₁₂ = minimum_image(r⃗₁₂, half_box_len)
|
||||
|
||||
distance² = r⃗₁₂[1]^2 + r⃗₁₂[2]^2
|
||||
distance² = sq_norm2d(r⃗₁₂)
|
||||
|
||||
overlapping = distance² < overlapping_r²
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
export LocalCOMEnv
|
||||
|
||||
struct LocalCOMEnv <: Env
|
||||
params::EnvParams
|
||||
shared::EnvSharedProps
|
||||
|
||||
distance_state_space::Vector{Interval}
|
||||
direction_angle_state_space::Vector{Interval}
|
||||
|
@ -35,31 +35,35 @@ struct LocalCOMEnv <: Env
|
|||
end
|
||||
# Last state is when no particle is in the skin radius
|
||||
|
||||
params = EnvParams(n_states, state_space)
|
||||
shared = EnvSharedProps(n_states, state_space)
|
||||
|
||||
return new(params, distance_state_space, direction_angle_state_space, max_distance)
|
||||
return new(shared, distance_state_space, direction_angle_state_space, max_distance)
|
||||
end
|
||||
end
|
||||
|
||||
struct LocalCOMEnvHelper <: EnvHelper
|
||||
params::EnvHelperParams
|
||||
shared::EnvHelperSharedProps
|
||||
|
||||
vec_to_neighbour_sums::Vector{SVector{2,Float64}}
|
||||
n_neighbours::Vector{Int64}
|
||||
sq_norm2d_vec_to_local_center_of_mass::Vector{Float64}
|
||||
|
||||
function LocalCOMEnvHelper(params::EnvHelperParams)
|
||||
function LocalCOMEnvHelper(shared::EnvHelperSharedProps)
|
||||
return new(
|
||||
params, fill(SVector(0.0, 0.0), params.n_particles), fill(0, params.n_particles)
|
||||
shared,
|
||||
fill(SVector(0.0, 0.0), shared.n_particles),
|
||||
fill(0, shared.n_particles),
|
||||
zeros(shared.n_particles),
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
function gen_env_helper(::LocalCOMEnv, env_helper_params::EnvHelperParams)
|
||||
function gen_env_helper(::LocalCOMEnv, env_helper_params::EnvHelperSharedProps)
|
||||
return LocalCOMEnvHelper(env_helper_params)
|
||||
end
|
||||
|
||||
function pre_integration_hook(env_helper::LocalCOMEnvHelper)
|
||||
@simd for id in 1:(env_helper.params.n_particles)
|
||||
@simd for id in 1:(env_helper.shared.n_particles)
|
||||
env_helper.vec_to_neighbour_sums[id] = SVector(0.0, 0.0)
|
||||
env_helper.n_neighbours[id] = 0
|
||||
end
|
||||
|
@ -80,26 +84,28 @@ function state_update_helper_hook(
|
|||
end
|
||||
|
||||
function state_update_hook(env_helper::LocalCOMEnvHelper, particles::Vector{Particle})
|
||||
n_particles = env_helper.params.n_particles
|
||||
n_particles = env_helper.shared.n_particles
|
||||
|
||||
@turbo for id in 1:(n_particles)
|
||||
env_helper.params.old_states_ind[id] = env_helper.params.states_ind[id]
|
||||
env_helper.shared.old_states_ind[id] = env_helper.shared.states_ind[id]
|
||||
end
|
||||
|
||||
env = env_helper.params.env
|
||||
env = env_helper.shared.env
|
||||
|
||||
for id in 1:n_particles
|
||||
n_neighbours = env_helper.n_neighbours[id]
|
||||
|
||||
if n_neighbours == 0
|
||||
state_ind = env.params.n_states
|
||||
state_ind = env.shared.n_states
|
||||
else
|
||||
vec_to_local_center_of_mass =
|
||||
env_helper.vec_to_neighbour_sums[id] / n_neighbours
|
||||
|
||||
distance = sqrt(
|
||||
vec_to_local_center_of_mass[1]^2 + vec_to_local_center_of_mass[2]^2
|
||||
)
|
||||
sq_norm2d_vec_to_local_center_of_mass = sq_norm2d(vec_to_local_center_of_mass)
|
||||
env_helper.sq_norm2d_vec_to_local_center_of_mass[id] =
|
||||
sq_norm2d_vec_to_local_center_of_mass
|
||||
|
||||
distance = sqrt(sq_norm2d_vec_to_local_center_of_mass)
|
||||
|
||||
distance_state = find_state_interval(distance, env.distance_state_space)
|
||||
|
||||
|
@ -112,10 +118,10 @@ function state_update_hook(env_helper::LocalCOMEnvHelper, particles::Vector{Part
|
|||
)
|
||||
|
||||
state = SVector{2,Interval}(distance_state, direction_angle_state)
|
||||
state_ind = find_state_ind(state, env.params.state_space)
|
||||
state_ind = find_state_ind(state, env.shared.state_space)
|
||||
end
|
||||
|
||||
env_helper.params.states_ind[id] = state_ind
|
||||
env_helper.shared.states_ind[id] = state_ind
|
||||
end
|
||||
|
||||
return nothing
|
||||
|
@ -124,16 +130,14 @@ end
|
|||
function update_reward!(env::LocalCOMEnv, env_helper::LocalCOMEnvHelper, particle::Particle)
|
||||
id = particle.id
|
||||
|
||||
normalization = (env.max_distance * env_helper.params.n_particles)
|
||||
normalization = (env.max_distance * env_helper.shared.n_particles)
|
||||
|
||||
n_neighbours = env_helper.n_neighbours[id]
|
||||
if n_neighbours == 0
|
||||
env.params.reward = -(env.max_distance^2) / normalization
|
||||
env.shared.reward = -(env.max_distance^2) / normalization
|
||||
else
|
||||
vec_to_local_center_of_mass = env_helper.vec_to_neighbour_sums[id] / n_neighbours # TODO: Reuse vec_to_local_center_of_mass from state_update_hook
|
||||
env.params.reward =
|
||||
-(vec_to_local_center_of_mass[1]^2 + vec_to_local_center_of_mass[2]^2) /
|
||||
normalization
|
||||
env.shared.reward =
|
||||
-(env_helper.sq_norm2d_vec_to_local_center_of_mass[id]) / normalization # TODO: Add shape term
|
||||
end
|
||||
|
||||
return nothing
|
||||
|
|
47
src/RL/RL.jl
47
src/RL/RL.jl
|
@ -12,7 +12,8 @@ using LoopVectorization: @turbo
|
|||
using Random: Random
|
||||
using ProgressMeter: @showprogress
|
||||
|
||||
using ..ReCo: ReCo, Particle, angle2, Shape, DEFAULT_SKIN_TO_INTERACTION_R_RATIO
|
||||
using ..ReCo:
|
||||
ReCo, Particle, angle2, norm2d, sq_norm2d, Shape, DEFAULT_SKIN_TO_INTERACTION_R_RATIO
|
||||
|
||||
const INITIAL_STATE_IND = 1
|
||||
const INITIAL_REWARD = 0.0
|
||||
|
@ -69,7 +70,7 @@ end
|
|||
|
||||
abstract type Env <: AbstractEnv end
|
||||
|
||||
mutable struct EnvParams{state_dims}
|
||||
mutable struct EnvSharedProps{state_dims}
|
||||
n_actions::Int64
|
||||
action_space::Vector{SVector{2,Float64}}
|
||||
action_ind_space::OneTo{Int64}
|
||||
|
@ -82,7 +83,7 @@ mutable struct EnvParams{state_dims}
|
|||
reward::Float64
|
||||
terminated::Bool
|
||||
|
||||
function EnvParams(
|
||||
function EnvSharedProps(
|
||||
n_states::Int64,
|
||||
state_space::Vector{SVector{state_dims,Interval}};
|
||||
n_v_actions::Int64=2,
|
||||
|
@ -129,22 +130,22 @@ mutable struct EnvParams{state_dims}
|
|||
end
|
||||
|
||||
function reset!(env::Env)
|
||||
env.params.terminated = false
|
||||
env.shared.terminated = false
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
RLBase.state_space(env::Env) = env.params.state_ind_space
|
||||
RLBase.state_space(env::Env) = env.shared.state_ind_space
|
||||
|
||||
RLBase.state(env::Env) = env.params.state_ind
|
||||
RLBase.state(env::Env) = env.shared.state_ind
|
||||
|
||||
RLBase.action_space(env::Env) = env.params.action_ind_space
|
||||
RLBase.action_space(env::Env) = env.shared.action_ind_space
|
||||
|
||||
RLBase.reward(env::Env) = env.params.reward
|
||||
RLBase.reward(env::Env) = env.shared.reward
|
||||
|
||||
RLBase.is_terminated(env::Env) = env.params.terminated
|
||||
RLBase.is_terminated(env::Env) = env.shared.terminated
|
||||
|
||||
struct EnvHelperParams{H<:AbstractHook}
|
||||
struct EnvHelperSharedProps{H<:AbstractHook}
|
||||
env::Env
|
||||
agent::Agent
|
||||
hook::H
|
||||
|
@ -161,7 +162,7 @@ struct EnvHelperParams{H<:AbstractHook}
|
|||
actions::Vector{SVector{2,Float64}}
|
||||
actions_ind::Vector{Int64}
|
||||
|
||||
function EnvHelperParams(
|
||||
function EnvHelperSharedProps(
|
||||
env::Env,
|
||||
agent::Agent,
|
||||
hook::H,
|
||||
|
@ -186,7 +187,7 @@ end
|
|||
|
||||
abstract type EnvHelper end
|
||||
|
||||
function gen_env_helper(::Env, env_helper_params::EnvHelperParams)
|
||||
function gen_env_helper(::Env, env_helper_params::EnvHelperSharedProps)
|
||||
return method_not_implemented()
|
||||
end
|
||||
|
||||
|
@ -217,7 +218,7 @@ function state_update_hook(::EnvHelper, particles::Vector{Particle})
|
|||
end
|
||||
|
||||
function get_env_agent_hook(env_helper::EnvHelper)
|
||||
return (env_helper.params.env, env_helper.params.agent, env_helper.params.hook)
|
||||
return (env_helper.shared.env, env_helper.shared.agent, env_helper.shared.hook)
|
||||
end
|
||||
|
||||
function update_reward!(::Env, ::EnvHelper, particle::Particle)
|
||||
|
@ -233,16 +234,16 @@ function update_table_and_actions_hook(
|
|||
|
||||
if !first_integration_step
|
||||
# Old state
|
||||
env.params.state_ind = env_helper.params.old_states_ind[id]
|
||||
env.shared.state_ind = env_helper.shared.old_states_ind[id]
|
||||
|
||||
action_ind = env_helper.params.actions_ind[id]
|
||||
action_ind = env_helper.shared.actions_ind[id]
|
||||
|
||||
# Pre act
|
||||
agent(PRE_ACT_STAGE, env, action_ind)
|
||||
hook(PRE_ACT_STAGE, agent, env, action_ind)
|
||||
|
||||
# Update to current state
|
||||
env.params.state_ind = env_helper.params.states_ind[id]
|
||||
env.shared.state_ind = env_helper.shared.states_ind[id]
|
||||
|
||||
# Update reward
|
||||
update_reward!(env, env_helper, particle)
|
||||
|
@ -254,10 +255,10 @@ function update_table_and_actions_hook(
|
|||
|
||||
# Update action
|
||||
action_ind = agent(env)
|
||||
action = env.params.action_space[action_ind]
|
||||
action = env.shared.action_space[action_ind]
|
||||
|
||||
env_helper.params.actions[id] = action
|
||||
env_helper.params.actions_ind[id] = action_ind
|
||||
env_helper.shared.actions[id] = action
|
||||
env_helper.shared.actions_ind[id] = action_ind
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
@ -268,7 +269,7 @@ function act_hook(
|
|||
env_helper::EnvHelper, particle::Particle, δt::Float64, si::Float64, co::Float64
|
||||
)
|
||||
# Apply action
|
||||
action = env_helper.params.actions[particle.id]
|
||||
action = env_helper.shared.actions[particle.id]
|
||||
|
||||
vδt = action[1] * δt
|
||||
particle.tmp_c += SVector(vδt * co, vδt * si)
|
||||
|
@ -333,13 +334,13 @@ function run_rl(;
|
|||
|
||||
env = EnvType(sim_consts)
|
||||
|
||||
agent = gen_agent(env.params.n_states, env.params.n_actions, ϵ_stable)
|
||||
agent = gen_agent(env.shared.n_states, env.shared.n_actions, ϵ_stable)
|
||||
|
||||
n_steps_before_actions_update = round(Int64, update_actions_at / sim_consts.δt)
|
||||
|
||||
hook = TotalRewardPerEpisode()
|
||||
|
||||
env_helper_params = EnvHelperParams(
|
||||
env_helper_params = EnvHelperSharedProps(
|
||||
env,
|
||||
agent,
|
||||
hook,
|
||||
|
@ -374,7 +375,7 @@ function run_rl(;
|
|||
env_helper=env_helper,
|
||||
)
|
||||
|
||||
env.params.terminated = true
|
||||
env.shared.terminated = true
|
||||
|
||||
# Post episode
|
||||
hook(POST_EPISODE_STAGE, agent, env)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
using CairoMakie, LaTeXStrings
|
||||
using LoopVectorization: @turbo
|
||||
|
||||
using ReCo: minimum_image
|
||||
using ReCo: minimum_image, norm2d
|
||||
|
||||
function plot_g(radius, g, variables)
|
||||
fig = Figure()
|
||||
|
@ -44,7 +44,7 @@ function pair_correlation(sol, variables)
|
|||
|
||||
r⃗₁₂ = minimum_image(r⃗₁₂, variables.half_box_len)
|
||||
|
||||
distance = sqrt(r⃗₁₂[1]^2 + r⃗₁₂[2]^2)
|
||||
distance = norm2d(r⃗₁₂)
|
||||
|
||||
if (distance >= r) && (distance <= r + dr)
|
||||
N_g[i, r_ind] += 1
|
||||
|
|
|
@ -92,7 +92,7 @@ Base.wait(::Nothing) = nothing
|
|||
gen_run_additional_hooks(::Nothing, args...) = false
|
||||
|
||||
function gen_run_additional_hooks(env_helper::RL.EnvHelper, integration_step::Int64)
|
||||
return (integration_step % env_helper.params.n_steps_before_actions_update == 0) ||
|
||||
return (integration_step % env_helper.shared.n_steps_before_actions_update == 0) ||
|
||||
(integration_step == 1)
|
||||
end
|
||||
|
||||
|
|
Loading…
Reference in a new issue