1
0
Fork 0
mirror of https://gitlab.rlp.net/mobitar/ReCo.jl.git synced 2024-12-21 00:51:21 +00:00

Follow Julian style guide

This commit is contained in:
Mo8it 2022-01-18 02:17:52 +01:00
parent eceeda4099
commit a31b6ed2a0
11 changed files with 94 additions and 65 deletions

View file

@ -2,6 +2,8 @@ module Error
export method_not_implemented
method_not_implemented() = error("Method not implemented!")
function method_not_implemented()
return error("Method not implemented!")
end
end # module

View file

@ -5,11 +5,11 @@ export angle2, norm2d, sq_norm2d
using StaticArrays: SVector
"""
angle2(a::SVector{2,Float64}, b::SVector{2,Float64})
angle2(a::SVector{2,Real}, b::SVector{2,Real})
Returns the angle φ from vector a to b while φ [-π, π].
"""
function angle2(a::SVector{2,Float64}, b::SVector{2,Float64})
function angle2(a::SVector{2,Real}, b::SVector{2,Real})
θ_a = atan(a[2], a[1])
θ_b = atan(b[2], b[1])
@ -18,7 +18,12 @@ function angle2(a::SVector{2,Float64}, b::SVector{2,Float64})
return rem2pi(θ, RoundNearest)
end
sq_norm2d(v::SVector{2,Float64}) = v[1]^2 + v[2]^2
norm2d(v::SVector{2,Float64}) = sqrt(sq_norm2d(v))
function sq_norm2d(v::SVector{2,Real})
return v[1]^2 + v[2]^2
end
function norm2d(v::SVector{2,Real})
return sqrt(sq_norm2d(v))
end
end # module

View file

@ -18,13 +18,13 @@ function push!(pv::PreVector{T}, x::T) where {T}
return nothing
end
function reset!(pv::PreVector{T}) where {T}
function reset!(pv::PreVector)
pv.last_ind = 0
return nothing
end
function iterate(pv::PreVector{T}, state=1) where {T}
function iterate(pv::PreVector, state::UInt64=UInt64(1))
if state > pv.last_ind
return nothing
else

View file

@ -73,12 +73,22 @@ function reset!(env::Env)
return nothing
end
RLBase.state_space(env::Env) = env.shared.state_id_space
function RLBase.state_space(env::Env)
return env.shared.state_id_space
end
RLBase.state(env::Env) = env.shared.state_id
function RLBase.state(env::Env)
return env.shared.state_id
end
RLBase.action_space(env::Env) = env.shared.action_id_space
function RLBase.action_space(env::Env)
return env.shared.action_id_space
end
RLBase.reward(env::Env) = env.shared.reward
function RLBase.reward(env::Env)
return env.shared.reward
end
RLBase.is_terminated(env::Env) = env.shared.terminated
function RLBase.is_terminated(env::Env)
return env.shared.terminated
end

View file

