1
0
Fork 0
mirror of https://gitlab.rlp.net/mobitar/ReCo.jl.git synced 2024-09-19 19:01:17 +00:00

Follow Julian style guide

This commit is contained in:
Mo8it 2022-01-18 02:17:52 +01:00
parent eceeda4099
commit a31b6ed2a0
11 changed files with 94 additions and 65 deletions

View file

@ -2,6 +2,8 @@ module Error
export method_not_implemented export method_not_implemented
method_not_implemented() = error("Method not implemented!") function method_not_implemented()
return error("Method not implemented!")
end
end # module end # module

View file

@ -5,11 +5,11 @@ export angle2, norm2d, sq_norm2d
using StaticArrays: SVector using StaticArrays: SVector
""" """
angle2(a::SVector{2,Float64}, b::SVector{2,Float64}) angle2(a::SVector{2,Real}, b::SVector{2,Real})
Returns the angle φ from vector a to b while φ [-π, π]. Returns the angle φ from vector a to b while φ [-π, π].
""" """
function angle2(a::SVector{2,Float64}, b::SVector{2,Float64}) function angle2(a::SVector{2,Real}, b::SVector{2,Real})
θ_a = atan(a[2], a[1]) θ_a = atan(a[2], a[1])
θ_b = atan(b[2], b[1]) θ_b = atan(b[2], b[1])
@ -18,7 +18,12 @@ function angle2(a::SVector{2,Float64}, b::SVector{2,Float64})
return rem2pi(θ, RoundNearest) return rem2pi(θ, RoundNearest)
end end
sq_norm2d(v::SVector{2,Float64}) = v[1]^2 + v[2]^2 function sq_norm2d(v::SVector{2,Real})
norm2d(v::SVector{2,Float64}) = sqrt(sq_norm2d(v)) return v[1]^2 + v[2]^2
end
function norm2d(v::SVector{2,Real})
return sqrt(sq_norm2d(v))
end
end # module end # module

View file

@ -18,13 +18,13 @@ function push!(pv::PreVector{T}, x::T) where {T}
return nothing return nothing
end end
function reset!(pv::PreVector{T}) where {T} function reset!(pv::PreVector)
pv.last_ind = 0 pv.last_ind = 0
return nothing return nothing
end end
function iterate(pv::PreVector{T}, state=1) where {T} function iterate(pv::PreVector, state::UInt64=UInt64(1))
if state > pv.last_ind if state > pv.last_ind
return nothing return nothing
else else

View file

@ -73,12 +73,22 @@ function reset!(env::Env)
return nothing return nothing
end end
RLBase.state_space(env::Env) = env.shared.state_id_space function RLBase.state_space(env::Env)
return env.shared.state_id_space
end
RLBase.state(env::Env) = env.shared.state_id function RLBase.state(env::Env)
return env.shared.state_id
end
RLBase.action_space(env::Env) = env.shared.action_id_space function RLBase.action_space(env::Env)
return env.shared.action_id_space
end
RLBase.reward(env::Env) = env.shared.reward function RLBase.reward(env::Env)
return env.shared.reward
end
RLBase.is_terminated(env::Env) = env.shared.terminated function RLBase.is_terminated(env::Env)
return env.shared.terminated
end

View file

