1
0
Fork 0
mirror of https://gitlab.rlp.net/mobitar/ReCo.jl.git synced 2024-11-08 22:21:08 +00:00

Added norm2d and sq_norm2d

This commit is contained in:
Mo8it 2022-01-11 18:39:38 +01:00
parent bb3246a1e7
commit 0cf59e04a5
6 changed files with 59 additions and 51 deletions

View file

@ -1,6 +1,6 @@
module Geometry module Geometry
export angle2 export angle2, norm2d, sq_norm2d
using StaticArrays: SVector using StaticArrays: SVector
@ -18,4 +18,7 @@ function angle2(a::SVector{2,Float64}, b::SVector{2,Float64})
return rem2pi(θ, RoundNearest) return rem2pi(θ, RoundNearest)
end end
sq_norm2d(v::SVector{2,Float64}) = v[1]^2 + v[2]^2
norm2d(v::SVector{2,Float64}) = sqrt(sq_norm2d(v))
end # module end # module

View file

@ -64,7 +64,7 @@ function are_overlapping(
r⃗₁₂ = minimum_image(r⃗₁₂, half_box_len) r⃗₁₂ = minimum_image(r⃗₁₂, half_box_len)
distance² = r⃗₁₂[1]^2 + r⃗₁₂[2]^2 distance² = sq_norm2d(r⃗₁₂)
overlapping = distance² < overlapping_r² overlapping = distance² < overlapping_r²

View file

@ -1,7 +1,7 @@
export LocalCOMEnv export LocalCOMEnv
struct LocalCOMEnv <: Env struct LocalCOMEnv <: Env
params::EnvParams shared::EnvSharedProps
distance_state_space::Vector{Interval} distance_state_space::Vector{Interval}
direction_angle_state_space::Vector{Interval} direction_angle_state_space::Vector{Interval}
@ -35,31 +35,35 @@ struct LocalCOMEnv <: Env
end end
# Last state is when no particle is in the skin radius # Last state is when no particle is in the skin radius
params = EnvParams(n_states, state_space) shared = EnvSharedProps(n_states, state_space)
return new(params, distance_state_space, direction_angle_state_space, max_distance) return new(shared, distance_state_space, direction_angle_state_space, max_distance)
end end
end end
struct LocalCOMEnvHelper <: EnvHelper struct LocalCOMEnvHelper <: EnvHelper
params::EnvHelperParams shared::EnvHelperSharedProps
vec_to_neighbour_sums::Vector{SVector{2,Float64}} vec_to_neighbour_sums::Vector{SVector{2,Float64}}
n_neighbours::Vector{Int64} n_neighbours::Vector{Int64}
sq_norm2d_vec_to_local_center_of_mass::Vector{Float64}
function LocalCOMEnvHelper(params::EnvHelperParams) function LocalCOMEnvHelper(shared::EnvHelperSharedProps)
return new( return new(
params, fill(SVector(0.0, 0.0), params.n_particles), fill(0, params.n_particles) shared,
fill(SVector(0.0, 0.0), shared.n_particles),
fill(0, shared.n_particles),
zeros(shared.n_particles),
) )
end end
end end
function gen_env_helper(::LocalCOMEnv, env_helper_params::EnvHelperParams) function gen_env_helper(::LocalCOMEnv, env_helper_params::EnvHelperSharedProps)
return LocalCOMEnvHelper(env_helper_params) return LocalCOMEnvHelper(env_helper_params)
end end
function pre_integration_hook(env_helper::LocalCOMEnvHelper) function pre_integration_hook(env_helper::LocalCOMEnvHelper)
@simd for id in 1:(env_helper.params.n_particles) @simd for id in 1:(env_helper.shared.n_particles)
env_helper.vec_to_neighbour_sums[id] = SVector(0.0, 0.0) env_helper.vec_to_neighbour_sums[id] = SVector(0.0, 0.0)
env_helper.n_neighbours[id] = 0 env_helper.n_neighbours[id] = 0
end end
@ -80,26 +84,28 @@ function state_update_helper_hook(
end end
function state_update_hook(env_helper::LocalCOMEnvHelper, particles::Vector{Particle}) function state_update_hook(env_helper::LocalCOMEnvHelper, particles::Vector{Particle})
n_particles = env_helper.params.n_particles n_particles = env_helper.shared.n_particles
@turbo for id in 1:(n_particles) @turbo for id in 1:(n_particles)
env_helper.params.old_states_ind[id] = env_helper.params.states_ind[id] env_helper.shared.old_states_ind[id] = env_helper.shared.states_ind[id]
end end
env = env_helper.params.env env = env_helper.shared.env
for id in 1:n_particles for id in 1:n_particles
n_neighbours = env_helper.n_neighbours[id] n_neighbours = env_helper.n_neighbours[id]
if n_neighbours == 0 if n_neighbours == 0
state_ind = env.params.n_states state_ind = env.shared.n_states
else else
vec_to_local_center_of_mass = vec_to_local_center_of_mass =
env_helper.vec_to_neighbour_sums[id] / n_neighbours env_helper.vec_to_neighbour_sums[id] / n_neighbours
distance = sqrt( sq_norm2d_vec_to_local_center_of_mass = sq_norm2d(vec_to_local_center_of_mass)
vec_to_local_center_of_mass[1]^2 + vec_to_local_center_of_mass[2]^2 env_helper.sq_norm2d_vec_to_local_center_of_mass[id] =
) sq_norm2d_vec_to_local_center_of_mass
distance = sqrt(sq_norm2d_vec_to_local_center_of_mass)
distance_state = find_state_interval(distance, env.distance_state_space) distance_state = find_state_interval(distance, env.distance_state_space)
@ -112,10 +118,10 @@ function state_update_hook(env_helper::LocalCOMEnvHelper, particles::Vector{Part
) )
state = SVector{2,Interval}(distance_state, direction_angle_state) state = SVector{2,Interval}(distance_state, direction_angle_state)
state_ind = find_state_ind(state, env.params.state_space) state_ind = find_state_ind(state, env.shared.state_space)
end end
env_helper.params.states_ind[id] = state_ind env_helper.shared.states_ind[id] = state_ind
end end
return nothing return nothing
@ -124,16 +130,14 @@ end
function update_reward!(env::LocalCOMEnv, env_helper::LocalCOMEnvHelper, particle::Particle) function update_reward!(env::LocalCOMEnv, env_helper::LocalCOMEnvHelper, particle::Particle)
id = particle.id id = particle.id
normalization = (env.max_distance * env_helper.params.n_particles) normalization = (env.max_distance * env_helper.shared.n_particles)
n_neighbours = env_helper.n_neighbours[id] n_neighbours = env_helper.n_neighbours[id]
if n_neighbours == 0 if n_neighbours == 0
env.params.reward = -(env.max_distance^2) / normalization env.shared.reward = -(env.max_distance^2) / normalization
else else
vec_to_local_center_of_mass = env_helper.vec_to_neighbour_sums[id] / n_neighbours # TODO: Reuse vec_to_local_center_of_mass from state_update_hook env.shared.reward =
env.params.reward = -(env_helper.sq_norm2d_vec_to_local_center_of_mass[id]) / normalization # TODO: Add shape term
-(vec_to_local_center_of_mass[1]^2 + vec_to_local_center_of_mass[2]^2) /
normalization
end end
return nothing return nothing

View file

@ -12,7 +12,8 @@ using LoopVectorization: @turbo
using Random: Random using Random: Random
using ProgressMeter: @showprogress using ProgressMeter: @showprogress
using ..ReCo: ReCo, Particle, angle2, Shape, DEFAULT_SKIN_TO_INTERACTION_R_RATIO using ..ReCo:
ReCo, Particle, angle2, norm2d, sq_norm2d, Shape, DEFAULT_SKIN_TO_INTERACTION_R_RATIO
const INITIAL_STATE_IND = 1 const INITIAL_STATE_IND = 1
const INITIAL_REWARD = 0.0 const INITIAL_REWARD = 0.0
@ -69,7 +70,7 @@ end
abstract type Env <: AbstractEnv end abstract type Env <: AbstractEnv end
mutable struct EnvParams{state_dims} mutable struct EnvSharedProps{state_dims}
n_actions::Int64 n_actions::Int64
action_space::Vector{SVector{2,Float64}} action_space::Vector{SVector{2,Float64}}
action_ind_space::OneTo{Int64} action_ind_space::OneTo{Int64}
@ -82,7 +83,7 @@ mutable struct EnvParams{state_dims}
reward::Float64 reward::Float64
terminated::Bool terminated::Bool
function EnvParams( function EnvSharedProps(
n_states::Int64, n_states::Int64,
state_space::Vector{SVector{state_dims,Interval}}; state_space::Vector{SVector{state_dims,Interval}};
n_v_actions::Int64=2, n_v_actions::Int64=2,
@ -129,22 +130,22 @@ mutable struct EnvParams{state_dims}
end end
function reset!(env::Env) function reset!(env::Env)
env.params.terminated = false env.shared.terminated = false
return nothing return nothing
end end
RLBase.state_space(env::Env) = env.params.state_ind_space RLBase.state_space(env::Env) = env.shared.state_ind_space
RLBase.state(env::Env) = env.params.state_ind RLBase.state(env::Env) = env.shared.state_ind
RLBase.action_space(env::Env) = env.params.action_ind_space RLBase.action_space(env::Env) = env.shared.action_ind_space
RLBase.reward(env::Env) = env.params.reward RLBase.reward(env::Env) = env.shared.reward
RLBase.is_terminated(env::Env) = env.params.terminated RLBase.is_terminated(env::Env) = env.shared.terminated
struct EnvHelperParams{H<:AbstractHook} struct EnvHelperSharedProps{H<:AbstractHook}
env::Env env::Env
agent::Agent agent::Agent
hook::H hook::H
@ -161,7 +162,7 @@ struct EnvHelperParams{H<:AbstractHook}
actions::Vector{SVector{2,Float64}} actions::Vector{SVector{2,Float64}}
actions_ind::Vector{Int64} actions_ind::Vector{Int64}
function EnvHelperParams( function EnvHelperSharedProps(
env::Env, env::Env,
agent::Agent, agent::Agent,
hook::H, hook::H,
@ -186,7 +187,7 @@ end
abstract type EnvHelper end abstract type EnvHelper end
function gen_env_helper(::Env, env_helper_params::EnvHelperParams) function gen_env_helper(::Env, env_helper_params::EnvHelperSharedProps)
return method_not_implemented() return method_not_implemented()
end end
@ -217,7 +218,7 @@ function state_update_hook(::EnvHelper, particles::Vector{Particle})
end end
function get_env_agent_hook(env_helper::EnvHelper) function get_env_agent_hook(env_helper::EnvHelper)
return (env_helper.params.env, env_helper.params.agent, env_helper.params.hook) return (env_helper.shared.env, env_helper.shared.agent, env_helper.shared.hook)
end end
function update_reward!(::Env, ::EnvHelper, particle::Particle) function update_reward!(::Env, ::EnvHelper, particle::Particle)
@ -233,16 +234,16 @@ function update_table_and_actions_hook(
if !first_integration_step if !first_integration_step
# Old state # Old state
env.params.state_ind = env_helper.params.old_states_ind[id] env.shared.state_ind = env_helper.shared.old_states_ind[id]
action_ind = env_helper.params.actions_ind[id] action_ind = env_helper.shared.actions_ind[id]
# Pre act # Pre act
agent(PRE_ACT_STAGE, env, action_ind) agent(PRE_ACT_STAGE, env, action_ind)
hook(PRE_ACT_STAGE, agent, env, action_ind) hook(PRE_ACT_STAGE, agent, env, action_ind)
# Update to current state # Update to current state
env.params.state_ind = env_helper.params.states_ind[id] env.shared.state_ind = env_helper.shared.states_ind[id]
# Update reward # Update reward
update_reward!(env, env_helper, particle) update_reward!(env, env_helper, particle)
@ -254,10 +255,10 @@ function update_table_and_actions_hook(
# Update action # Update action
action_ind = agent(env) action_ind = agent(env)
action = env.params.action_space[action_ind] action = env.shared.action_space[action_ind]
env_helper.params.actions[id] = action env_helper.shared.actions[id] = action
env_helper.params.actions_ind[id] = action_ind env_helper.shared.actions_ind[id] = action_ind
return nothing return nothing
end end
@ -268,7 +269,7 @@ function act_hook(
env_helper::EnvHelper, particle::Particle, δt::Float64, si::Float64, co::Float64 env_helper::EnvHelper, particle::Particle, δt::Float64, si::Float64, co::Float64
) )
# Apply action # Apply action
action = env_helper.params.actions[particle.id] action = env_helper.shared.actions[particle.id]
vδt = action[1] * δt vδt = action[1] * δt
particle.tmp_c += SVector(vδt * co, vδt * si) particle.tmp_c += SVector(vδt * co, vδt * si)
@ -333,13 +334,13 @@ function run_rl(;
env = EnvType(sim_consts) env = EnvType(sim_consts)
agent = gen_agent(env.params.n_states, env.params.n_actions, ϵ_stable) agent = gen_agent(env.shared.n_states, env.shared.n_actions, ϵ_stable)
n_steps_before_actions_update = round(Int64, update_actions_at / sim_consts.δt) n_steps_before_actions_update = round(Int64, update_actions_at / sim_consts.δt)
hook = TotalRewardPerEpisode() hook = TotalRewardPerEpisode()
env_helper_params = EnvHelperParams( env_helper_params = EnvHelperSharedProps(
env, env,
agent, agent,
hook, hook,
@ -374,7 +375,7 @@ function run_rl(;
env_helper=env_helper, env_helper=env_helper,
) )
env.params.terminated = true env.shared.terminated = true
# Post episode # Post episode
hook(POST_EPISODE_STAGE, agent, env) hook(POST_EPISODE_STAGE, agent, env)

View file

@ -1,7 +1,7 @@
using CairoMakie, LaTeXStrings using CairoMakie, LaTeXStrings
using LoopVectorization: @turbo using LoopVectorization: @turbo
using ReCo: minimum_image using ReCo: minimum_image, norm2d
function plot_g(radius, g, variables) function plot_g(radius, g, variables)
fig = Figure() fig = Figure()
@ -44,7 +44,7 @@ function pair_correlation(sol, variables)
r⃗₁₂ = minimum_image(r⃗₁₂, variables.half_box_len) r⃗₁₂ = minimum_image(r⃗₁₂, variables.half_box_len)
distance = sqrt(r⃗₁₂[1]^2 + r⃗₁₂[2]^2) distance = norm2d(r⃗₁₂)
if (distance >= r) && (distance <= r + dr) if (distance >= r) && (distance <= r + dr)
N_g[i, r_ind] += 1 N_g[i, r_ind] += 1

View file

@ -92,7 +92,7 @@ Base.wait(::Nothing) = nothing
gen_run_additional_hooks(::Nothing, args...) = false gen_run_additional_hooks(::Nothing, args...) = false
function gen_run_additional_hooks(env_helper::RL.EnvHelper, integration_step::Int64) function gen_run_additional_hooks(env_helper::RL.EnvHelper, integration_step::Int64)
return (integration_step % env_helper.params.n_steps_before_actions_update == 0) || return (integration_step % env_helper.shared.n_steps_before_actions_update == 0) ||
(integration_step == 1) (integration_step == 1)
end end