mirror of
https://gitlab.rlp.net/mobitar/ReCo.jl.git
synced 2024-10-11 20:34:22 +00:00
Fixes
This commit is contained in:
parent
dc6a4ee7b7
commit
3036c5e65a
6 changed files with 69 additions and 56 deletions
|
@ -10,6 +10,7 @@ CellListMap = "69e1c6dd-3888-40e6-b3c8-31ac5f578864"
|
|||
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
|
||||
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
|
||||
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
|
||||
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
|
||||
GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a"
|
||||
Intervals = "d8418881-c3e1-53bb-8760-2df7ec849ed5"
|
||||
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
|
||||
|
|
|
@ -28,9 +28,9 @@ function animate_bundle!(args)
|
|||
|
||||
if args.debug
|
||||
args.interaction_circles[][i] = Circle(
|
||||
Point2(bundle_c[i, frame]), args.interaction_r
|
||||
Point2(c[1], c[2]), args.interaction_r
|
||||
)
|
||||
args.skin_circles[][i] = Circle(Point2(bundle_c[i, frame]), args.skin_r)
|
||||
args.skin_circles[][i] = Circle(Point2(c[1], c[2]), args.skin_r)
|
||||
|
||||
args.interaction_colors[][i] = RGBAf(color, 0.08)
|
||||
args.skin_colors[][i] = RGBAf(color, 0.04)
|
||||
|
@ -146,16 +146,11 @@ function animate_after_loading(dir, animation_path, sim_consts, framerate, debug
|
|||
bundle_paths = bundle_paths[sort_perm]
|
||||
sort_perm = nothing
|
||||
|
||||
skin_r = 0.0
|
||||
skin_r = sim_consts.skin_r
|
||||
|
||||
@showprogress 1 for (n_bundle, bundle_path) in enumerate(bundle_paths)
|
||||
bundle::Bundle = JLD2.load_object(bundle_path)
|
||||
|
||||
run_params_file = "$dir/runs/run_params_$n_bundle.json"
|
||||
if debug && isfile(run_params_file)
|
||||
skin_r::Float64 = JSON3.read(read(run_params_file, String)).skin_r
|
||||
end
|
||||
|
||||
args = (;
|
||||
io,
|
||||
ax,
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
module ReCoRL
|
||||
module RL
|
||||
|
||||
using ReinforcementLearning
|
||||
using Flux: InvDecay
|
||||
using Intervals
|
||||
using StaticArrays: SVector
|
||||
using Random: Random
|
||||
|
||||
using ..ReCo
|
||||
|
||||
import Base: run
|
||||
|
||||
|
@ -16,7 +19,7 @@ struct DistanceState{L<:Bound}
|
|||
end
|
||||
|
||||
struct DirectionState
|
||||
interval::Interval{Float64,Open,Closed}
|
||||
interval::Interval{Float64,Closed,Open}
|
||||
|
||||
function DirectionState(lower::Float64, upper::Float64)
|
||||
return new(Interval{Float64,Closed,Open}(lower, upper))
|
||||
|
@ -43,7 +46,7 @@ mutable struct EnvParams
|
|||
@assert min_distance > 0.0
|
||||
@assert max_distance > min_distance
|
||||
@assert n_v_actions > 1
|
||||
@assert n_ω_actoins > 1
|
||||
@assert n_ω_actions > 1
|
||||
@assert max_v > 0
|
||||
@assert max_ω > 0
|
||||
|
||||
|
@ -89,6 +92,8 @@ mutable struct EnvParams
|
|||
)
|
||||
end
|
||||
|
||||
n_states = n_distance_states * n_direction_states + 1
|
||||
|
||||
state_space = Vector{
|
||||
Tuple{Union{DistanceState,Nothing},Union{DirectionState,Nothing}}
|
||||
}(
|
||||
|
@ -112,10 +117,10 @@ end
|
|||
|
||||
mutable struct Env <: AbstractEnv
|
||||
params::EnvParams
|
||||
particle::Particle
|
||||
particle::ReCo.Particle
|
||||
state::Tuple{Union{DistanceState,Nothing},Union{DirectionState,Nothing}}
|
||||
|
||||
function Env(params::EnvParams, particle::Particle)
|
||||
function Env(params::EnvParams, particle::ReCo.Particle)
|
||||
return new(params, particle, (nothing, nothing))
|
||||
end
|
||||
end
|
||||
|
@ -145,16 +150,14 @@ struct Params{H<:AbstractHook}
|
|||
n_particles::Int64, env_params::EnvParams, n_steps_before_actions_update::Int64
|
||||
) where {H<:AbstractHook}
|
||||
policies = [
|
||||
gen_policy(
|
||||
length(rl_params.env_params.state_space),
|
||||
length(rl_params.env_params.action_space),
|
||||
) for i in 1:n_particles
|
||||
gen_policy(length(env_params.state_space), length(env_params.action_space)) for
|
||||
i in 1:n_particles
|
||||
]
|
||||
agents = [
|
||||
Agent(; policy=policy, trajectory=VectorSARTTrajectory()) for policy in policies
|
||||
]
|
||||
return new(
|
||||
[Env(env_params, gen_tmp_particle()) for i in 1:n_particles],
|
||||
[Env(env_params, ReCo.gen_tmp_particle()) for i in 1:n_particles],
|
||||
agents,
|
||||
[H() for i in 1:n_particles],
|
||||
Vector{Tuple{Float64,Float64}}(undef, n_particles),
|
||||
|
@ -206,7 +209,7 @@ function state_hook(
|
|||
return nothing
|
||||
end
|
||||
|
||||
function integration_hook(particle::Particle, rl_params::Params, δt::Float64)
|
||||
function integration_hook(particle::ReCo.Particle, rl_params::Params, δt::Float64)
|
||||
action = rl_params.actions[particle.id]
|
||||
|
||||
particle.tmp_c += action[1] * δt
|
||||
|
@ -216,7 +219,7 @@ function integration_hook(particle::Particle, rl_params::Params, δt::Float64)
|
|||
end
|
||||
|
||||
function post_integration_hook(
|
||||
rl_params::Params, n_particles::Int64, particles::Vector{Particle}
|
||||
rl_params::Params, n_particles::Int64, particles::Vector{ReCo.Particle}
|
||||
)
|
||||
env_direction_state = rl_params.env_params.direction_state_space[1]
|
||||
|
||||
|
@ -266,7 +269,7 @@ function post_integration_hook(
|
|||
return nothing
|
||||
end
|
||||
|
||||
function run(
|
||||
function run(;
|
||||
n_episodes::Int64=100,
|
||||
episode_duration::Float64=5.0,
|
||||
update_actions_at::Float64=0.1,
|
||||
|
@ -275,15 +278,17 @@ function run(
|
|||
@assert n_episodes > 0
|
||||
@assert episode_duration > 0
|
||||
@assert update_actions_at in 0.01:0.01:episode_duration
|
||||
@assert episode_duration % update_actions_at == 0
|
||||
@assert n_particles > 0
|
||||
|
||||
Random.seed!(42)
|
||||
|
||||
sim_consts = gen_sim_consts(n_particles, v₀)
|
||||
sim_consts = ReCo.gen_sim_consts(n_particles, 0.0; skin_to_interaction_r_ratio=3.5)
|
||||
n_particles = sim_consts.n_particles
|
||||
|
||||
env_params = EnvParams(sim_consts.particle_radius, sim_consts.skin_r)
|
||||
|
||||
n_steps_before_actions_update = round(Int64, update_actions_at / sim_consts.δt)
|
||||
|
||||
rl_params = Params{TotalRewardPerEpisode}(
|
||||
n_particles, env_params, n_steps_before_actions_update
|
||||
)
|
||||
|
@ -297,7 +302,7 @@ function run(
|
|||
end
|
||||
|
||||
for episode in 1:n_episodes
|
||||
dir, particles = init_sim_with_sim_consts(; sim_consts, parent_dir="RL")
|
||||
dir, particles = ReCo.init_sim_with_sim_consts(sim_consts; parent_dir="RL")
|
||||
|
||||
for i in 1:n_particles
|
||||
env = rl_params.envs[i]
|
||||
|
|
14
src/run.jl
14
src/run.jl
|
@ -11,12 +11,11 @@ function run_sim(
|
|||
snapshot_at::Float64=0.1,
|
||||
seed::Int64=42,
|
||||
n_bundle_snapshots::Int64=100,
|
||||
rl_params::Union{ReCo.Params,Nothing}=nothing,
|
||||
rl_params::Union{RL.Params,Nothing}=nothing,
|
||||
)
|
||||
@assert length(dir) > 0
|
||||
@assert duration > 0
|
||||
@assert snapshot_at in 0.001:0.001:duration
|
||||
@assert duration % snapshot_at == 0
|
||||
@assert seed > 0
|
||||
@assert n_bundle_snapshots >= 0
|
||||
|
||||
|
@ -29,6 +28,8 @@ function run_sim(
|
|||
n_steps_before_snapshot = round(Int64, snapshot_at / sim_consts.δt)
|
||||
n_snapshots = floor(Int64, integration_steps / n_steps_before_snapshot) + 1
|
||||
|
||||
@assert (n_snapshots - 1) * snapshot_at == duration
|
||||
|
||||
n_bundle_snapshots = min(n_snapshots, n_bundle_snapshots)
|
||||
|
||||
sim_state = JSON3.read(read("$dir/sim_state.json", String))
|
||||
|
@ -107,9 +108,9 @@ function run_sim(
|
|||
)
|
||||
|
||||
if !isnothing(rl_params)
|
||||
pre_integration_hook! = ReCoRL.pre_integration_hook
|
||||
integration_hook = ReCoRL.integration_hook
|
||||
post_integration_hook = ReCoRL.post_integration_hook
|
||||
pre_integration_hook! = RL.pre_integration_hook!
|
||||
integration_hook = RL.integration_hook
|
||||
post_integration_hook = RL.post_integration_hook
|
||||
else
|
||||
pre_integration_hook! = empty_hook
|
||||
integration_hook = empty_hook
|
||||
|
@ -118,10 +119,9 @@ function run_sim(
|
|||
|
||||
simulate(
|
||||
args,
|
||||
sim_consts.δt,
|
||||
T0,
|
||||
T,
|
||||
sim.consts.n_steps_before_verlet_list_update,
|
||||
sim_consts.n_steps_before_verlet_list_update,
|
||||
n_steps_before_snapshot,
|
||||
n_bundles,
|
||||
dir,
|
||||
|
|
39
src/setup.jl
39
src/setup.jl
|
@ -3,7 +3,7 @@ using Dates: now
|
|||
|
||||
const DEFAULT_PACKING_RATIO = 0.5
|
||||
const DEFAULT_δt = 1e-5
|
||||
const DEFAULT_SKIN_TO_INTERACTION_R_RATIO = 3.0
|
||||
const DEFAULT_SKIN_TO_INTERACTION_R_RATIO = 1.5
|
||||
const DEFAULT_EXPORTS_DIR = "exports"
|
||||
const DEFAULT_PARENT_DIR = ""
|
||||
const DEFAULT_COMMENT = ""
|
||||
|
@ -47,7 +47,7 @@ end
|
|||
|
||||
function gen_sim_consts(
|
||||
n_particles::Int64,
|
||||
v₀::Float64,
|
||||
v₀::Float64;
|
||||
δt::Float64=DEFAULT_δt,
|
||||
packing_ratio::Float64=DEFAULT_PACKING_RATIO,
|
||||
skin_to_interaction_r_ratio::Float64=DEFAULT_SKIN_TO_INTERACTION_R_RATIO,
|
||||
|
@ -59,19 +59,26 @@ function gen_sim_consts(
|
|||
|
||||
μ = 1.0
|
||||
D₀ = 1.0
|
||||
particle_radius = 1.0
|
||||
particle_radius = 0.5
|
||||
Dᵣ = 3 * D₀ / ((2 * particle_radius)^2)
|
||||
σ = 1.0
|
||||
ϵ = 100.0
|
||||
interaction_r = 2^(1 / 6) * σ
|
||||
|
||||
buffer = 6.0
|
||||
max_approach_after_one_integration_step = buffer * (2 * v₀ * δt) / interaction_r
|
||||
@assert skin_to_interaction_r_ratio >= 1 + max_approach_after_one_integration_step
|
||||
if v₀ != 0.0
|
||||
buffer = 1.8
|
||||
max_approach_after_one_integration_step = buffer * (2 * v₀ * δt) / interaction_r
|
||||
@assert skin_to_interaction_r_ratio >= 1 + max_approach_after_one_integration_step
|
||||
|
||||
skin_r = skin_to_interaction_r_ratio * interaction_r
|
||||
n_steps_before_verlet_list_update =
|
||||
(skin_to_interaction_r_ratio - 1) / max_approach_after_one_integration_step
|
||||
skin_r = skin_to_interaction_r_ratio * interaction_r
|
||||
n_steps_before_verlet_list_update = round(
|
||||
Int64,
|
||||
(skin_to_interaction_r_ratio - 1) / max_approach_after_one_integration_step,
|
||||
)
|
||||
else
|
||||
skin_r = 1.5 * interaction_r
|
||||
n_steps_before_verlet_list_update = 100
|
||||
end
|
||||
|
||||
grid_n = round(Int64, ceil(sqrt(n_particles)))
|
||||
|
||||
|
@ -103,7 +110,7 @@ function gen_sim_consts(
|
|||
end
|
||||
|
||||
function init_sim_with_sim_consts(
|
||||
sim_consts,
|
||||
sim_consts;
|
||||
exports_dir::String=DEFAULT_EXPORTS_DIR,
|
||||
parent_dir::String=DEFAULT_PARENT_DIR,
|
||||
comment::String=DEFAULT_COMMENT,
|
||||
|
@ -111,7 +118,7 @@ function init_sim_with_sim_consts(
|
|||
particles = gen_particles(
|
||||
sim_consts.grid_n, sim_consts.grid_box_width, sim_consts.half_box_len
|
||||
)
|
||||
bundle = Bundle(n_particles, 1)
|
||||
bundle = Bundle(sim_consts.n_particles, 1)
|
||||
save_snapshot!(bundle, 1, 0.0, particles)
|
||||
|
||||
dir = exports_dir
|
||||
|
@ -152,8 +159,14 @@ function init_sim(;
|
|||
comment::String=DEFAULT_COMMENT,
|
||||
)
|
||||
sim_consts = gen_sim_consts(
|
||||
n_particles, v₀, δt, packing_ratio, skin_to_interaction_r_ratio
|
||||
n_particles,
|
||||
v₀;
|
||||
δt=δt,
|
||||
packing_ratio=packing_ratio,
|
||||
skin_to_interaction_r_ratio=skin_to_interaction_r_ratio,
|
||||
)
|
||||
|
||||
return init_sim_with_sim_consts(sim_consts, exports_dir, parent_dir, comment)[1]
|
||||
return init_sim_with_sim_consts(
|
||||
sim_consts; exports_dir=exports_dir, parent_dir=parent_dir, comment=comment
|
||||
)[1]
|
||||
end
|
|
@ -41,7 +41,7 @@ function update_verlet_lists!(args, cl)
|
|||
end
|
||||
|
||||
function euler!(
|
||||
args, state_hook::F, integration_hook::F, rl_params::Union{ReCoRL.Params,Nothing}
|
||||
args, state_hook::F, integration_hook::F, rl_params::Union{RL.Params,Nothing}
|
||||
) where {F<:Function}
|
||||
for id1 in 1:(args.n_particles - 1)
|
||||
p1 = args.particles[id1]
|
||||
|
@ -90,14 +90,13 @@ wait(::Nothing) = nothing
|
|||
|
||||
gen_run_hooks(::Nothing, args...) = false
|
||||
|
||||
function gen_run_hooks(rl_params::ReCoRL.Params, integration_step::Int64)
|
||||
return (itegration_step == 1) ||
|
||||
function gen_run_hooks(rl_params::RL.Params, integration_step::Int64)
|
||||
return (integration_step == 1) ||
|
||||
(integration_step % rl_params.n_steps_before_actions_update == 0)
|
||||
end
|
||||
|
||||
function simulate(
|
||||
args,
|
||||
δt::Float64,
|
||||
T0::Float64,
|
||||
T::Float64,
|
||||
n_steps_before_verlet_list_update::Int64,
|
||||
|
@ -105,11 +104,11 @@ function simulate(
|
|||
n_bundles::Int64,
|
||||
dir::String,
|
||||
save_data::Bool,
|
||||
rl_params::Union{ReCoRL.Params,Nothing},
|
||||
pre_integration_hook!::F,
|
||||
integration_hook::F,
|
||||
post_integration_hook::F,
|
||||
) where {F<:Function}
|
||||
rl_params::Union{RL.Params,Nothing},
|
||||
pre_integration_hook!::Function,
|
||||
integration_hook::Function,
|
||||
post_integration_hook::Function,
|
||||
)
|
||||
bundle_snapshot_counter = 0
|
||||
|
||||
task::Union{Task,Nothing} = nothing
|
||||
|
@ -123,7 +122,7 @@ function simulate(
|
|||
start_time = now()
|
||||
println("Started simulation at $start_time.")
|
||||
|
||||
@showprogress 0.6 for (integration_step, t) in enumerate(T0:δt:T)
|
||||
@showprogress 0.6 for (integration_step, t) in enumerate(T0:(args.δt):T)
|
||||
if (integration_step % n_steps_before_snapshot == 0) && save_data
|
||||
wait(task)
|
||||
|
||||
|
@ -148,10 +147,10 @@ function simulate(
|
|||
|
||||
if run_hooks
|
||||
pre_integration_hook!(rl_params, args.n_particles)
|
||||
state_hook = ReCoRL.state_hook
|
||||
state_hook = RL.state_hook
|
||||
end
|
||||
|
||||
euler!(args, state_hook, integration_hook, rl.params.actions)
|
||||
euler!(args, state_hook, integration_hook, rl_params)
|
||||
|
||||
if run_hooks
|
||||
post_integration_hook(rl_params, args.n_particles, args.particles)
|
||||
|
|
Loading…
Reference in a new issue