@ -20,12 +20,12 @@ struct EnvHelperSharedProps{H<:AbstractHook}
function EnvHelperSharedProps( function EnvHelperSharedProps(
env::Env, env::Env,
agent::Agent, agent::Agent,
hook::H, hook::AbstractHook,
n_steps_before_actions_update::Int64, n_steps_before_actions_update::Int64,
goal_gyration_tensor_eigvals_ratio::Float64, goal_gyration_tensor_eigvals_ratio::Float64,
n_particles::Int64, n_particles::Int64,
) where {H<:AbstractHook} )
return new{H}( return new(
env, env,
agent, agent,
hook, hook,

View file

@ -1,16 +1,16 @@
using ..ReCo: Particle using ..ReCo: Particle
function pre_integration_hook(::EnvHelper) function pre_integration_hook!(::EnvHelper)
return ReCo.method_not_implemented() return ReCo.method_not_implemented()
end end
function state_update_helper_hook( function state_update_helper_hook!(
::EnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64} ::EnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}
) )
return ReCo.method_not_implemented() return ReCo.method_not_implemented()
end end
function state_update_hook(::EnvHelper, particles::Vector{Particle}) function state_update_hook!(::EnvHelper, particles::Vector{Particle})
return ReCo.method_not_implemented() return ReCo.method_not_implemented()
end end
@ -18,7 +18,7 @@ function update_reward!(::Env, ::EnvHelper, particle::Particle)
return ReCo.method_not_implemented() return ReCo.method_not_implemented()
end end
function update_table_and_actions_hook( function update_table_and_actions_hook!(
env_helper::EnvHelper, particle::Particle, first_integration_step::Bool env_helper::EnvHelper, particle::Particle, first_integration_step::Bool
) )
env, agent, hook = get_env_agent_hook(env_helper) env, agent, hook = get_env_agent_hook(env_helper)
@ -56,10 +56,12 @@ function update_table_and_actions_hook(
return nothing return nothing
end end
act_hook(::Nothing, args...) = nothing function act_hook!(::Particle, ::Nothing, args...)
return nothing
end
function act_hook( function act_hook!(
env_helper::EnvHelper, particle::Particle, δt::Float64, si::Float64, co::Float64 particle::Particle, env_helper::EnvHelper, δt::Float64, si::Float64, co::Float64
) )
# Apply action # Apply action
action = env_helper.shared.actions[particle.id] action = env_helper.shared.actions[particle.id]

View file

@ -76,7 +76,7 @@ function gen_env_helper(::LocalCOMEnv, env_helper_shared::EnvHelperSharedProps;
return LocalCOMEnvHelper(env_helper_shared, args.half_box_len, args.skin_r) return LocalCOMEnvHelper(env_helper_shared, args.half_box_len, args.skin_r)
end end
function pre_integration_hook(env_helper::LocalCOMEnvHelper) function pre_integration_hook!(env_helper::LocalCOMEnvHelper)
@simd for id in 1:(env_helper.shared.n_particles) @simd for id in 1:(env_helper.shared.n_particles)
env_helper.vec_to_neighbour_sums[id] = SVector(0.0, 0.0) env_helper.vec_to_neighbour_sums[id] = SVector(0.0, 0.0)
env_helper.n_neighbours[id] = 0 env_helper.n_neighbours[id] = 0
@ -85,7 +85,7 @@ function pre_integration_hook(env_helper::LocalCOMEnvHelper)
return nothing return nothing
end end
function state_update_helper_hook( function state_update_helper_hook!(
env_helper::LocalCOMEnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64} env_helper::LocalCOMEnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}
) )
env_helper.vec_to_neighbour_sums[id1] += r⃗₁₂ env_helper.vec_to_neighbour_sums[id1] += r⃗₁₂
@ -97,7 +97,7 @@ function state_update_helper_hook(
return nothing return nothing
end end
function state_update_hook(env_helper::LocalCOMEnvHelper, particles::Vector{Particle}) function state_update_hook!(env_helper::LocalCOMEnvHelper, particles::Vector{Particle})
n_particles = env_helper.shared.n_particles n_particles = env_helper.shared.n_particles
@turbo for id in 1:n_particles @turbo for id in 1:n_particles

View file

@ -61,7 +61,7 @@ function run_rl(;
seed::Int64=42, seed::Int64=42,
ϵ_stable::Float64=0.0001, ϵ_stable::Float64=0.0001,
skin_to_interaction_r_ratio::Float64=ReCo.DEFAULT_SKIN_TO_INTERACTION_R_RATIO, skin_to_interaction_r_ratio::Float64=ReCo.DEFAULT_SKIN_TO_INTERACTION_R_RATIO,
packing_ratio=0.22, packing_ratio::Float64=0.22,
) where {E<:Env} ) where {E<:Env}
@assert 0.0 <= goal_gyration_tensor_eigvals_ratio <= 1.0 @assert 0.0 <= goal_gyration_tensor_eigvals_ratio <= 1.0
@assert n_episodes > 0 @assert n_episodes > 0

View file

