1
0
Fork 0
mirror of https://gitlab.rlp.net/mobitar/ReCo.jl.git synced 2024-09-19 19:01:17 +00:00

Restructure for RL

This commit is contained in:
MoBit 2021-12-12 15:29:08 +01:00
parent 07b8455e24
commit 029f7f29f2
9 changed files with 339 additions and 148 deletions

View file

@ -12,7 +12,7 @@ function plot_g(radius, g, variables)
yticks=0:ceil(Int64, maximum(g)), yticks=0:ceil(Int64, maximum(g)),
xlabel=L"r", xlabel=L"r",
ylabel=L"g(r)", ylabel=L"g(r)",
title="v = $(variables.v)", title="v = $(variables.v)",
) )
scatterlines!(ax, radius, g; color=:white, markercolor=:red) scatterlines!(ax, radius, g; color=:white, markercolor=:red)

View file

@ -133,7 +133,7 @@ function animate_after_loading(dir, animation_path, sim_consts, framerate, debug
skin_colors = Observable(Vector{RGBAf}(undef, n_particles)) skin_colors = Observable(Vector{RGBAf}(undef, n_particles))
end end
particle_radius = sim_consts.particle_diameter / 2 particle_radius = sim_consts.particle_radius
interaction_r = sim_consts.interaction_r interaction_r = sim_consts.interaction_r
bundle_paths = readdir("$dir/bundles"; join=true, sort=false) bundle_paths = readdir("$dir/bundles"; join=true, sort=false)

View file

@ -8,13 +8,13 @@ using ReCo
function run_benchmarks( function run_benchmarks(
dir::String=""; dir::String="";
n_particles::Int64=1000, n_particles::Int64=1000,
v::Float64=20.0, v::Float64=20.0,
duration::Float64=2.0, duration::Float64=2.0,
n_bundle_snapshots::Int64=0, n_bundle_snapshots::Int64=0,
comment="", comment="",
) )
if length(dir) == 0 if length(dir) == 0
dir = init_sim(; n_particles=n_particles, v=v, parent_dir="benchmark") dir = init_sim(; n_particles=n_particles, v=v, parent_dir="benchmark")
end end
benchmark = @benchmark run_sim( benchmark = @benchmark run_sim(
@ -28,7 +28,7 @@ function run_benchmarks(
f, f,
Dict( Dict(
"n_particles" => n_particles, "n_particles" => n_particles,
"v" => v, "v" => v,
"duration" => duration, "duration" => duration,
"n_bundle_snapshots" => n_bundle_snapshots, "n_bundle_snapshots" => n_bundle_snapshots,
"comment" => comment, "comment" => comment,

View file

@ -1,5 +1,7 @@
using StaticArrays: SVector using StaticArrays: SVector
using JLD2: JLD2 using JLD2: JLD2
using JSON3: JSON3
using OrderedCollections: OrderedDict
struct Bundle struct Bundle
t::Vector{Float64} t::Vector{Float64}
@ -33,14 +35,6 @@ function save_snapshot!(
return nothing return nothing
end end
function set_sim_state(dir::String, n_bundles::Int64, T::Float64)
open("$dir/sim_state.json", "w") do f
JSON3.write(f, (n_bundles=n_bundles, T=round(T; digits=3)))
end
return nothing
end
function save_bundle(dir::String, bundle::Bundle, n::Int64, T::Float64) function save_bundle(dir::String, bundle::Bundle, n::Int64, T::Float64)
bundles_dir = "$dir/bundles" bundles_dir = "$dir/bundles"
mkpath(bundles_dir) mkpath(bundles_dir)
@ -49,5 +43,27 @@ function save_bundle(dir::String, bundle::Bundle, n::Int64, T::Float64)
set_sim_state(dir, n, T) set_sim_state(dir, n, T)
return nothing
end
function set_sim_state(dir::String, n_bundles::Int64, T::Float64)
open("$dir/sim_state.json", "w") do f
JSON3.write(f, (n_bundles=n_bundles, T=round(T; digits=3)))
end
return nothing
end
function struct_to_ordered_dict(s)
return OrderedDict(key => getfield(s, key) for key in propertynames(s))
end
function write_struct_to_json(s, path_without_extension::String)
ordered_dict = struct_to_ordered_dict(s)
open("$path_without_extension.json", "w") do f
JSON3.write(f, ordered_dict)
end
return nothing return nothing
end end

View file

@ -1,24 +1,45 @@
using ReinforcementLearning module ReCoRL
mutable struct ReCoEnvParams using ReinforcementLearning
n_particles::Int64 using Intervals
half_box_len::Float64
skin_r::Float64 import Base: run
struct DistanceState{L<:Bound}
interval::Interval{Float64,L,Closed}
function DistanceState{L}(lower::Float64, upper::Float64) where {L<:Bound}
return new(Interval{Float64,L,Closed}(lower, upper))
end
end
struct DirectionState
interval::Interval{Float64,Open,Closed}
function DirectionState(lower::Float64, upper::Float64)
return new(Interval{Float64,Closed,Open}(lower, upper))
end
end
mutable struct EnvParams
action_space::Vector{Tuple{Float64,Float64}} action_space::Vector{Tuple{Float64,Float64}}
state_space::Vector{Tuple{Symbol,Symbol}} distance_state_space::Vector{DistanceState}
direction_state_space::Vector{DirectionState}
state_space::Vector{Tuple{Union{DistanceState,Nothing},Union{DirectionState,Nothing}}}
reward::Float64 reward::Float64
function ReCoEnvParams( function EnvParams(
n_particles::Int64, min_distance::Float64,
half_box_len::Float64, max_distance::Float64;
skin_r::Float64, n_v_actions::Int64=2,
n_v_actions::Int64, n_ω_actions::Int64=3,
n_ω_actions::Int64; max_v::Float64=20.0,
max_v::Float64=80.0, max_ω::Float64=π / 1.5,
max_ω::Float64=float(π), n_distance_states::Int64=3,
n_direction_states::Int64=4,
) )
@assert half_box_len > 0 @assert min_distance > 0.0
@assert skin_r > 0 @assert max_distance > min_distance
@assert n_v_actions > 1 @assert n_v_actions > 1
@assert n_ω_actoins > 1 @assert n_ω_actoins > 1
@assert max_v > 0 @assert max_v > 0
@ -39,40 +60,133 @@ mutable struct ReCoEnvParams
end end
end end
distance_state_space = (:big, :medium, :small) distance_range =
direction_state_space = (:before, :behind, :left, :right) min_distance:((max_distance - min_distance) / n_distance_states):max_distance
n_states = undef, length(distance_state_space) * length(direction_state_space) + 1 distance_state_space = Vector{DistanceState}(undef, n_distance_states)
state_space = Vector{Tuple{Symbol,Symbol}}(n_states) for i in 1:n_distance_states
if i == 1
bound = Closed
else
bound = Open
end
distance_state_space[i] = DistanceState{bound}(
distance_range[i], distance_range[i + 1]
)
end
direction_range = 0.0:(2 * π / n_direction_states):(2 * π)
direction_state_space = Vector{DirectionState}(undef, n_direction_states)
for i in 1:n_direction_states
direction_state_space[i] = DirectionState(
direction_range[i], direction_range[i + 1]
)
end
state_space = Vector{Tuple{DistanceState,DirectionState}}(undef, n_states)
ind = 1 ind = 1
for distance in distance_state_space for distance_state in distance_state_space
for direction in direction_state_space for direction_state in direction_state_space
state_space[ind] = (distance, direction) state_space[ind] = (distance_state, direction_state)
ind += 1 ind += 1
end end
end end
state_space[ind] = (:none, :none) state_space[ind] = (nothing, nothing)
return new(n_particles, half_box_len, skin_r, action_space, state_space, 0.0) return new(
action_space, distance_state_space, direction_state_space, state_space, 0.0
)
end end
end end
mutable struct ReCoEnv <: AbstractEnv mutable struct Env <: AbstractEnv
params::ReCoEnvParams params::EnvParams
particle::Particle particle::Particle
state::Tuple{Symbol,Symbol} state::Tuple{Union{DistanceState,Nothing},Union{DirectionState,Nothing}}
function ReCoEnv(params::ReCoEnvParams, particle::Particle) function Env(params::EnvParams, particle::Particle)
return new(params, particle, (:none, :none)) return new(params, particle, (nothing, nothing))
end end
end end
RLBase.state_space(env::ReCoEnv) = env.state_space struct Params
envs::Vector{Env}
# agents
actions::Vector{Tuple{Float64,Float64}}
env_params::EnvParams
n_steps_before_actions_update::Int64
RLBase.state(env::ReCoEnv) = env.state function Params(
n_particles::Int64, env_params::EnvParams, n_steps_before_actions_update::Int64
)
return new(
Vector{Env}(undef, n_particles),
Vector{Tuple{Float64,Float64}}(undef, n_particles),
env_params,
n_steps_before_actions_update,
)
end
end
RLBase.action_space(env::ReCoEnv) = env.params.action_space RLBase.state_space(env::Env) = env.state_space
RLBase.reward(env::ReCoEnv) = env.params.reward RLBase.state(env::Env) = env.state
RLBase.action_space(env::Env) = env.params.action_space
RLBase.reward(env::Env) = env.params.reward
function pre_integration_hook!() end
function integration_hook() end
function post_integration_hook() end
function run(
n_episodes::Int64=100,
episode_duration::Float64=5.0,
update_actions_at::Float64=0.1,
n_particles::Int64=100,
)
@assert n_episodes > 0
@assert episode_duration > 0
@assert update_actions_at in 0.01:0.01:episode_duration
@assert episode_duration % update_actions_at == 0
@assert n_particles > 0
Random.seed!(42)
# envs
# agents
# pre_experiment
sim_consts = gen_sim_consts(n_particles, v₀)
env_params = EnvParams(sim_consts.particle_radius, sim_consts.skin_r)
rl_params = Params(n_particles, env_params, n_steps_before_actions_update)
for episode in 1:n_episodes
# reset
# pre_episode
dir = init_sim_with_sim_consts(; sim_consts, parent_dir="RL")
run_sim(
dir;
duration=episode_duration,
seed=rand(1:typemax(Int64)),
rl_params=rl_params,
skin_r=skin_r,
)
end
return nothing
end
end # module

View file

@ -3,18 +3,20 @@ using JSON3: JSON3
using JLD2: JLD2 using JLD2: JLD2
using StaticArrays: SVector using StaticArrays: SVector
empty_hook(args...) = nothing
function run_sim( function run_sim(
dir::String; dir::String;
duration::Float64, duration::Float64,
snapshot_at::Float64=0.1, snapshot_at::Float64=0.1,
n_steps_before_verlet_list_update::Int64=100,
seed::Int64=42, seed::Int64=42,
n_bundle_snapshots::Int64=100, n_bundle_snapshots::Int64=100,
rl_params::Union{ReCo.Params,Nothing}=nothing,
) )
@assert length(dir) > 0 @assert length(dir) > 0
@assert duration > 0 @assert duration > 0
@assert snapshot_at in 0.001:0.001:duration @assert snapshot_at in 0.001:0.001:duration
@assert n_steps_before_verlet_list_update >= 0 @assert duration % snapshot_at == 0
@assert seed > 0 @assert seed > 0
@assert n_bundle_snapshots >= 0 @assert n_bundle_snapshots >= 0
@ -22,15 +24,9 @@ function run_sim(
sim_consts = JSON3.read(read("$dir/sim_consts.json", String)) sim_consts = JSON3.read(read("$dir/sim_consts.json", String))
skin_r =
1.5 * (
2 * sim_consts.v * n_steps_before_verlet_list_update * sim_consts.δt +
2 * sim_consts.interaction_r
)
integration_steps = floor(Int64, duration / sim_consts.δt) + 1 integration_steps = floor(Int64, duration / sim_consts.δt) + 1
n_steps_before_snapshot = round(Int64, snapshot_at / sim_consts.δt)
n_steps_before_snapshot = round(Int64, snapshot_at / sim_consts.δt)
n_snapshots = floor(Int64, integration_steps / n_steps_before_snapshot) + 1 n_snapshots = floor(Int64, integration_steps / n_steps_before_snapshot) + 1
n_bundle_snapshots = min(n_snapshots, n_bundle_snapshots) n_bundle_snapshots = min(n_snapshots, n_bundle_snapshots)
@ -38,27 +34,63 @@ function run_sim(
sim_state = JSON3.read(read("$dir/sim_state.json", String)) sim_state = JSON3.read(read("$dir/sim_state.json", String))
n_bundles = sim_state.n_bundles n_bundles = sim_state.n_bundles
T0::Float64 = sim_state.T
T = T0 + duration
if n_bundle_snapshots > 0
save_data = true
else
save_data = false
end
@async begin
start_datetime = now()
run_params = (;
# Input
duration,
snapshot_at,
seed,
n_bundle_snapshots,
# Calculated
integration_steps,
n_steps_before_snapshot,
n_snapshots,
T,
# Read
T0,
start_datetime,
)
next_bundle = n_bundles + 1
runs_dir = "$dir/runs"
if save_data
write_struct_to_json(run_params, "$runs_dir/run_params_$next_bundle")
end
end
bundles_dir = "$dir/bundles" bundles_dir = "$dir/bundles"
bundle = JLD2.load_object("$bundles_dir/bundle_$n_bundles.jld2") bundle = JLD2.load_object("$bundles_dir/bundle_$n_bundles.jld2")
particles = generate_particles(bundle, sim_consts.n_particles) particles = gen_particles(bundle, sim_consts.n_particles)
args = ( args = (
v=sim_consts.v, v=sim_consts.v,
skin_r=skin_r, skin_r=sim_consts.skin_r,
skin_r²=skin_r^2, skin_r²=sim_consts.skin_r^2,
n_snapshots=n_snapshots, n_snapshots=n_snapshots,
c₁=4 * sim_consts.ϵ * 6 * sim_consts.σ^6 * sim_consts.δt * sim_consts.μ, c₁=4 * sim_consts.ϵ * 6 * sim_consts.σ^6 * sim_consts.δt * sim_consts.μ,
c₂=2 * sim_consts.σ^6, c₂=2 * sim_consts.σ^6,
c₃=sqrt(2 * sim_consts.D₀ * sim_consts.δt), c₃=sqrt(2 * sim_consts.D₀ * sim_consts.δt),
c₄=sqrt(2 * sim_consts.Dᵣ * sim_consts.δt), c₄=sqrt(2 * sim_consts.Dᵣ * sim_consts.δt),
vδt=sim_consts.v * sim_consts.δt, vδt=sim_consts.v * sim_consts.δt,
μ=sim_consts.μ, μ=sim_consts.μ,
interaction_r=sim_consts.interaction_r, interaction_r=sim_consts.interaction_r,
interaction_r²=sim_consts.interaction_r^2, interaction_r²=sim_consts.interaction_r^2,
n_particles=sim_consts.n_particles, n_particles=sim_consts.n_particles,
half_box_len=sim_consts.half_box_len, half_box_len=sim_consts.half_box_len,
particle_diameter=sim_consts.particle_diameter,
particles=particles, particles=particles,
particles_c=[particles[i].c for i in 1:(sim_consts.n_particles)], particles_c=[particles[i].c for i in 1:(sim_consts.n_particles)],
verlet_lists=[ verlet_lists=[
@ -67,45 +99,20 @@ function run_sim(
], ],
n_bundle_snapshots=n_bundle_snapshots, n_bundle_snapshots=n_bundle_snapshots,
bundle=Bundle(sim_consts.n_particles, n_bundle_snapshots), bundle=Bundle(sim_consts.n_particles, n_bundle_snapshots),
box=Box(SVector(2 * sim_consts.half_box_len, 2 * sim_consts.half_box_len), skin_r), box=Box(
SVector(2 * sim_consts.half_box_len, 2 * sim_consts.half_box_len),
sim_consts.skin_r,
),
) )
T0::Float64 = sim_state.T if !isnothing(rl_params)
T = T0 + duration pre_integration_hook! = ReCoRL.pre_integration_hook
integration_hook = ReCoRL.integration_hook
start_datetime = now() post_integration_hook = ReCoRL.post_integration_hook
run_params = (;
# Input
duration,
snapshot_at,
n_steps_before_verlet_list_update,
seed,
n_bundle_snapshots,
# Calculated
skin_r,
integration_steps,
n_steps_before_snapshot,
n_snapshots,
T,
# Read
T0,
start_datetime,
)
next_bundle = n_bundles + 1
runs_dir = "$dir/runs"
mkpath(runs_dir)
if n_bundle_snapshots > 0
save_data = true
open("$runs_dir/run_params_$next_bundle.json", "w") do f
JSON3.write(f, run_params)
end
else else
save_data = false pre_integration_hook! = empty_hook
integration_hook = empty_hook
post_integration_hook = empty_hook
end end
simulate( simulate(
@ -113,12 +120,16 @@ function run_sim(
sim_consts.δt, sim_consts.δt,
T0, T0,
T, T,
n_steps_before_verlet_list_update, sim.consts.n_steps_before_verlet_list_update,
n_steps_before_snapshot, n_steps_before_snapshot,
n_bundles, n_bundles,
dir, dir,
save_data, save_data,
rl_params,
pre_integration_hook!,
integration_hook,
post_integration_hook,
) )
return dir return nothing
end end

View file

@ -1,6 +1,12 @@
using Distributions: Uniform using Distributions: Uniform
using Dates: now using Dates: now
using JSON3: JSON3
const DEFAULT_PACKING_RATIO = 0.5
const DEFAULT_δt = 1e-5
const DEFAULT_SKIN_TO_INTERACTION_R_RATIO = 3.0
const DEFAULT_EXPORTS_DIR = "exports"
const DEFAULT_PARENT_DIR = ""
const DEFAULT_COMMENT = ""
function initial_particle_grid_pos( function initial_particle_grid_pos(
i::Int64, j::Int64, grid_box_width::Float64, half_box_len::Float64 i::Int64, j::Int64, grid_box_width::Float64, half_box_len::Float64
@ -9,7 +15,7 @@ function initial_particle_grid_pos(
return SVector(i * grid_box_width + term, j * grid_box_width + term) return SVector(i * grid_box_width + term, j * grid_box_width + term)
end end
function generate_particles(grid_n::Int64, grid_box_width::Float64, half_box_len::Float64) function gen_particles(grid_n::Int64, grid_box_width::Float64, half_box_len::Float64)
particles = Vector{Particle}(undef, grid_n^2) particles = Vector{Particle}(undef, grid_n^2)
id = 1 id = 1
@ -29,7 +35,7 @@ function generate_particles(grid_n::Int64, grid_box_width::Float64, half_box_len
return particles return particles
end end
function generate_particles(bundle::Bundle, n_particles::Int64) function gen_particles(bundle::Bundle, n_particles::Int64)
particles = Vector{Particle}(undef, n_particles) particles = Vector{Particle}(undef, n_particles)
@simd for id in 1:n_particles @simd for id in 1:n_particles
@ -39,27 +45,34 @@ function generate_particles(bundle::Bundle, n_particles::Int64)
return particles return particles
end end
function init_sim(; function gen_sim_consts(
n_particles::Int64, n_particles::Int64,
v::Float64, v₀::Float64,
δt::Float64=1e-5, δt::Float64=DEFAULT_δt,
packing_ratio::Float64=0.5, packing_ratio::Float64=DEFAULT_PACKING_RATIO,
comment::String="", skin_to_interaction_r_ratio::Float64=DEFAULT_SKIN_TO_INTERACTION_R_RATIO,
parent_dir::String="",
) )
@assert n_particles > 0 @assert n_particles > 0
@assert v >= 0 @assert v >= 0
@assert δt in 1e-7:1e-7:1e-4 @assert δt in 1e-7:1e-7:1e-4
@assert packing_ratio > 0 @assert packing_ratio > 0
μ = 1.0 μ = 1.0
D₀ = 1.0 D₀ = 1.0
particle_diameter = 1.0 particle_radius = 1.0
Dᵣ = 3 * D₀ / (particle_diameter^2) Dᵣ = 3 * D₀ / ((2 * particle_radius)^2)
σ = 1.0 σ = 1.0
ϵ = 100.0 ϵ = 100.0
interaction_r = 2^(1 / 6) * σ interaction_r = 2^(1 / 6) * σ
buffer = 6.0
max_approach_after_one_integration_step = buffer * (2 * v₀ * δt) / interaction_r
@assert skin_to_interaction_r_ratio >= 1 + max_approach_after_one_integration_step
skin_r = skin_to_interaction_r_ratio * interaction_r
n_steps_before_verlet_list_update =
(skin_to_interaction_r_ratio - 1) / max_approach_after_one_integration_step
grid_n = round(Int64, ceil(sqrt(n_particles))) grid_n = round(Int64, ceil(sqrt(n_particles)))
n_particles = grid_n^2 n_particles = grid_n^2
@ -67,47 +80,80 @@ function init_sim(;
half_box_len = sqrt(n_particles * π / packing_ratio) * σ / 4 half_box_len = sqrt(n_particles * π / packing_ratio) * σ / 4
grid_box_width = 2 * half_box_len / grid_n grid_box_width = 2 * half_box_len / grid_n
sim_consts = (; return (;
# Input # Input
n_particles, n_particles,
v, v,
δt, δt,
packing_ratio, packing_ratio,
# Calculated # Calculated
μ, μ,
D₀, D₀,
particle_diameter, particle_radius,
Dᵣ, Dᵣ,
σ, σ,
ϵ, ϵ,
interaction_r, interaction_r,
skin_r,
n_steps_before_verlet_list_update,
grid_n, grid_n,
half_box_len, half_box_len,
grid_box_width, grid_box_width,
) )
end
particles = generate_particles(grid_n, grid_box_width, half_box_len) function init_sim_with_sim_consts(
sim_consts,
exports_dir::String=DEFAULT_EXPORTS_DIR,
parent_dir::String=DEFAULT_PARENT_DIR,
comment::String=DEFAULT_COMMENT,
)
particles = gen_particles(
sim_consts.grid_n, sim_consts.grid_box_width, sim_consts.half_box_len
)
bundle = Bundle(n_particles, 1) bundle = Bundle(n_particles, 1)
save_snapshot!(bundle, 1, 0.0, particles) save_snapshot!(bundle, 1, 0.0, particles)
particles = nothing particles = nothing
dir = "exports"
if length(parent_dir) > 0 if length(parent_dir) > 0
dir *= "/$parent_dir" exports_dir *= "/$parent_dir"
end end
dir *= "/$(now())_N=$(n_particles)_v=$(v)_#$(rand(1000:9999))"
start_datetime = now()
exports_dir *= "/$(start_datetime)_N=$(sim_consts.n_particles)_v=$(sim_consts.v₀)_#$(rand(1000:9999))"
if length(comment) > 0 if length(comment) > 0
dir *= "_$comment" exports_dir *= "_$comment"
end end
mkpath(dir) mkpath(exports_dir)
open("$dir/sim_consts.json", "w") do f task = @async write_struct_to_json(sim_consts, "$exports_dir/sim_consts")
JSON3.write(f, sim_consts)
end
save_bundle(dir, bundle, 1, 0.0) save_bundle(exports_dir, bundle, 1, 0.0)
return dir runs_dir = "$exports_dir/runs"
mkpath(runs_dir)
wait(task)
return exports_dir
end
function init_sim(;
n_particles::Int64,
v₀::Float64,
δt::Float64=DEFAULT_δt,
packing_ratio::Float64=DEFAULT_PACKING_RATIO,
skin_to_interaction_r_ratio::Float64=DEFAULT_SKIN_TO_INTERACTION_R_RATIO,
exports_dir::String=DEFAULT_EXPORTS_DIR,
parent_dir::String=DEFAULT_PARENT_DIR,
comment::String=DEFAULT_COMMENT,
)
sim_consts = gen_sim_consts(
n_particles, v₀, δt, packing_ratio, skin_to_interaction_r_ratio
)
return init_sim_with_sim_consts(sim_consts, exports_dir, parent_dir, comment)
end end

View file

@ -40,7 +40,9 @@ function update_verlet_lists!(args, cl)
return cl return cl
end end
function euler!(args) function euler!(
args, integration_hook::F, actions::Vector{Tuple{Float64,Float64}}
) where {F<:Function}
for id1 in 1:(args.n_particles - 1) for id1 in 1:(args.n_particles - 1)
p1 = args.particles[id1] p1 = args.particles[id1]
p1_c = p1.c p1_c = p1.c
@ -66,21 +68,23 @@ function euler!(args)
@simd for p in args.particles @simd for p in args.particles
si, co = sincos(p.φ) si, co = sincos(p.φ)
p.tmp_c += SVector( p.tmp_c += SVector(
args.vδt * co + args.c₃ * rand_normal01(), args.vδt * co + args.c₃ * rand_normal01(),
args.vδt * si + args.c₃ * rand_normal01(), args.vδt * si + args.c₃ * rand_normal01(),
) )
p.φ += args.c₄ * rand_normal01() p.φ += args.c₄ * rand_normal01()
restrict_coordinates!(p, args.half_box_len) restrict_coordinates!(p, args.half_box_len)
integration_hook(p, actions)
p.c = p.tmp_c p.c = p.tmp_c
end end
return nothing return nothing
end end
wait(n::Nothing) = n wait(::Nothing) = nothing
function simulate( function simulate(
args, args,
@ -92,7 +96,11 @@ function simulate(
n_bundles::Int64, n_bundles::Int64,
dir::String, dir::String,
save_data::Bool, save_data::Bool,
) rl_params::Union{ReCoRL.Params,Nothing},
pre_integration_hook!::F,
integration_hook::F,
post_integration_hook::F,
) where {F<:Function}
bundle_snapshot_counter = 0 bundle_snapshot_counter = 0
task::Union{Task,Nothing} = nothing task::Union{Task,Nothing} = nothing
@ -100,6 +108,8 @@ function simulate(
cl = CellList(args.particles_c, args.box; parallel=false) cl = CellList(args.particles_c, args.box; parallel=false)
cl = update_verlet_lists!(args, cl) cl = update_verlet_lists!(args, cl)
update_actions = true
start_time = now() start_time = now()
println("Started simulation at $start_time.") println("Started simulation at $start_time.")
@ -124,7 +134,17 @@ function simulate(
cl = update_verlet_lists!(args, cl) cl = update_verlet_lists!(args, cl)
end end
euler!(args) update_actions = integration_step % rl_params.n_steps_before_actions_update == 0
if update_actions
pre_integration_hook!(rl_params)
end
euler!(args, integration_hook, rl.params.actions)
if update_actions
post_integration_hook(rl_params)
end
end end
wait(task) wait(task)

View file

@ -1,16 +0,0 @@
using JSON3: JSON3
using OrderedCollections: OrderedDict
function struct_to_ordered_dict(s)
return OrderedDict(key => getfield(s, key) for key in propertynames(s))
end
function write_struct_to_json(s, path_without_extension::String)
ordered_dict = struct_to_ordered_dict(s)
open("$path_without_extension.json", "w") do f
JSON3.write(f, ordered_dict)
end
return nothing
end