1
0
Fork 0
mirror of https://gitlab.rlp.net/mobitar/ReCo.jl.git synced 2024-11-08 22:21:08 +00:00
This commit is contained in:
MoBit 2021-12-13 00:19:18 +01:00
parent dc6a4ee7b7
commit 3036c5e65a
6 changed files with 69 additions and 56 deletions

View file

@ -10,6 +10,7 @@ CellListMap = "69e1c6dd-3888-40e6-b3c8-31ac5f578864"
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4" ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a" GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a"
Intervals = "d8418881-c3e1-53bb-8760-2df7ec849ed5" Intervals = "d8418881-c3e1-53bb-8760-2df7ec849ed5"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"

View file

@ -28,9 +28,9 @@ function animate_bundle!(args)
if args.debug if args.debug
args.interaction_circles[][i] = Circle( args.interaction_circles[][i] = Circle(
Point2(bundle_c[i, frame]), args.interaction_r Point2(c[1], c[2]), args.interaction_r
) )
args.skin_circles[][i] = Circle(Point2(bundle_c[i, frame]), args.skin_r) args.skin_circles[][i] = Circle(Point2(c[1], c[2]), args.skin_r)
args.interaction_colors[][i] = RGBAf(color, 0.08) args.interaction_colors[][i] = RGBAf(color, 0.08)
args.skin_colors[][i] = RGBAf(color, 0.04) args.skin_colors[][i] = RGBAf(color, 0.04)
@ -146,16 +146,11 @@ function animate_after_loading(dir, animation_path, sim_consts, framerate, debug
bundle_paths = bundle_paths[sort_perm] bundle_paths = bundle_paths[sort_perm]
sort_perm = nothing sort_perm = nothing
skin_r = 0.0 skin_r = sim_consts.skin_r
@showprogress 1 for (n_bundle, bundle_path) in enumerate(bundle_paths) @showprogress 1 for (n_bundle, bundle_path) in enumerate(bundle_paths)
bundle::Bundle = JLD2.load_object(bundle_path) bundle::Bundle = JLD2.load_object(bundle_path)
run_params_file = "$dir/runs/run_params_$n_bundle.json"
if debug && isfile(run_params_file)
skin_r::Float64 = JSON3.read(read(run_params_file, String)).skin_r
end
args = (; args = (;
io, io,
ax, ax,

View file

@ -1,9 +1,12 @@
module ReCoRL module RL
using ReinforcementLearning using ReinforcementLearning
using Flux: InvDecay using Flux: InvDecay
using Intervals using Intervals
using StaticArrays: SVector using StaticArrays: SVector
using Random: Random
using ..ReCo
import Base: run import Base: run
@ -16,7 +19,7 @@ struct DistanceState{L<:Bound}
end end
struct DirectionState struct DirectionState
interval::Interval{Float64,Open,Closed} interval::Interval{Float64,Closed,Open}
function DirectionState(lower::Float64, upper::Float64) function DirectionState(lower::Float64, upper::Float64)
return new(Interval{Float64,Closed,Open}(lower, upper)) return new(Interval{Float64,Closed,Open}(lower, upper))
@ -43,7 +46,7 @@ mutable struct EnvParams
@assert min_distance > 0.0 @assert min_distance > 0.0
@assert max_distance > min_distance @assert max_distance > min_distance
@assert n_v_actions > 1 @assert n_v_actions > 1
@assert n_ω_actoins > 1 @assert n_ω_actions > 1
@assert max_v > 0 @assert max_v > 0
@assert max_ω > 0 @assert max_ω > 0
@ -89,6 +92,8 @@ mutable struct EnvParams
) )
end end
n_states = n_distance_states * n_direction_states + 1
state_space = Vector{ state_space = Vector{
Tuple{Union{DistanceState,Nothing},Union{DirectionState,Nothing}} Tuple{Union{DistanceState,Nothing},Union{DirectionState,Nothing}}
}( }(
@ -112,10 +117,10 @@ end
mutable struct Env <: AbstractEnv mutable struct Env <: AbstractEnv
params::EnvParams params::EnvParams
particle::Particle particle::ReCo.Particle
state::Tuple{Union{DistanceState,Nothing},Union{DirectionState,Nothing}} state::Tuple{Union{DistanceState,Nothing},Union{DirectionState,Nothing}}
function Env(params::EnvParams, particle::Particle) function Env(params::EnvParams, particle::ReCo.Particle)
return new(params, particle, (nothing, nothing)) return new(params, particle, (nothing, nothing))
end end
end end
@ -145,16 +150,14 @@ struct Params{H<:AbstractHook}
n_particles::Int64, env_params::EnvParams, n_steps_before_actions_update::Int64 n_particles::Int64, env_params::EnvParams, n_steps_before_actions_update::Int64
) where {H<:AbstractHook} ) where {H<:AbstractHook}
policies = [ policies = [
gen_policy( gen_policy(length(env_params.state_space), length(env_params.action_space)) for
length(rl_params.env_params.state_space), i in 1:n_particles
length(rl_params.env_params.action_space),
) for i in 1:n_particles
] ]
agents = [ agents = [
Agent(; policy=policy, trajectory=VectorSARTTrajectory()) for policy in policies Agent(; policy=policy, trajectory=VectorSARTTrajectory()) for policy in policies
] ]
return new( return new(
[Env(env_params, gen_tmp_particle()) for i in 1:n_particles], [Env(env_params, ReCo.gen_tmp_particle()) for i in 1:n_particles],
agents, agents,
[H() for i in 1:n_particles], [H() for i in 1:n_particles],
Vector{Tuple{Float64,Float64}}(undef, n_particles), Vector{Tuple{Float64,Float64}}(undef, n_particles),
@ -206,7 +209,7 @@ function state_hook(
return nothing return nothing
end end
function integration_hook(particle::Particle, rl_params::Params, δt::Float64) function integration_hook(particle::ReCo.Particle, rl_params::Params, δt::Float64)
action = rl_params.actions[particle.id] action = rl_params.actions[particle.id]
particle.tmp_c += action[1] * δt particle.tmp_c += action[1] * δt
@ -216,7 +219,7 @@ function integration_hook(particle::Particle, rl_params::Params, δt::Float64)
end end
function post_integration_hook( function post_integration_hook(
rl_params::Params, n_particles::Int64, particles::Vector{Particle} rl_params::Params, n_particles::Int64, particles::Vector{ReCo.Particle}
) )
env_direction_state = rl_params.env_params.direction_state_space[1] env_direction_state = rl_params.env_params.direction_state_space[1]
@ -266,7 +269,7 @@ function post_integration_hook(
return nothing return nothing
end end
function run( function run(;
n_episodes::Int64=100, n_episodes::Int64=100,
episode_duration::Float64=5.0, episode_duration::Float64=5.0,
update_actions_at::Float64=0.1, update_actions_at::Float64=0.1,
@ -275,15 +278,17 @@ function run(
@assert n_episodes > 0 @assert n_episodes > 0
@assert episode_duration > 0 @assert episode_duration > 0
@assert update_actions_at in 0.01:0.01:episode_duration @assert update_actions_at in 0.01:0.01:episode_duration
@assert episode_duration % update_actions_at == 0
@assert n_particles > 0 @assert n_particles > 0
Random.seed!(42) Random.seed!(42)
sim_consts = gen_sim_consts(n_particles, v₀) sim_consts = ReCo.gen_sim_consts(n_particles, 0.0; skin_to_interaction_r_ratio=3.5)
n_particles = sim_consts.n_particles
env_params = EnvParams(sim_consts.particle_radius, sim_consts.skin_r) env_params = EnvParams(sim_consts.particle_radius, sim_consts.skin_r)
n_steps_before_actions_update = round(Int64, update_actions_at / sim_consts.δt)
rl_params = Params{TotalRewardPerEpisode}( rl_params = Params{TotalRewardPerEpisode}(
n_particles, env_params, n_steps_before_actions_update n_particles, env_params, n_steps_before_actions_update
) )
@ -297,7 +302,7 @@ function run(
end end
for episode in 1:n_episodes for episode in 1:n_episodes
dir, particles = init_sim_with_sim_consts(; sim_consts, parent_dir="RL") dir, particles = ReCo.init_sim_with_sim_consts(sim_consts; parent_dir="RL")
for i in 1:n_particles for i in 1:n_particles
env = rl_params.envs[i] env = rl_params.envs[i]

View file

@ -11,12 +11,11 @@ function run_sim(
snapshot_at::Float64=0.1, snapshot_at::Float64=0.1,
seed::Int64=42, seed::Int64=42,
n_bundle_snapshots::Int64=100, n_bundle_snapshots::Int64=100,
rl_params::Union{ReCo.Params,Nothing}=nothing, rl_params::Union{RL.Params,Nothing}=nothing,
) )
@assert length(dir) > 0 @assert length(dir) > 0
@assert duration > 0 @assert duration > 0
@assert snapshot_at in 0.001:0.001:duration @assert snapshot_at in 0.001:0.001:duration
@assert duration % snapshot_at == 0
@assert seed > 0 @assert seed > 0
@assert n_bundle_snapshots >= 0 @assert n_bundle_snapshots >= 0
@ -29,6 +28,8 @@ function run_sim(
n_steps_before_snapshot = round(Int64, snapshot_at / sim_consts.δt) n_steps_before_snapshot = round(Int64, snapshot_at / sim_consts.δt)
n_snapshots = floor(Int64, integration_steps / n_steps_before_snapshot) + 1 n_snapshots = floor(Int64, integration_steps / n_steps_before_snapshot) + 1
@assert (n_snapshots - 1) * snapshot_at == duration
n_bundle_snapshots = min(n_snapshots, n_bundle_snapshots) n_bundle_snapshots = min(n_snapshots, n_bundle_snapshots)
sim_state = JSON3.read(read("$dir/sim_state.json", String)) sim_state = JSON3.read(read("$dir/sim_state.json", String))
@ -107,9 +108,9 @@ function run_sim(
) )
if !isnothing(rl_params) if !isnothing(rl_params)
pre_integration_hook! = ReCoRL.pre_integration_hook pre_integration_hook! = RL.pre_integration_hook!
integration_hook = ReCoRL.integration_hook integration_hook = RL.integration_hook
post_integration_hook = ReCoRL.post_integration_hook post_integration_hook = RL.post_integration_hook
else else
pre_integration_hook! = empty_hook pre_integration_hook! = empty_hook
integration_hook = empty_hook integration_hook = empty_hook
@ -118,10 +119,9 @@ function run_sim(
simulate( simulate(
args, args,
sim_consts.δt,
T0, T0,
T, T,
sim.consts.n_steps_before_verlet_list_update, sim_consts.n_steps_before_verlet_list_update,
n_steps_before_snapshot, n_steps_before_snapshot,
n_bundles, n_bundles,
dir, dir,

View file

@ -3,7 +3,7 @@ using Dates: now
const DEFAULT_PACKING_RATIO = 0.5 const DEFAULT_PACKING_RATIO = 0.5
const DEFAULT_δt = 1e-5 const DEFAULT_δt = 1e-5
const DEFAULT_SKIN_TO_INTERACTION_R_RATIO = 3.0 const DEFAULT_SKIN_TO_INTERACTION_R_RATIO = 1.5
const DEFAULT_EXPORTS_DIR = "exports" const DEFAULT_EXPORTS_DIR = "exports"
const DEFAULT_PARENT_DIR = "" const DEFAULT_PARENT_DIR = ""
const DEFAULT_COMMENT = "" const DEFAULT_COMMENT = ""
@ -47,7 +47,7 @@ end
function gen_sim_consts( function gen_sim_consts(
n_particles::Int64, n_particles::Int64,
v₀::Float64, v₀::Float64;
δt::Float64=DEFAULT_δt, δt::Float64=DEFAULT_δt,
packing_ratio::Float64=DEFAULT_PACKING_RATIO, packing_ratio::Float64=DEFAULT_PACKING_RATIO,
skin_to_interaction_r_ratio::Float64=DEFAULT_SKIN_TO_INTERACTION_R_RATIO, skin_to_interaction_r_ratio::Float64=DEFAULT_SKIN_TO_INTERACTION_R_RATIO,
@ -59,19 +59,26 @@ function gen_sim_consts(
μ = 1.0 μ = 1.0
D₀ = 1.0 D₀ = 1.0
particle_radius = 1.0 particle_radius = 0.5
Dᵣ = 3 * D₀ / ((2 * particle_radius)^2) Dᵣ = 3 * D₀ / ((2 * particle_radius)^2)
σ = 1.0 σ = 1.0
ϵ = 100.0 ϵ = 100.0
interaction_r = 2^(1 / 6) * σ interaction_r = 2^(1 / 6) * σ
buffer = 6.0 if v₀ != 0.0
buffer = 1.8
max_approach_after_one_integration_step = buffer * (2 * v₀ * δt) / interaction_r max_approach_after_one_integration_step = buffer * (2 * v₀ * δt) / interaction_r
@assert skin_to_interaction_r_ratio >= 1 + max_approach_after_one_integration_step @assert skin_to_interaction_r_ratio >= 1 + max_approach_after_one_integration_step
skin_r = skin_to_interaction_r_ratio * interaction_r skin_r = skin_to_interaction_r_ratio * interaction_r
n_steps_before_verlet_list_update = n_steps_before_verlet_list_update = round(
(skin_to_interaction_r_ratio - 1) / max_approach_after_one_integration_step Int64,
(skin_to_interaction_r_ratio - 1) / max_approach_after_one_integration_step,
)
else
skin_r = 1.5 * interaction_r
n_steps_before_verlet_list_update = 100
end
grid_n = round(Int64, ceil(sqrt(n_particles))) grid_n = round(Int64, ceil(sqrt(n_particles)))
@ -103,7 +110,7 @@ function gen_sim_consts(
end end
function init_sim_with_sim_consts( function init_sim_with_sim_consts(
sim_consts, sim_consts;
exports_dir::String=DEFAULT_EXPORTS_DIR, exports_dir::String=DEFAULT_EXPORTS_DIR,
parent_dir::String=DEFAULT_PARENT_DIR, parent_dir::String=DEFAULT_PARENT_DIR,
comment::String=DEFAULT_COMMENT, comment::String=DEFAULT_COMMENT,
@ -111,7 +118,7 @@ function init_sim_with_sim_consts(
particles = gen_particles( particles = gen_particles(
sim_consts.grid_n, sim_consts.grid_box_width, sim_consts.half_box_len sim_consts.grid_n, sim_consts.grid_box_width, sim_consts.half_box_len
) )
bundle = Bundle(n_particles, 1) bundle = Bundle(sim_consts.n_particles, 1)
save_snapshot!(bundle, 1, 0.0, particles) save_snapshot!(bundle, 1, 0.0, particles)
dir = exports_dir dir = exports_dir
@ -152,8 +159,14 @@ function init_sim(;
comment::String=DEFAULT_COMMENT, comment::String=DEFAULT_COMMENT,
) )
sim_consts = gen_sim_consts( sim_consts = gen_sim_consts(
n_particles, v₀, δt, packing_ratio, skin_to_interaction_r_ratio n_particles,
v₀;
δt=δt,
packing_ratio=packing_ratio,
skin_to_interaction_r_ratio=skin_to_interaction_r_ratio,
) )
return init_sim_with_sim_consts(sim_consts, exports_dir, parent_dir, comment)[1] return init_sim_with_sim_consts(
sim_consts; exports_dir=exports_dir, parent_dir=parent_dir, comment=comment
)[1]
end end

View file

@ -41,7 +41,7 @@ function update_verlet_lists!(args, cl)
end end
function euler!( function euler!(
args, state_hook::F, integration_hook::F, rl_params::Union{ReCoRL.Params,Nothing} args, state_hook::F, integration_hook::F, rl_params::Union{RL.Params,Nothing}
) where {F<:Function} ) where {F<:Function}
for id1 in 1:(args.n_particles - 1) for id1 in 1:(args.n_particles - 1)
p1 = args.particles[id1] p1 = args.particles[id1]
@ -90,14 +90,13 @@ wait(::Nothing) = nothing
gen_run_hooks(::Nothing, args...) = false gen_run_hooks(::Nothing, args...) = false
function gen_run_hooks(rl_params::ReCoRL.Params, integration_step::Int64) function gen_run_hooks(rl_params::RL.Params, integration_step::Int64)
return (itegration_step == 1) || return (integration_step == 1) ||
(integration_step % rl_params.n_steps_before_actions_update == 0) (integration_step % rl_params.n_steps_before_actions_update == 0)
end end
function simulate( function simulate(
args, args,
δt::Float64,
T0::Float64, T0::Float64,
T::Float64, T::Float64,
n_steps_before_verlet_list_update::Int64, n_steps_before_verlet_list_update::Int64,
@ -105,11 +104,11 @@ function simulate(
n_bundles::Int64, n_bundles::Int64,
dir::String, dir::String,
save_data::Bool, save_data::Bool,
rl_params::Union{ReCoRL.Params,Nothing}, rl_params::Union{RL.Params,Nothing},
pre_integration_hook!::F, pre_integration_hook!::Function,
integration_hook::F, integration_hook::Function,
post_integration_hook::F, post_integration_hook::Function,
) where {F<:Function} )
bundle_snapshot_counter = 0 bundle_snapshot_counter = 0
task::Union{Task,Nothing} = nothing task::Union{Task,Nothing} = nothing
@ -123,7 +122,7 @@ function simulate(
start_time = now() start_time = now()
println("Started simulation at $start_time.") println("Started simulation at $start_time.")
@showprogress 0.6 for (integration_step, t) in enumerate(T0:δt:T) @showprogress 0.6 for (integration_step, t) in enumerate(T0:(args.δt):T)
if (integration_step % n_steps_before_snapshot == 0) && save_data if (integration_step % n_steps_before_snapshot == 0) && save_data
wait(task) wait(task)
@ -148,10 +147,10 @@ function simulate(
if run_hooks if run_hooks
pre_integration_hook!(rl_params, args.n_particles) pre_integration_hook!(rl_params, args.n_particles)
state_hook = ReCoRL.state_hook state_hook = RL.state_hook
end end
euler!(args, state_hook, integration_hook, rl.params.actions) euler!(args, state_hook, integration_hook, rl_params)
if run_hooks if run_hooks
post_integration_hook(rl_params, args.n_particles, args.particles) post_integration_hook(rl_params, args.n_particles, args.particles)