@ -8,21 +8,21 @@ using LinearAlgebra: eigvals, eigvecs, Hermitian, dot
using ..ReCo: ReCo, Particle using ..ReCo: ReCo, Particle
function project_to_unit_circle(x::Float64, half_box_len::Float64) function project_to_unit_circle(x::Real, half_box_len::Real)
φ = (x + half_box_len) * π / half_box_len φ = (x + half_box_len) * π / half_box_len
si, co = sincos(φ) si, co = sincos(φ)
return SVector(co, si) return SVector(co, si)
end end
function project_back_from_unit_circle(θ::T, half_box_len::Float64) where {T<:Real} function project_back_from_unit_circle(θ::Real, half_box_len::Real)
x = θ * half_box_len / π - half_box_len x = θ * half_box_len / π - half_box_len
return ReCo.restrict_coordinate(x, half_box_len) return ReCo.restrict_coordinate(x, half_box_len)
end end
function center_of_mass_from_proj_sums( function center_of_mass_from_proj_sums(
x_proj_sum::SVector{2,Float64}, y_proj_sum::SVector{2,Float64}, half_box_len::Float64 x_proj_sum::SVector{2,Real}, y_proj_sum::SVector{2,Real}, half_box_len::Real
) )
# Prevent for example atan(1e-16, 1e-15) != 0 with rounding # Prevent for example atan(1e-16, 1e-15) != 0 with rounding
digits = 5 digits = 5
@ -47,7 +47,7 @@ function center_of_mass_from_proj_sums(
return SVector(COM_x, COM_y) return SVector(COM_x, COM_y)
end end
function center_of_mass(centers::AbstractVector{SVector{2,Float64}}, half_box_len::Float64) function center_of_mass(centers::AbstractVector{SVector{2,Real}}, half_box_len::Real)
x_proj_sum = SVector(0.0, 0.0) x_proj_sum = SVector(0.0, 0.0)
y_proj_sum = SVector(0.0, 0.0) y_proj_sum = SVector(0.0, 0.0)
@ -59,7 +59,7 @@ function center_of_mass(centers::AbstractVector{SVector{2,Float64}}, half_box_le
return center_of_mass_from_proj_sums(x_proj_sum, y_proj_sum, half_box_len) return center_of_mass_from_proj_sums(x_proj_sum, y_proj_sum, half_box_len)
end end
function center_of_mass(particles::Vector{Particle}, half_box_len::Float64) function center_of_mass(particles::AbstractVector{Particle}, half_box_len::Real)
x_proj_sum = SVector(0.0, 0.0) x_proj_sum = SVector(0.0, 0.0)
y_proj_sum = SVector(0.0, 0.0) y_proj_sum = SVector(0.0, 0.0)
@ -72,7 +72,7 @@ function center_of_mass(particles::Vector{Particle}, half_box_len::Float64)
end end
function gyration_tensor( function gyration_tensor(
particles::Vector{Particle}, half_box_len::Float64, COM::SVector{2,Float64} particles::AbstractVector{Particle}, half_box_len::Real, COM::SVector{2,Real}
) )
S11 = 0.0 S11 = 0.0
S12 = 0.0 S12 = 0.0
@ -89,20 +89,22 @@ function gyration_tensor(
return Hermitian(SMatrix{2,2}(S11, S12, S12, S22)) return Hermitian(SMatrix{2,2}(S11, S12, S12, S22))
end end
function gyration_tensor(particles::Vector{Particle}, half_box_len::Float64) function gyration_tensor(particles::AbstractVector{Particle}, half_box_len::Real)
COM = center_of_mass(particles, half_box_len) COM = center_of_mass(particles, half_box_len)
return gyration_tensor(particles, half_box_len, COM) return gyration_tensor(particles, half_box_len, COM)
end end
function gyration_tensor_eigvals_ratio(particles::Vector{Particle}, half_box_len::Float64) function gyration_tensor_eigvals_ratio(
particles::AbstractVector{Particle}, half_box_len::Real
)
g_tensor = gyration_tensor(particles, half_box_len) g_tensor = gyration_tensor(particles, half_box_len)
ev = eigvals(g_tensor) # Eigenvalues are sorted ev = eigvals(g_tensor) # Eigenvalues are sorted
return ev[1] / ev[2] return ev[1] / ev[2]
end end
function gyration_tensor_eigvecs( function gyration_tensor_eigvecs(
particles::Vector{Particle}, half_box_len::Float64, COM::SVector{2,Float64} particles::AbstractVector{Particle}, half_box_len::Real, COM::SVector{2,Real}
) )
g_tensor = gyration_tensor(particles, half_box_len, COM) g_tensor = gyration_tensor(particles, half_box_len, COM)
eig_vecs = eigvecs(g_tensor) eig_vecs = eigvecs(g_tensor)
@ -114,12 +116,12 @@ function gyration_tensor_eigvecs(
end end
function elliptical_distance( function elliptical_distance(
v::SVector{2,Float64}, v::SVector{2,Real},
COM::SVector{2,Float64}, COM::SVector{2,Real},
gyration_tensor_eigvec_to_smaller_eigval::SVector{2,Float64}, gyration_tensor_eigvec_to_smaller_eigval::SVector{2,Real},
gyration_tensor_eigvec_to_bigger_eigval::SVector{2,Float64}, gyration_tensor_eigvec_to_bigger_eigval::SVector{2,Real},
goal_gyration_tensor_eigvals_ratio::Float64, goal_gyration_tensor_eigvals_ratio::Real,
half_box_len::Float64, half_box_len::Real,
) )
v = ReCo.minimum_image(v - COM, half_box_len) v = ReCo.minimum_image(v - COM, half_box_len)

View file

@ -1,4 +1,6 @@
empty_hook(args...) = nothing function empty_hook(args...)
return nothing
end
function run_sim( function run_sim(
dir::String; dir::String;
@ -101,7 +103,7 @@ function run_sim(
), ),
) )
simulate( simulate!(
args, args,
T0, T0,
T, T,

View file

@ -1,4 +1,6 @@
rand_normal01() = rand(Normal(0, 1)) function rand_normal01()
return rand(Normal(0, 1))
end
function push_to_verlet_list!(verlet_lists, i, j) function push_to_verlet_list!(verlet_lists, i, j)
if i < j if i < j
@ -36,9 +38,9 @@ function euler!(
args, args,
first_integration_step::Bool, first_integration_step::Bool,
env_helper::Union{RL.EnvHelper,Nothing}, env_helper::Union{RL.EnvHelper,Nothing},
state_update_helper_hook::Function, state_update_helper_hook!::Function,
state_update_hook::Function, state_update_hook!::Function,
update_table_and_actions_hook::Function, update_table_and_actions_hook!::Function,
) )
for id1 in 1:(args.n_particles - 1) for id1 in 1:(args.n_particles - 1)
p1 = args.particles[id1] p1 = args.particles[id1]
@ -52,7 +54,7 @@ function euler!(
p1_c, p2.c, args.interaction_r², args.half_box_len p1_c, p2.c, args.interaction_r², args.half_box_len
) )
state_update_helper_hook(env_helper, id1, id2, r⃗₁₂) state_update_helper_hook!(env_helper, id1, id2, r⃗₁₂)
if overlapping if overlapping
factor = args.c₁ / (distance²^4) * (args.c₂ / (distance²^3) - 1.0) factor = args.c₁ / (distance²^4) * (args.c₂ / (distance²^3) - 1.0)
@ -64,7 +66,7 @@ function euler!(
end end
end end
state_update_hook(env_helper, args.particles) state_update_hook!(env_helper, args.particles)
@simd for p in args.particles @simd for p in args.particles
si, co = sincos(p.φ) si, co = sincos(p.φ)
@ -75,9 +77,9 @@ function euler!(
restrict_coordinates!(p, args.half_box_len) restrict_coordinates!(p, args.half_box_len)
update_table_and_actions_hook(env_helper, p, first_integration_step) update_table_and_actions_hook!(env_helper, p, first_integration_step)
RL.act_hook(env_helper, p, args.δt, si, co) RL.act_hook!(p, env_helper, args.δt, si, co)
p.φ += args.c₄ * rand_normal01() p.φ += args.c₄ * rand_normal01()
@ -87,16 +89,20 @@ function euler!(
return nothing return nothing
end end
Base.wait(::Nothing) = nothing function Base.wait(::Nothing)
return nothing
end
gen_run_additional_hooks(::Nothing, args...) = false function gen_run_additional_hooks(::Nothing, args...)
return false
end
function gen_run_additional_hooks(env_helper::RL.EnvHelper, integration_step::Int64) function gen_run_additional_hooks(env_helper::RL.EnvHelper, integration_step::Int64)
return (integration_step % env_helper.shared.n_steps_before_actions_update == 0) || return (integration_step % env_helper.shared.n_steps_before_actions_update == 0) ||
(integration_step == 1) (integration_step == 1)
end end
function simulate( function simulate!(
args, args,
T0::Float64, T0::Float64,
T::Float64, T::Float64,
@ -116,8 +122,8 @@ function simulate(
first_integration_step = true first_integration_step = true
state_update_helper_hook = state_update_helper_hook! =
state_update_hook = update_table_and_actions_hook = empty_hook state_update_hook! = update_table_and_actions_hook! = empty_hook
start_time = now() start_time = now()
println("Started simulation at $start_time.") println("Started simulation at $start_time.")
@ -146,25 +152,25 @@ function simulate(
run_additional_hooks = gen_run_additional_hooks(env_helper, integration_step) run_additional_hooks = gen_run_additional_hooks(env_helper, integration_step)
if run_additional_hooks if run_additional_hooks
RL.pre_integration_hook(env_helper) RL.pre_integration_hook!(env_helper)
state_update_helper_hook = RL.state_update_helper_hook state_update_helper_hook! = RL.state_update_helper_hook!
state_update_hook = RL.state_update_hook state_update_hook! = RL.state_update_hook!
update_table_and_actions_hook = RL.update_table_and_actions_hook update_table_and_actions_hook! = RL.update_table_and_actions_hook!
end end
euler!( euler!(
args, args,
first_integration_step, first_integration_step,
env_helper, env_helper,
state_update_helper_hook, state_update_helper_hook!,
state_update_hook, state_update_hook!,
update_table_and_actions_hook, update_table_and_actions_hook!,
) )
if run_additional_hooks if run_additional_hooks
state_update_helper_hook = state_update_helper_hook! =
state_update_hook = update_table_and_actions_hook = empty_hook state_update_hook! = update_table_and_actions_hook! = empty_hook
end end
first_integration_step = false first_integration_step = false