@ -20,12 +20,12 @@ struct EnvHelperSharedProps{H<:AbstractHook}
function EnvHelperSharedProps(
env::Env,
agent::Agent,
hook::H,
hook::AbstractHook,
n_steps_before_actions_update::Int64,
goal_gyration_tensor_eigvals_ratio::Float64,
n_particles::Int64,
) where {H<:AbstractHook}
return new{H}(
)
return new(
env,
agent,
hook,

View file

@ -1,16 +1,16 @@
using ..ReCo: Particle
function pre_integration_hook(::EnvHelper)
function pre_integration_hook!(::EnvHelper)
return ReCo.method_not_implemented()
end
function state_update_helper_hook(
function state_update_helper_hook!(
::EnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}
)
return ReCo.method_not_implemented()
end
function state_update_hook(::EnvHelper, particles::Vector{Particle})
function state_update_hook!(::EnvHelper, particles::Vector{Particle})
return ReCo.method_not_implemented()
end
@ -18,7 +18,7 @@ function update_reward!(::Env, ::EnvHelper, particle::Particle)
return ReCo.method_not_implemented()
end
function update_table_and_actions_hook(
function update_table_and_actions_hook!(
env_helper::EnvHelper, particle::Particle, first_integration_step::Bool
)
env, agent, hook = get_env_agent_hook(env_helper)
@ -56,10 +56,12 @@ function update_table_and_actions_hook(
return nothing
end
act_hook(::Nothing, args...) = nothing
function act_hook!(::Particle, ::Nothing, args...)
return nothing
end
function act_hook(
env_helper::EnvHelper, particle::Particle, δt::Float64, si::Float64, co::Float64
function act_hook!(
particle::Particle, env_helper::EnvHelper, δt::Float64, si::Float64, co::Float64
)
# Apply action
action = env_helper.shared.actions[particle.id]

View file

@ -76,7 +76,7 @@ function gen_env_helper(::LocalCOMEnv, env_helper_shared::EnvHelperSharedProps;
return LocalCOMEnvHelper(env_helper_shared, args.half_box_len, args.skin_r)
end
function pre_integration_hook(env_helper::LocalCOMEnvHelper)
function pre_integration_hook!(env_helper::LocalCOMEnvHelper)
@simd for id in 1:(env_helper.shared.n_particles)
env_helper.vec_to_neighbour_sums[id] = SVector(0.0, 0.0)
env_helper.n_neighbours[id] = 0
@ -85,7 +85,7 @@ function pre_integration_hook(env_helper::LocalCOMEnvHelper)
return nothing
end
function state_update_helper_hook(
function state_update_helper_hook!(
env_helper::LocalCOMEnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}
)
env_helper.vec_to_neighbour_sums[id1] += r⃗₁₂
@ -97,7 +97,7 @@ function state_update_helper_hook(
return nothing
end
function state_update_hook(env_helper::LocalCOMEnvHelper, particles::Vector{Particle})
function state_update_hook!(env_helper::LocalCOMEnvHelper, particles::Vector{Particle})
n_particles = env_helper.shared.n_particles
@turbo for id in 1:n_particles

View file

@ -61,7 +61,7 @@ function run_rl(;
seed::Int64=42,
ϵ_stable::Float64=0.0001,
skin_to_interaction_r_ratio::Float64=ReCo.DEFAULT_SKIN_TO_INTERACTION_R_RATIO,
packing_ratio=0.22,
packing_ratio::Float64=0.22,
) where {E<:Env}
@assert 0.0 <= goal_gyration_tensor_eigvals_ratio <= 1.0
@assert n_episodes > 0

View file

@ -8,21 +8,21 @@ using LinearAlgebra: eigvals, eigvecs, Hermitian, dot
using ..ReCo: ReCo, Particle
function project_to_unit_circle(x::Float64, half_box_len::Float64)
function project_to_unit_circle(x::Real, half_box_len::Real)
φ = (x + half_box_len) * π / half_box_len
si, co = sincos(φ)
return SVector(co, si)
end
function project_back_from_unit_circle(θ::T, half_box_len::Float64) where {T<:Real}
function project_back_from_unit_circle(θ::Real, half_box_len::Real)
x = θ * half_box_len / π - half_box_len
return ReCo.restrict_coordinate(x, half_box_len)
end
function center_of_mass_from_proj_sums(
x_proj_sum::SVector{2,Float64}, y_proj_sum::SVector{2,Float64}, half_box_len::Float64
x_proj_sum::SVector{2,Real}, y_proj_sum::SVector{2,Real}, half_box_len::Real
)
# Prevent for example atan(1e-16, 1e-15) != 0 with rounding
digits = 5
@ -47,7 +47,7 @@ function center_of_mass_from_proj_sums(
return SVector(COM_x, COM_y)
end
function center_of_mass(centers::AbstractVector{SVector{2,Float64}}, half_box_len::Float64)
function center_of_mass(centers::AbstractVector{SVector{2,Real}}, half_box_len::Real)
x_proj_sum = SVector(0.0, 0.0)
y_proj_sum = SVector(0.0, 0.0)
@ -59,7 +59,7 @@ function center_of_mass(centers::AbstractVector{SVector{2,Float64}}, half_box_le
return center_of_mass_from_proj_sums(x_proj_sum, y_proj_sum, half_box_len)
end
function center_of_mass(particles::Vector{Particle}, half_box_len::Float64)
function center_of_mass(particles::AbstractVector{Particle}, half_box_len::Real)
x_proj_sum = SVector(0.0, 0.0)
y_proj_sum = SVector(0.0, 0.0)
@ -72,7 +72,7 @@ function center_of_mass(particles::Vector{Particle}, half_box_len::Float64)
end
function gyration_tensor(
particles::Vector{Particle}, half_box_len::Float64, COM::SVector{2,Float64}
particles::AbstractVector{Particle}, half_box_len::Real, COM::SVector{2,Real}
)
S11 = 0.0
S12 = 0.0
@ -89,20 +89,22 @@ function gyration_tensor(
return Hermitian(SMatrix{2,2}(S11, S12, S12, S22))
end
function gyration_tensor(particles::Vector{Particle}, half_box_len::Float64)
function gyration_tensor(particles::AbstractVector{Particle}, half_box_len::Real)
COM = center_of_mass(particles, half_box_len)
return gyration_tensor(particles, half_box_len, COM)
end
function gyration_tensor_eigvals_ratio(particles::Vector{Particle}, half_box_len::Float64)
function gyration_tensor_eigvals_ratio(
particles::AbstractVector{Particle}, half_box_len::Real
)
g_tensor = gyration_tensor(particles, half_box_len)
ev = eigvals(g_tensor) # Eigenvalues are sorted
return ev[1] / ev[2]
end
function gyration_tensor_eigvecs(
particles::Vector{Particle}, half_box_len::Float64, COM::SVector{2,Float64}
particles::AbstractVector{Particle}, half_box_len::Real, COM::SVector{2,Real}
)
g_tensor = gyration_tensor(particles, half_box_len, COM)
eig_vecs = eigvecs(g_tensor)
@ -114,12 +116,12 @@ function gyration_tensor_eigvecs(
end
function elliptical_distance(
v::SVector{2,Float64},
COM::SVector{2,Float64},
gyration_tensor_eigvec_to_smaller_eigval::SVector{2,Float64},
gyration_tensor_eigvec_to_bigger_eigval::SVector{2,Float64},
goal_gyration_tensor_eigvals_ratio::Float64,
half_box_len::Float64,
v::SVector{2,Real},
COM::SVector{2,Real},
gyration_tensor_eigvec_to_smaller_eigval::SVector{2,Real},
gyration_tensor_eigvec_to_bigger_eigval::SVector{2,Real},
goal_gyration_tensor_eigvals_ratio::Real,
half_box_len::Real,
)
v = ReCo.minimum_image(v - COM, half_box_len)

View file

@ -1,4 +1,6 @@
empty_hook(args...) = nothing
function empty_hook(args...)
return nothing
end
function run_sim(
dir::String;
@ -101,7 +103,7 @@ function run_sim(
),
)
simulate(
simulate!(
args,
T0,
T,

View file

@ -1,4 +1,6 @@
rand_normal01() = rand(Normal(0, 1))
function rand_normal01()
return rand(Normal(0, 1))
end
function push_to_verlet_list!(verlet_lists, i, j)
if i < j
@ -36,9 +38,9 @@ function euler!(
args,
first_integration_step::Bool,
env_helper::Union{RL.EnvHelper,Nothing},
state_update_helper_hook::Function,
state_update_hook::Function,
update_table_and_actions_hook::Function,
state_update_helper_hook!::Function,
state_update_hook!::Function,
update_table_and_actions_hook!::Function,
)
for id1 in 1:(args.n_particles - 1)
p1 = args.particles[id1]
@ -52,7 +54,7 @@ function euler!(
p1_c, p2.c, args.interaction_r², args.half_box_len
)
state_update_helper_hook(env_helper, id1, id2, r⃗₁₂)
state_update_helper_hook!(env_helper, id1, id2, r⃗₁₂)
if overlapping
factor = args.c₁ / (distance²^4) * (args.c₂ / (distance²^3) - 1.0)
@ -64,7 +66,7 @@ function euler!(
end
end
state_update_hook(env_helper, args.particles)
state_update_hook!(env_helper, args.particles)
@simd for p in args.particles
si, co = sincos(p.φ)
@ -75,9 +77,9 @@ function euler!(
restrict_coordinates!(p, args.half_box_len)
update_table_and_actions_hook(env_helper, p, first_integration_step)
update_table_and_actions_hook!(env_helper, p, first_integration_step)
RL.act_hook(env_helper, p, args.δt, si, co)
RL.act_hook!(p, env_helper, args.δt, si, co)
p.φ += args.c₄ * rand_normal01()
@ -87,16 +89,20 @@ function euler!(
return nothing
end
Base.wait(::Nothing) = nothing
function Base.wait(::Nothing)
return nothing
end
gen_run_additional_hooks(::Nothing, args...) = false
function gen_run_additional_hooks(::Nothing, args...)
return false
end
function gen_run_additional_hooks(env_helper::RL.EnvHelper, integration_step::Int64)
return (integration_step % env_helper.shared.n_steps_before_actions_update == 0) ||
(integration_step == 1)
end
function simulate(
function simulate!(
args,
T0::Float64,
T::Float64,
@ -116,8 +122,8 @@ function simulate(
first_integration_step = true
state_update_helper_hook =
state_update_hook = update_table_and_actions_hook = empty_hook
state_update_helper_hook! =
state_update_hook! = update_table_and_actions_hook! = empty_hook
start_time = now()
println("Started simulation at $start_time.")
@ -146,25 +152,25 @@ function simulate(
run_additional_hooks = gen_run_additional_hooks(env_helper, integration_step)
if run_additional_hooks
RL.pre_integration_hook(env_helper)
RL.pre_integration_hook!(env_helper)
state_update_helper_hook = RL.state_update_helper_hook
state_update_hook = RL.state_update_hook
update_table_and_actions_hook = RL.update_table_and_actions_hook
state_update_helper_hook! = RL.state_update_helper_hook!
state_update_hook! = RL.state_update_hook!
update_table_and_actions_hook! = RL.update_table_and_actions_hook!
end
euler!(
args,
first_integration_step,
env_helper,
state_update_helper_hook,
state_update_hook,
update_table_and_actions_hook,
state_update_helper_hook!,
state_update_hook!,
update_table_and_actions_hook!,
)
if run_additional_hooks
state_update_helper_hook =
state_update_hook = update_table_and_actions_hook = empty_hook
state_update_helper_hook! =
state_update_hook! = update_table_and_actions_hook! = empty_hook
end
first_integration_step = false