1
0
Fork 0
mirror of https://gitlab.rlp.net/mobitar/ReCo.jl.git synced 2024-10-11 20:34:22 +00:00
This commit is contained in:
MoBit 2021-12-13 00:19:18 +01:00
parent dc6a4ee7b7
commit 3036c5e65a
6 changed files with 69 additions and 56 deletions

View file

@ -10,6 +10,7 @@ CellListMap = "69e1c6dd-3888-40e6-b3c8-31ac5f578864"
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a"
Intervals = "d8418881-c3e1-53bb-8760-2df7ec849ed5"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"

View file

@ -28,9 +28,9 @@ function animate_bundle!(args)
if args.debug
args.interaction_circles[][i] = Circle(
Point2(bundle_c[i, frame]), args.interaction_r
Point2(c[1], c[2]), args.interaction_r
)
args.skin_circles[][i] = Circle(Point2(bundle_c[i, frame]), args.skin_r)
args.skin_circles[][i] = Circle(Point2(c[1], c[2]), args.skin_r)
args.interaction_colors[][i] = RGBAf(color, 0.08)
args.skin_colors[][i] = RGBAf(color, 0.04)
@ -146,16 +146,11 @@ function animate_after_loading(dir, animation_path, sim_consts, framerate, debug
bundle_paths = bundle_paths[sort_perm]
sort_perm = nothing
skin_r = 0.0
skin_r = sim_consts.skin_r
@showprogress 1 for (n_bundle, bundle_path) in enumerate(bundle_paths)
bundle::Bundle = JLD2.load_object(bundle_path)
run_params_file = "$dir/runs/run_params_$n_bundle.json"
if debug && isfile(run_params_file)
skin_r::Float64 = JSON3.read(read(run_params_file, String)).skin_r
end
args = (;
io,
ax,

View file

@ -1,9 +1,12 @@
module ReCoRL
module RL
using ReinforcementLearning
using Flux: InvDecay
using Intervals
using StaticArrays: SVector
using Random: Random
using ..ReCo
import Base: run
@ -16,7 +19,7 @@ struct DistanceState{L<:Bound}
end
struct DirectionState
interval::Interval{Float64,Open,Closed}
interval::Interval{Float64,Closed,Open}
function DirectionState(lower::Float64, upper::Float64)
return new(Interval{Float64,Closed,Open}(lower, upper))
@ -43,7 +46,7 @@ mutable struct EnvParams
@assert min_distance > 0.0
@assert max_distance > min_distance
@assert n_v_actions > 1
@assert n_ω_actoins > 1
@assert n_ω_actions > 1
@assert max_v > 0
@assert max_ω > 0
@ -89,6 +92,8 @@ mutable struct EnvParams
)
end
n_states = n_distance_states * n_direction_states + 1
state_space = Vector{
Tuple{Union{DistanceState,Nothing},Union{DirectionState,Nothing}}
}(
@ -112,10 +117,10 @@ end
mutable struct Env <: AbstractEnv
params::EnvParams
particle::Particle
particle::ReCo.Particle
state::Tuple{Union{DistanceState,Nothing},Union{DirectionState,Nothing}}
function Env(params::EnvParams, particle::Particle)
function Env(params::EnvParams, particle::ReCo.Particle)
return new(params, particle, (nothing, nothing))
end
end
@ -145,16 +150,14 @@ struct Params{H<:AbstractHook}
n_particles::Int64, env_params::EnvParams, n_steps_before_actions_update::Int64
) where {H<:AbstractHook}
policies = [
gen_policy(
length(rl_params.env_params.state_space),
length(rl_params.env_params.action_space),
) for i in 1:n_particles
gen_policy(length(env_params.state_space), length(env_params.action_space)) for
i in 1:n_particles
]
agents = [
Agent(; policy=policy, trajectory=VectorSARTTrajectory()) for policy in policies
]
return new(
[Env(env_params, gen_tmp_particle()) for i in 1:n_particles],
[Env(env_params, ReCo.gen_tmp_particle()) for i in 1:n_particles],
agents,
[H() for i in 1:n_particles],
Vector{Tuple{Float64,Float64}}(undef, n_particles),
@ -206,7 +209,7 @@ function state_hook(
return nothing
end
function integration_hook(particle::Particle, rl_params::Params, δt::Float64)
function integration_hook(particle::ReCo.Particle, rl_params::Params, δt::Float64)
action = rl_params.actions[particle.id]
particle.tmp_c += action[1] * δt
@ -216,7 +219,7 @@ function integration_hook(particle::Particle, rl_params::Params, δt::Float64)
end
function post_integration_hook(
rl_params::Params, n_particles::Int64, particles::Vector{Particle}
rl_params::Params, n_particles::Int64, particles::Vector{ReCo.Particle}
)
env_direction_state = rl_params.env_params.direction_state_space[1]
@ -266,7 +269,7 @@ function post_integration_hook(
return nothing
end
function run(
function run(;
n_episodes::Int64=100,
episode_duration::Float64=5.0,
update_actions_at::Float64=0.1,
@ -275,15 +278,17 @@ function run(
@assert n_episodes > 0
@assert episode_duration > 0
@assert update_actions_at in 0.01:0.01:episode_duration
@assert episode_duration % update_actions_at == 0
@assert n_particles > 0
Random.seed!(42)
sim_consts = gen_sim_consts(n_particles, v₀)
sim_consts = ReCo.gen_sim_consts(n_particles, 0.0; skin_to_interaction_r_ratio=3.5)
n_particles = sim_consts.n_particles
env_params = EnvParams(sim_consts.particle_radius, sim_consts.skin_r)
n_steps_before_actions_update = round(Int64, update_actions_at / sim_consts.δt)
rl_params = Params{TotalRewardPerEpisode}(
n_particles, env_params, n_steps_before_actions_update
)
@ -297,7 +302,7 @@ function run(
end
for episode in 1:n_episodes
dir, particles = init_sim_with_sim_consts(; sim_consts, parent_dir="RL")
dir, particles = ReCo.init_sim_with_sim_consts(sim_consts; parent_dir="RL")
for i in 1:n_particles
env = rl_params.envs[i]

View file

@ -11,12 +11,11 @@ function run_sim(
snapshot_at::Float64=0.1,
seed::Int64=42,
n_bundle_snapshots::Int64=100,
rl_params::Union{ReCo.Params,Nothing}=nothing,
rl_params::Union{RL.Params,Nothing}=nothing,
)
@assert length(dir) > 0
@assert duration > 0
@assert snapshot_at in 0.001:0.001:duration
@assert duration % snapshot_at == 0
@assert seed > 0
@assert n_bundle_snapshots >= 0
@ -29,6 +28,8 @@ function run_sim(
n_steps_before_snapshot = round(Int64, snapshot_at / sim_consts.δt)
n_snapshots = floor(Int64, integration_steps / n_steps_before_snapshot) + 1
@assert (n_snapshots - 1) * snapshot_at == duration
n_bundle_snapshots = min(n_snapshots, n_bundle_snapshots)
sim_state = JSON3.read(read("$dir/sim_state.json", String))
@ -107,9 +108,9 @@ function run_sim(
)
if !isnothing(rl_params)
pre_integration_hook! = ReCoRL.pre_integration_hook
integration_hook = ReCoRL.integration_hook
post_integration_hook = ReCoRL.post_integration_hook
pre_integration_hook! = RL.pre_integration_hook!
integration_hook = RL.integration_hook
post_integration_hook = RL.post_integration_hook
else
pre_integration_hook! = empty_hook
integration_hook = empty_hook
@ -118,10 +119,9 @@ function run_sim(
simulate(
args,
sim_consts.δt,
T0,
T,
sim.consts.n_steps_before_verlet_list_update,
sim_consts.n_steps_before_verlet_list_update,
n_steps_before_snapshot,
n_bundles,
dir,

View file

@ -3,7 +3,7 @@ using Dates: now
const DEFAULT_PACKING_RATIO = 0.5
const DEFAULT_δt = 1e-5
const DEFAULT_SKIN_TO_INTERACTION_R_RATIO = 3.0
const DEFAULT_SKIN_TO_INTERACTION_R_RATIO = 1.5
const DEFAULT_EXPORTS_DIR = "exports"
const DEFAULT_PARENT_DIR = ""
const DEFAULT_COMMENT = ""
@ -47,7 +47,7 @@ end
function gen_sim_consts(
n_particles::Int64,
v₀::Float64,
v₀::Float64;
δt::Float64=DEFAULT_δt,
packing_ratio::Float64=DEFAULT_PACKING_RATIO,
skin_to_interaction_r_ratio::Float64=DEFAULT_SKIN_TO_INTERACTION_R_RATIO,
@ -59,19 +59,26 @@ function gen_sim_consts(
μ = 1.0
D₀ = 1.0
particle_radius = 1.0
particle_radius = 0.5
Dᵣ = 3 * D₀ / ((2 * particle_radius)^2)
σ = 1.0
ϵ = 100.0
interaction_r = 2^(1 / 6) * σ
buffer = 6.0
max_approach_after_one_integration_step = buffer * (2 * v₀ * δt) / interaction_r
@assert skin_to_interaction_r_ratio >= 1 + max_approach_after_one_integration_step
if v₀ != 0.0
buffer = 1.8
max_approach_after_one_integration_step = buffer * (2 * v₀ * δt) / interaction_r
@assert skin_to_interaction_r_ratio >= 1 + max_approach_after_one_integration_step
skin_r = skin_to_interaction_r_ratio * interaction_r
n_steps_before_verlet_list_update =
(skin_to_interaction_r_ratio - 1) / max_approach_after_one_integration_step
skin_r = skin_to_interaction_r_ratio * interaction_r
n_steps_before_verlet_list_update = round(
Int64,
(skin_to_interaction_r_ratio - 1) / max_approach_after_one_integration_step,
)
else
skin_r = 1.5 * interaction_r
n_steps_before_verlet_list_update = 100
end
grid_n = round(Int64, ceil(sqrt(n_particles)))
@ -103,7 +110,7 @@ function gen_sim_consts(
end
function init_sim_with_sim_consts(
sim_consts,
sim_consts;
exports_dir::String=DEFAULT_EXPORTS_DIR,
parent_dir::String=DEFAULT_PARENT_DIR,
comment::String=DEFAULT_COMMENT,
@ -111,7 +118,7 @@ function init_sim_with_sim_consts(
particles = gen_particles(
sim_consts.grid_n, sim_consts.grid_box_width, sim_consts.half_box_len
)
bundle = Bundle(n_particles, 1)
bundle = Bundle(sim_consts.n_particles, 1)
save_snapshot!(bundle, 1, 0.0, particles)
dir = exports_dir
@ -152,8 +159,14 @@ function init_sim(;
comment::String=DEFAULT_COMMENT,
)
sim_consts = gen_sim_consts(
n_particles, v₀, δt, packing_ratio, skin_to_interaction_r_ratio
n_particles,
v₀;
δt=δt,
packing_ratio=packing_ratio,
skin_to_interaction_r_ratio=skin_to_interaction_r_ratio,
)
return init_sim_with_sim_consts(sim_consts, exports_dir, parent_dir, comment)[1]
return init_sim_with_sim_consts(
sim_consts; exports_dir=exports_dir, parent_dir=parent_dir, comment=comment
)[1]
end

View file

@ -41,7 +41,7 @@ function update_verlet_lists!(args, cl)
end
function euler!(
args, state_hook::F, integration_hook::F, rl_params::Union{ReCoRL.Params,Nothing}
args, state_hook::F, integration_hook::F, rl_params::Union{RL.Params,Nothing}
) where {F<:Function}
for id1 in 1:(args.n_particles - 1)
p1 = args.particles[id1]
@ -90,14 +90,13 @@ wait(::Nothing) = nothing
gen_run_hooks(::Nothing, args...) = false
function gen_run_hooks(rl_params::ReCoRL.Params, integration_step::Int64)
return (itegration_step == 1) ||
function gen_run_hooks(rl_params::RL.Params, integration_step::Int64)
return (integration_step == 1) ||
(integration_step % rl_params.n_steps_before_actions_update == 0)
end
function simulate(
args,
δt::Float64,
T0::Float64,
T::Float64,
n_steps_before_verlet_list_update::Int64,
@ -105,11 +104,11 @@ function simulate(
n_bundles::Int64,
dir::String,
save_data::Bool,
rl_params::Union{ReCoRL.Params,Nothing},
pre_integration_hook!::F,
integration_hook::F,
post_integration_hook::F,
) where {F<:Function}
rl_params::Union{RL.Params,Nothing},
pre_integration_hook!::Function,
integration_hook::Function,
post_integration_hook::Function,
)
bundle_snapshot_counter = 0
task::Union{Task,Nothing} = nothing
@ -123,7 +122,7 @@ function simulate(
start_time = now()
println("Started simulation at $start_time.")
@showprogress 0.6 for (integration_step, t) in enumerate(T0:δt:T)
@showprogress 0.6 for (integration_step, t) in enumerate(T0:(args.δt):T)
if (integration_step % n_steps_before_snapshot == 0) && save_data
wait(task)
@ -148,10 +147,10 @@ function simulate(
if run_hooks
pre_integration_hook!(rl_params, args.n_particles)
state_hook = ReCoRL.state_hook
state_hook = RL.state_hook
end
euler!(args, state_hook, integration_hook, rl.params.actions)
euler!(args, state_hook, integration_hook, rl_params)
if run_hooks
post_integration_hook(rl_params, args.n_particles, args.particles)