diff --git a/src/Error.jl b/src/Error.jl index d91e1cb..23105fb 100644 --- a/src/Error.jl +++ b/src/Error.jl @@ -2,6 +2,8 @@ module Error export method_not_implemented -method_not_implemented() = error("Method not implemented!") +function method_not_implemented() + return error("Method not implemented!") +end end # module \ No newline at end of file diff --git a/src/Geometry.jl b/src/Geometry.jl index e7e418c..3e48e0b 100644 --- a/src/Geometry.jl +++ b/src/Geometry.jl @@ -5,11 +5,11 @@ export angle2, norm2d, sq_norm2d using StaticArrays: SVector """ - angle2(a::SVector{2,Float64}, b::SVector{2,Float64}) + angle2(a::SVector{2,Real}, b::SVector{2,Real}) Returns the angle φ from vector a to b while φ ∈ [-π, π]. """ -function angle2(a::SVector{2,Float64}, b::SVector{2,Float64}) +function angle2(a::SVector{2,Real}, b::SVector{2,Real}) θ_a = atan(a[2], a[1]) θ_b = atan(b[2], b[1]) @@ -18,7 +18,12 @@ function angle2(a::SVector{2,Float64}, b::SVector{2,Float64}) return rem2pi(θ, RoundNearest) end -sq_norm2d(v::SVector{2,Float64}) = v[1]^2 + v[2]^2 -norm2d(v::SVector{2,Float64}) = sqrt(sq_norm2d(v)) +function sq_norm2d(v::SVector{2,Real}) + return v[1]^2 + v[2]^2 +end + +function norm2d(v::SVector{2,Real}) + return sqrt(sq_norm2d(v)) +end end # module \ No newline at end of file diff --git a/src/PreVectors.jl b/src/PreVectors.jl index 0562c26..554150e 100644 --- a/src/PreVectors.jl +++ b/src/PreVectors.jl @@ -18,13 +18,13 @@ function push!(pv::PreVector{T}, x::T) where {T} return nothing end -function reset!(pv::PreVector{T}) where {T} +function reset!(pv::PreVector) pv.last_ind = 0 return nothing end -function iterate(pv::PreVector{T}, state=1) where {T} +function iterate(pv::PreVector, state::UInt64=UInt64(1)) if state > pv.last_ind return nothing else diff --git a/src/RL/Env.jl b/src/RL/Env.jl index f80fa22..cf5d2d9 100644 --- a/src/RL/Env.jl +++ b/src/RL/Env.jl @@ -73,12 +73,22 @@ function reset!(env::Env) return nothing end -RLBase.state_space(env::Env) = env.shared.state_id_space +function RLBase.state_space(env::Env) + return env.shared.state_id_space +end -RLBase.state(env::Env) = env.shared.state_id +function RLBase.state(env::Env) + return env.shared.state_id +end -RLBase.action_space(env::Env) = env.shared.action_id_space +function RLBase.action_space(env::Env) + return env.shared.action_id_space +end -RLBase.reward(env::Env) = env.shared.reward +function RLBase.reward(env::Env) + return env.shared.reward +end -RLBase.is_terminated(env::Env) = env.shared.terminated \ No newline at end of file +function RLBase.is_terminated(env::Env) + return env.shared.terminated +end \ No newline at end of file diff --git a/src/RL/EnvHelper.jl b/src/RL/EnvHelper.jl index f0c47bb..2e1f821 100644 --- a/src/RL/EnvHelper.jl +++ b/src/RL/EnvHelper.jl @@ -20,12 +20,12 @@ struct EnvHelperSharedProps{H<:AbstractHook} function EnvHelperSharedProps( env::Env, agent::Agent, - hook::H, + hook::AbstractHook, n_steps_before_actions_update::Int64, goal_gyration_tensor_eigvals_ratio::Float64, n_particles::Int64, - ) where {H<:AbstractHook} - return new{H}( + ) + return new( env, agent, hook, diff --git a/src/RL/Hooks.jl b/src/RL/Hooks.jl index 82666fb..6fc13b5 100644 --- a/src/RL/Hooks.jl +++ b/src/RL/Hooks.jl @@ -1,16 +1,16 @@ using ..ReCo: Particle -function pre_integration_hook(::EnvHelper) +function pre_integration_hook!(::EnvHelper) return ReCo.method_not_implemented() end -function state_update_helper_hook( +function state_update_helper_hook!( ::EnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64} ) return ReCo.method_not_implemented() end -function state_update_hook(::EnvHelper, particles::Vector{Particle}) +function state_update_hook!(::EnvHelper, particles::Vector{Particle}) return ReCo.method_not_implemented() end @@ -18,7 +18,7 @@ function update_reward!(::Env, ::EnvHelper, particle::Particle) return ReCo.method_not_implemented() end -function update_table_and_actions_hook( +function update_table_and_actions_hook!( env_helper::EnvHelper, particle::Particle, first_integration_step::Bool ) env, agent, hook = get_env_agent_hook(env_helper) @@ -56,10 +56,12 @@ function update_table_and_actions_hook( return nothing end -act_hook(::Nothing, args...) = nothing +function act_hook!(::Particle, ::Nothing, args...) + return nothing +end -function act_hook( - env_helper::EnvHelper, particle::Particle, δt::Float64, si::Float64, co::Float64 +function act_hook!( + particle::Particle, env_helper::EnvHelper, δt::Float64, si::Float64, co::Float64 ) # Apply action action = env_helper.shared.actions[particle.id] diff --git a/src/RL/LocalCOMEnv.jl b/src/RL/LocalCOMEnv.jl index 6b73030..1e8699d 100644 --- a/src/RL/LocalCOMEnv.jl +++ b/src/RL/LocalCOMEnv.jl @@ -76,7 +76,7 @@ function gen_env_helper(::LocalCOMEnv, env_helper_shared::EnvHelperSharedProps; return LocalCOMEnvHelper(env_helper_shared, args.half_box_len, args.skin_r) end -function pre_integration_hook(env_helper::LocalCOMEnvHelper) +function pre_integration_hook!(env_helper::LocalCOMEnvHelper) @simd for id in 1:(env_helper.shared.n_particles) env_helper.vec_to_neighbour_sums[id] = SVector(0.0, 0.0) env_helper.n_neighbours[id] = 0 @@ -85,7 +85,7 @@ function pre_integration_hook(env_helper::LocalCOMEnvHelper) return nothing end -function state_update_helper_hook( +function state_update_helper_hook!( env_helper::LocalCOMEnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64} ) env_helper.vec_to_neighbour_sums[id1] += r⃗₁₂ @@ -97,7 +97,7 @@ function state_update_helper_hook( return nothing end -function state_update_hook(env_helper::LocalCOMEnvHelper, particles::Vector{Particle}) +function state_update_hook!(env_helper::LocalCOMEnvHelper, particles::Vector{Particle}) n_particles = env_helper.shared.n_particles @turbo for id in 1:n_particles diff --git a/src/RL/RL.jl b/src/RL/RL.jl index d0378e7..c0b52fc 100644 --- a/src/RL/RL.jl +++ b/src/RL/RL.jl @@ -61,7 +61,7 @@ function run_rl(; seed::Int64=42, ϵ_stable::Float64=0.0001, skin_to_interaction_r_ratio::Float64=ReCo.DEFAULT_SKIN_TO_INTERACTION_R_RATIO, - packing_ratio=0.22, + packing_ratio::Float64=0.22, ) where {E<:Env} @assert 0.0 <= goal_gyration_tensor_eigvals_ratio <= 1.0 @assert n_episodes > 0 diff --git a/src/Shape.jl b/src/Shape.jl index 7b7419b..c82df91 100644 --- a/src/Shape.jl +++ b/src/Shape.jl @@ -8,21 +8,21 @@ using LinearAlgebra: eigvals, eigvecs, Hermitian, dot using ..ReCo: ReCo, Particle -function project_to_unit_circle(x::Float64, half_box_len::Float64) +function project_to_unit_circle(x::Real, half_box_len::Real) φ = (x + half_box_len) * π / half_box_len si, co = sincos(φ) return SVector(co, si) end -function project_back_from_unit_circle(θ::T, half_box_len::Float64) where {T<:Real} +function project_back_from_unit_circle(θ::Real, half_box_len::Real) x = θ * half_box_len / π - half_box_len return ReCo.restrict_coordinate(x, half_box_len) end function center_of_mass_from_proj_sums( - x_proj_sum::SVector{2,Float64}, y_proj_sum::SVector{2,Float64}, half_box_len::Float64 + x_proj_sum::SVector{2,Real}, y_proj_sum::SVector{2,Real}, half_box_len::Real ) # Prevent for example atan(1e-16, 1e-15) != 0 with rounding digits = 5 @@ -47,7 +47,7 @@ function center_of_mass_from_proj_sums( return SVector(COM_x, COM_y) end -function center_of_mass(centers::AbstractVector{SVector{2,Float64}}, half_box_len::Float64) +function center_of_mass(centers::AbstractVector{SVector{2,Real}}, half_box_len::Real) x_proj_sum = SVector(0.0, 0.0) y_proj_sum = SVector(0.0, 0.0) @@ -59,7 +59,7 @@ function center_of_mass(centers::AbstractVector{SVector{2,Float64}}, half_box_le return center_of_mass_from_proj_sums(x_proj_sum, y_proj_sum, half_box_len) end -function center_of_mass(particles::Vector{Particle}, half_box_len::Float64) +function center_of_mass(particles::AbstractVector{Particle}, half_box_len::Real) x_proj_sum = SVector(0.0, 0.0) y_proj_sum = SVector(0.0, 0.0) @@ -72,7 +72,7 @@ function center_of_mass(particles::Vector{Particle}, half_box_len::Float64) end function gyration_tensor( - particles::Vector{Particle}, half_box_len::Float64, COM::SVector{2,Float64} + particles::AbstractVector{Particle}, half_box_len::Real, COM::SVector{2,Real} ) S11 = 0.0 S12 = 0.0 @@ -89,20 +89,22 @@ function gyration_tensor( return Hermitian(SMatrix{2,2}(S11, S12, S12, S22)) end -function gyration_tensor(particles::Vector{Particle}, half_box_len::Float64) +function gyration_tensor(particles::AbstractVector{Particle}, half_box_len::Real) COM = center_of_mass(particles, half_box_len) return gyration_tensor(particles, half_box_len, COM) end -function gyration_tensor_eigvals_ratio(particles::Vector{Particle}, half_box_len::Float64) +function gyration_tensor_eigvals_ratio( + particles::AbstractVector{Particle}, half_box_len::Real +) g_tensor = gyration_tensor(particles, half_box_len) ev = eigvals(g_tensor) # Eigenvalues are sorted return ev[1] / ev[2] end function gyration_tensor_eigvecs( - particles::Vector{Particle}, half_box_len::Float64, COM::SVector{2,Float64} + particles::AbstractVector{Particle}, half_box_len::Real, COM::SVector{2,Real} ) g_tensor = gyration_tensor(particles, half_box_len, COM) eig_vecs = eigvecs(g_tensor) @@ -114,12 +116,12 @@ function gyration_tensor_eigvecs( end function elliptical_distance( - v::SVector{2,Float64}, - COM::SVector{2,Float64}, - gyration_tensor_eigvec_to_smaller_eigval::SVector{2,Float64}, - gyration_tensor_eigvec_to_bigger_eigval::SVector{2,Float64}, - goal_gyration_tensor_eigvals_ratio::Float64, - half_box_len::Float64, + v::SVector{2,Real}, + COM::SVector{2,Real}, + gyration_tensor_eigvec_to_smaller_eigval::SVector{2,Real}, + gyration_tensor_eigvec_to_bigger_eigval::SVector{2,Real}, + goal_gyration_tensor_eigvals_ratio::Real, + half_box_len::Real, ) v′ = ReCo.minimum_image(v - COM, half_box_len) diff --git a/src/run.jl b/src/run.jl index 57bba4f..324ae12 100644 --- a/src/run.jl +++ b/src/run.jl @@ -1,4 +1,6 @@ -empty_hook(args...) = nothing +function empty_hook(args...) + return nothing +end function run_sim( dir::String; @@ -101,7 +103,7 @@ function run_sim( ), ) - simulate( + simulate!( args, T0, T, diff --git a/src/simulation.jl b/src/simulation.jl index b31cebf..97ec77a 100644 --- a/src/simulation.jl +++ b/src/simulation.jl @@ -1,4 +1,6 @@ -rand_normal01() = rand(Normal(0, 1)) +function rand_normal01() + return rand(Normal(0, 1)) +end function push_to_verlet_list!(verlet_lists, i, j) if i < j @@ -36,9 +38,9 @@ function euler!( args, first_integration_step::Bool, env_helper::Union{RL.EnvHelper,Nothing}, - state_update_helper_hook::Function, - state_update_hook::Function, - update_table_and_actions_hook::Function, + state_update_helper_hook!::Function, + state_update_hook!::Function, + update_table_and_actions_hook!::Function, ) for id1 in 1:(args.n_particles - 1) p1 = args.particles[id1] @@ -52,7 +54,7 @@ function euler!( p1_c, p2.c, args.interaction_r², args.half_box_len ) - state_update_helper_hook(env_helper, id1, id2, r⃗₁₂) + state_update_helper_hook!(env_helper, id1, id2, r⃗₁₂) if overlapping factor = args.c₁ / (distance²^4) * (args.c₂ / (distance²^3) - 1.0) @@ -64,7 +66,7 @@ function euler!( end end - state_update_hook(env_helper, args.particles) + state_update_hook!(env_helper, args.particles) @simd for p in args.particles si, co = sincos(p.φ) @@ -75,9 +77,9 @@ function euler!( restrict_coordinates!(p, args.half_box_len) - update_table_and_actions_hook(env_helper, p, first_integration_step) + update_table_and_actions_hook!(env_helper, p, first_integration_step) - RL.act_hook(env_helper, p, args.δt, si, co) + RL.act_hook!(p, env_helper, args.δt, si, co) p.φ += args.c₄ * rand_normal01() @@ -87,16 +89,20 @@ function euler!( return nothing end -Base.wait(::Nothing) = nothing +function Base.wait(::Nothing) + return nothing +end -gen_run_additional_hooks(::Nothing, args...) = false +function gen_run_additional_hooks(::Nothing, args...) + return false +end function gen_run_additional_hooks(env_helper::RL.EnvHelper, integration_step::Int64) return (integration_step % env_helper.shared.n_steps_before_actions_update == 0) || (integration_step == 1) end -function simulate( +function simulate!( args, T0::Float64, T::Float64, @@ -116,8 +122,8 @@ function simulate( first_integration_step = true - state_update_helper_hook = - state_update_hook = update_table_and_actions_hook = empty_hook + state_update_helper_hook! = + state_update_hook! = update_table_and_actions_hook! = empty_hook start_time = now() println("Started simulation at $start_time.") @@ -146,25 +152,25 @@ function simulate( run_additional_hooks = gen_run_additional_hooks(env_helper, integration_step) if run_additional_hooks - RL.pre_integration_hook(env_helper) + RL.pre_integration_hook!(env_helper) - state_update_helper_hook = RL.state_update_helper_hook - state_update_hook = RL.state_update_hook - update_table_and_actions_hook = RL.update_table_and_actions_hook + state_update_helper_hook! = RL.state_update_helper_hook! + state_update_hook! = RL.state_update_hook! + update_table_and_actions_hook! = RL.update_table_and_actions_hook! end euler!( args, first_integration_step, env_helper, - state_update_helper_hook, - state_update_hook, - update_table_and_actions_hook, + state_update_helper_hook!, + state_update_hook!, + update_table_and_actions_hook!, ) if run_additional_hooks - state_update_helper_hook = - state_update_hook = update_table_and_actions_hook = empty_hook + state_update_helper_hook! = + state_update_hook! = update_table_and_actions_hook! = empty_hook end first_integration_step = false