diff --git a/src/Geometry.jl b/src/Geometry.jl index 49a43bd..e7e418c 100644 --- a/src/Geometry.jl +++ b/src/Geometry.jl @@ -1,6 +1,6 @@ module Geometry -export angle2 +export angle2, norm2d, sq_norm2d using StaticArrays: SVector @@ -18,4 +18,7 @@ function angle2(a::SVector{2,Float64}, b::SVector{2,Float64}) return rem2pi(θ, RoundNearest) end +sq_norm2d(v::SVector{2,Float64}) = v[1]^2 + v[2]^2 +norm2d(v::SVector{2,Float64}) = sqrt(sq_norm2d(v)) + end # module \ No newline at end of file diff --git a/src/Particle.jl b/src/Particle.jl index 058980e..2947b53 100644 --- a/src/Particle.jl +++ b/src/Particle.jl @@ -64,7 +64,7 @@ function are_overlapping( r⃗₁₂ = minimum_image(r⃗₁₂, half_box_len) - distance² = r⃗₁₂[1]^2 + r⃗₁₂[2]^2 + distance² = sq_norm2d(r⃗₁₂) overlapping = distance² < overlapping_r² diff --git a/src/RL/LocalCOMEnv.jl b/src/RL/LocalCOMEnv.jl index 4f58451..d2d351f 100644 --- a/src/RL/LocalCOMEnv.jl +++ b/src/RL/LocalCOMEnv.jl @@ -1,7 +1,7 @@ export LocalCOMEnv struct LocalCOMEnv <: Env - params::EnvParams + shared::EnvSharedProps distance_state_space::Vector{Interval} direction_angle_state_space::Vector{Interval} @@ -35,31 +35,35 @@ struct LocalCOMEnv <: Env end # Last state is when no particle is in the skin radius - params = EnvParams(n_states, state_space) + shared = EnvSharedProps(n_states, state_space) - return new(params, distance_state_space, direction_angle_state_space, max_distance) + return new(shared, distance_state_space, direction_angle_state_space, max_distance) end end struct LocalCOMEnvHelper <: EnvHelper - params::EnvHelperParams + shared::EnvHelperSharedProps vec_to_neighbour_sums::Vector{SVector{2,Float64}} n_neighbours::Vector{Int64} + sq_norm2d_vec_to_local_center_of_mass::Vector{Float64} - function LocalCOMEnvHelper(params::EnvHelperParams) + function LocalCOMEnvHelper(shared::EnvHelperSharedProps) return new( - params, fill(SVector(0.0, 0.0), params.n_particles), fill(0, params.n_particles) + shared, + fill(SVector(0.0, 0.0), shared.n_particles), + fill(0, shared.n_particles), + zeros(shared.n_particles), ) end end -function gen_env_helper(::LocalCOMEnv, env_helper_params::EnvHelperParams) +function gen_env_helper(::LocalCOMEnv, env_helper_params::EnvHelperSharedProps) return LocalCOMEnvHelper(env_helper_params) end function pre_integration_hook(env_helper::LocalCOMEnvHelper) - @simd for id in 1:(env_helper.params.n_particles) + @simd for id in 1:(env_helper.shared.n_particles) env_helper.vec_to_neighbour_sums[id] = SVector(0.0, 0.0) env_helper.n_neighbours[id] = 0 end @@ -80,26 +84,28 @@ function state_update_helper_hook( end function state_update_hook(env_helper::LocalCOMEnvHelper, particles::Vector{Particle}) - n_particles = env_helper.params.n_particles + n_particles = env_helper.shared.n_particles @turbo for id in 1:(n_particles) - env_helper.params.old_states_ind[id] = env_helper.params.states_ind[id] + env_helper.shared.old_states_ind[id] = env_helper.shared.states_ind[id] end - env = env_helper.params.env + env = env_helper.shared.env for id in 1:n_particles n_neighbours = env_helper.n_neighbours[id] if n_neighbours == 0 - state_ind = env.params.n_states + state_ind = env.shared.n_states else vec_to_local_center_of_mass = env_helper.vec_to_neighbour_sums[id] / n_neighbours - distance = sqrt( - vec_to_local_center_of_mass[1]^2 + vec_to_local_center_of_mass[2]^2 - ) + sq_norm2d_vec_to_local_center_of_mass = sq_norm2d(vec_to_local_center_of_mass) + env_helper.sq_norm2d_vec_to_local_center_of_mass[id] = + sq_norm2d_vec_to_local_center_of_mass + + distance = sqrt(sq_norm2d_vec_to_local_center_of_mass) distance_state = find_state_interval(distance, env.distance_state_space) @@ -112,10 +118,10 @@ function state_update_hook(env_helper::LocalCOMEnvHelper, particles::Vector{Part ) state = SVector{2,Interval}(distance_state, direction_angle_state) - state_ind = find_state_ind(state, env.params.state_space) + state_ind = find_state_ind(state, env.shared.state_space) end - env_helper.params.states_ind[id] = state_ind + env_helper.shared.states_ind[id] = state_ind end return nothing @@ -124,16 +130,14 @@ end function update_reward!(env::LocalCOMEnv, env_helper::LocalCOMEnvHelper, particle::Particle) id = particle.id - normalization = (env.max_distance * env_helper.params.n_particles) + normalization = (env.max_distance * env_helper.shared.n_particles) n_neighbours = env_helper.n_neighbours[id] if n_neighbours == 0 - env.params.reward = -(env.max_distance^2) / normalization + env.shared.reward = -(env.max_distance^2) / normalization else - vec_to_local_center_of_mass = env_helper.vec_to_neighbour_sums[id] / n_neighbours # TODO: Reuse vec_to_local_center_of_mass from state_update_hook - env.params.reward = - -(vec_to_local_center_of_mass[1]^2 + vec_to_local_center_of_mass[2]^2) / - normalization + env.shared.reward = + -(env_helper.sq_norm2d_vec_to_local_center_of_mass[id]) / normalization # TODO: Add shape term end return nothing diff --git a/src/RL/RL.jl b/src/RL/RL.jl index a301472..bf7250a 100644 --- a/src/RL/RL.jl +++ b/src/RL/RL.jl @@ -12,7 +12,8 @@ using LoopVectorization: @turbo using Random: Random using ProgressMeter: @showprogress -using ..ReCo: ReCo, Particle, angle2, Shape, DEFAULT_SKIN_TO_INTERACTION_R_RATIO +using ..ReCo: + ReCo, Particle, angle2, norm2d, sq_norm2d, Shape, DEFAULT_SKIN_TO_INTERACTION_R_RATIO const INITIAL_STATE_IND = 1 const INITIAL_REWARD = 0.0 @@ -69,7 +70,7 @@ end abstract type Env <: AbstractEnv end -mutable struct EnvParams{state_dims} +mutable struct EnvSharedProps{state_dims} n_actions::Int64 action_space::Vector{SVector{2,Float64}} action_ind_space::OneTo{Int64} @@ -82,7 +83,7 @@ mutable struct EnvParams{state_dims} reward::Float64 terminated::Bool - function EnvParams( + function EnvSharedProps( n_states::Int64, state_space::Vector{SVector{state_dims,Interval}}; n_v_actions::Int64=2, @@ -129,22 +130,22 @@ mutable struct EnvParams{state_dims} end function reset!(env::Env) - env.params.terminated = false + env.shared.terminated = false return nothing end -RLBase.state_space(env::Env) = env.params.state_ind_space +RLBase.state_space(env::Env) = env.shared.state_ind_space -RLBase.state(env::Env) = env.params.state_ind +RLBase.state(env::Env) = env.shared.state_ind -RLBase.action_space(env::Env) = env.params.action_ind_space +RLBase.action_space(env::Env) = env.shared.action_ind_space -RLBase.reward(env::Env) = env.params.reward +RLBase.reward(env::Env) = env.shared.reward -RLBase.is_terminated(env::Env) = env.params.terminated +RLBase.is_terminated(env::Env) = env.shared.terminated -struct EnvHelperParams{H<:AbstractHook} +struct EnvHelperSharedProps{H<:AbstractHook} env::Env agent::Agent hook::H @@ -161,7 +162,7 @@ struct EnvHelperParams{H<:AbstractHook} actions::Vector{SVector{2,Float64}} actions_ind::Vector{Int64} - function EnvHelperParams( + function EnvHelperSharedProps( env::Env, agent::Agent, hook::H, @@ -186,7 +187,7 @@ end abstract type EnvHelper end -function gen_env_helper(::Env, env_helper_params::EnvHelperParams) +function gen_env_helper(::Env, env_helper_params::EnvHelperSharedProps) return method_not_implemented() end @@ -217,7 +218,7 @@ function state_update_hook(::EnvHelper, particles::Vector{Particle}) end function get_env_agent_hook(env_helper::EnvHelper) - return (env_helper.params.env, env_helper.params.agent, env_helper.params.hook) + return (env_helper.shared.env, env_helper.shared.agent, env_helper.shared.hook) end function update_reward!(::Env, ::EnvHelper, particle::Particle) @@ -233,16 +234,16 @@ function update_table_and_actions_hook( if !first_integration_step # Old state - env.params.state_ind = env_helper.params.old_states_ind[id] + env.shared.state_ind = env_helper.shared.old_states_ind[id] - action_ind = env_helper.params.actions_ind[id] + action_ind = env_helper.shared.actions_ind[id] # Pre act agent(PRE_ACT_STAGE, env, action_ind) hook(PRE_ACT_STAGE, agent, env, action_ind) # Update to current state - env.params.state_ind = env_helper.params.states_ind[id] + env.shared.state_ind = env_helper.shared.states_ind[id] # Update reward update_reward!(env, env_helper, particle) @@ -254,10 +255,10 @@ function update_table_and_actions_hook( # Update action action_ind = agent(env) - action = env.params.action_space[action_ind] + action = env.shared.action_space[action_ind] - env_helper.params.actions[id] = action - env_helper.params.actions_ind[id] = action_ind + env_helper.shared.actions[id] = action + env_helper.shared.actions_ind[id] = action_ind return nothing end @@ -268,7 +269,7 @@ function act_hook( env_helper::EnvHelper, particle::Particle, δt::Float64, si::Float64, co::Float64 ) # Apply action - action = env_helper.params.actions[particle.id] + action = env_helper.shared.actions[particle.id] vδt = action[1] * δt particle.tmp_c += SVector(vδt * co, vδt * si) @@ -333,13 +334,13 @@ function run_rl(; env = EnvType(sim_consts) - agent = gen_agent(env.params.n_states, env.params.n_actions, ϵ_stable) + agent = gen_agent(env.shared.n_states, env.shared.n_actions, ϵ_stable) n_steps_before_actions_update = round(Int64, update_actions_at / sim_consts.δt) hook = TotalRewardPerEpisode() - env_helper_params = EnvHelperParams( + env_helper_params = EnvHelperSharedProps( env, agent, hook, @@ -374,7 +375,7 @@ function run_rl(; env_helper=env_helper, ) - env.params.terminated = true + env.shared.terminated = true # Post episode hook(POST_EPISODE_STAGE, agent, env) diff --git a/src/analysis/pair_correlation_function.jl b/src/analysis/pair_correlation_function.jl index 4236bda..7d23ed1 100644 --- a/src/analysis/pair_correlation_function.jl +++ b/src/analysis/pair_correlation_function.jl @@ -1,7 +1,7 @@ using CairoMakie, LaTeXStrings using LoopVectorization: @turbo -using ReCo: minimum_image +using ReCo: minimum_image, norm2d function plot_g(radius, g, variables) fig = Figure() @@ -44,7 +44,7 @@ function pair_correlation(sol, variables) r⃗₁₂ = minimum_image(r⃗₁₂, variables.half_box_len) - distance = sqrt(r⃗₁₂[1]^2 + r⃗₁₂[2]^2) + distance = norm2d(r⃗₁₂) if (distance >= r) && (distance <= r + dr) N_g[i, r_ind] += 1 diff --git a/src/simulation.jl b/src/simulation.jl index 08d3fa2..b31cebf 100644 --- a/src/simulation.jl +++ b/src/simulation.jl @@ -92,7 +92,7 @@ Base.wait(::Nothing) = nothing gen_run_additional_hooks(::Nothing, args...) = false function gen_run_additional_hooks(env_helper::RL.EnvHelper, integration_step::Int64) - return (integration_step % env_helper.params.n_steps_before_actions_update == 0) || + return (integration_step % env_helper.shared.n_steps_before_actions_update == 0) || (integration_step == 1) end