mirror of
https://gitlab.rlp.net/mobitar/ReCo.jl.git
synced 2024-11-08 22:21:08 +00:00
Fixes
This commit is contained in:
parent
dc6a4ee7b7
commit
3036c5e65a
6 changed files with 69 additions and 56 deletions
|
@ -10,6 +10,7 @@ CellListMap = "69e1c6dd-3888-40e6-b3c8-31ac5f578864"
|
||||||
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
|
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
|
||||||
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
|
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
|
||||||
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
|
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
|
||||||
|
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
|
||||||
GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a"
|
GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a"
|
||||||
Intervals = "d8418881-c3e1-53bb-8760-2df7ec849ed5"
|
Intervals = "d8418881-c3e1-53bb-8760-2df7ec849ed5"
|
||||||
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
|
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
|
||||||
|
|
|
@ -28,9 +28,9 @@ function animate_bundle!(args)
|
||||||
|
|
||||||
if args.debug
|
if args.debug
|
||||||
args.interaction_circles[][i] = Circle(
|
args.interaction_circles[][i] = Circle(
|
||||||
Point2(bundle_c[i, frame]), args.interaction_r
|
Point2(c[1], c[2]), args.interaction_r
|
||||||
)
|
)
|
||||||
args.skin_circles[][i] = Circle(Point2(bundle_c[i, frame]), args.skin_r)
|
args.skin_circles[][i] = Circle(Point2(c[1], c[2]), args.skin_r)
|
||||||
|
|
||||||
args.interaction_colors[][i] = RGBAf(color, 0.08)
|
args.interaction_colors[][i] = RGBAf(color, 0.08)
|
||||||
args.skin_colors[][i] = RGBAf(color, 0.04)
|
args.skin_colors[][i] = RGBAf(color, 0.04)
|
||||||
|
@ -146,16 +146,11 @@ function animate_after_loading(dir, animation_path, sim_consts, framerate, debug
|
||||||
bundle_paths = bundle_paths[sort_perm]
|
bundle_paths = bundle_paths[sort_perm]
|
||||||
sort_perm = nothing
|
sort_perm = nothing
|
||||||
|
|
||||||
skin_r = 0.0
|
skin_r = sim_consts.skin_r
|
||||||
|
|
||||||
@showprogress 1 for (n_bundle, bundle_path) in enumerate(bundle_paths)
|
@showprogress 1 for (n_bundle, bundle_path) in enumerate(bundle_paths)
|
||||||
bundle::Bundle = JLD2.load_object(bundle_path)
|
bundle::Bundle = JLD2.load_object(bundle_path)
|
||||||
|
|
||||||
run_params_file = "$dir/runs/run_params_$n_bundle.json"
|
|
||||||
if debug && isfile(run_params_file)
|
|
||||||
skin_r::Float64 = JSON3.read(read(run_params_file, String)).skin_r
|
|
||||||
end
|
|
||||||
|
|
||||||
args = (;
|
args = (;
|
||||||
io,
|
io,
|
||||||
ax,
|
ax,
|
||||||
|
|
|
@ -1,9 +1,12 @@
|
||||||
module ReCoRL
|
module RL
|
||||||
|
|
||||||
using ReinforcementLearning
|
using ReinforcementLearning
|
||||||
using Flux: InvDecay
|
using Flux: InvDecay
|
||||||
using Intervals
|
using Intervals
|
||||||
using StaticArrays: SVector
|
using StaticArrays: SVector
|
||||||
|
using Random: Random
|
||||||
|
|
||||||
|
using ..ReCo
|
||||||
|
|
||||||
import Base: run
|
import Base: run
|
||||||
|
|
||||||
|
@ -16,7 +19,7 @@ struct DistanceState{L<:Bound}
|
||||||
end
|
end
|
||||||
|
|
||||||
struct DirectionState
|
struct DirectionState
|
||||||
interval::Interval{Float64,Open,Closed}
|
interval::Interval{Float64,Closed,Open}
|
||||||
|
|
||||||
function DirectionState(lower::Float64, upper::Float64)
|
function DirectionState(lower::Float64, upper::Float64)
|
||||||
return new(Interval{Float64,Closed,Open}(lower, upper))
|
return new(Interval{Float64,Closed,Open}(lower, upper))
|
||||||
|
@ -43,7 +46,7 @@ mutable struct EnvParams
|
||||||
@assert min_distance > 0.0
|
@assert min_distance > 0.0
|
||||||
@assert max_distance > min_distance
|
@assert max_distance > min_distance
|
||||||
@assert n_v_actions > 1
|
@assert n_v_actions > 1
|
||||||
@assert n_ω_actoins > 1
|
@assert n_ω_actions > 1
|
||||||
@assert max_v > 0
|
@assert max_v > 0
|
||||||
@assert max_ω > 0
|
@assert max_ω > 0
|
||||||
|
|
||||||
|
@ -89,6 +92,8 @@ mutable struct EnvParams
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
n_states = n_distance_states * n_direction_states + 1
|
||||||
|
|
||||||
state_space = Vector{
|
state_space = Vector{
|
||||||
Tuple{Union{DistanceState,Nothing},Union{DirectionState,Nothing}}
|
Tuple{Union{DistanceState,Nothing},Union{DirectionState,Nothing}}
|
||||||
}(
|
}(
|
||||||
|
@ -112,10 +117,10 @@ end
|
||||||
|
|
||||||
mutable struct Env <: AbstractEnv
|
mutable struct Env <: AbstractEnv
|
||||||
params::EnvParams
|
params::EnvParams
|
||||||
particle::Particle
|
particle::ReCo.Particle
|
||||||
state::Tuple{Union{DistanceState,Nothing},Union{DirectionState,Nothing}}
|
state::Tuple{Union{DistanceState,Nothing},Union{DirectionState,Nothing}}
|
||||||
|
|
||||||
function Env(params::EnvParams, particle::Particle)
|
function Env(params::EnvParams, particle::ReCo.Particle)
|
||||||
return new(params, particle, (nothing, nothing))
|
return new(params, particle, (nothing, nothing))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -145,16 +150,14 @@ struct Params{H<:AbstractHook}
|
||||||
n_particles::Int64, env_params::EnvParams, n_steps_before_actions_update::Int64
|
n_particles::Int64, env_params::EnvParams, n_steps_before_actions_update::Int64
|
||||||
) where {H<:AbstractHook}
|
) where {H<:AbstractHook}
|
||||||
policies = [
|
policies = [
|
||||||
gen_policy(
|
gen_policy(length(env_params.state_space), length(env_params.action_space)) for
|
||||||
length(rl_params.env_params.state_space),
|
i in 1:n_particles
|
||||||
length(rl_params.env_params.action_space),
|
|
||||||
) for i in 1:n_particles
|
|
||||||
]
|
]
|
||||||
agents = [
|
agents = [
|
||||||
Agent(; policy=policy, trajectory=VectorSARTTrajectory()) for policy in policies
|
Agent(; policy=policy, trajectory=VectorSARTTrajectory()) for policy in policies
|
||||||
]
|
]
|
||||||
return new(
|
return new(
|
||||||
[Env(env_params, gen_tmp_particle()) for i in 1:n_particles],
|
[Env(env_params, ReCo.gen_tmp_particle()) for i in 1:n_particles],
|
||||||
agents,
|
agents,
|
||||||
[H() for i in 1:n_particles],
|
[H() for i in 1:n_particles],
|
||||||
Vector{Tuple{Float64,Float64}}(undef, n_particles),
|
Vector{Tuple{Float64,Float64}}(undef, n_particles),
|
||||||
|
@ -206,7 +209,7 @@ function state_hook(
|
||||||
return nothing
|
return nothing
|
||||||
end
|
end
|
||||||
|
|
||||||
function integration_hook(particle::Particle, rl_params::Params, δt::Float64)
|
function integration_hook(particle::ReCo.Particle, rl_params::Params, δt::Float64)
|
||||||
action = rl_params.actions[particle.id]
|
action = rl_params.actions[particle.id]
|
||||||
|
|
||||||
particle.tmp_c += action[1] * δt
|
particle.tmp_c += action[1] * δt
|
||||||
|
@ -216,7 +219,7 @@ function integration_hook(particle::Particle, rl_params::Params, δt::Float64)
|
||||||
end
|
end
|
||||||
|
|
||||||
function post_integration_hook(
|
function post_integration_hook(
|
||||||
rl_params::Params, n_particles::Int64, particles::Vector{Particle}
|
rl_params::Params, n_particles::Int64, particles::Vector{ReCo.Particle}
|
||||||
)
|
)
|
||||||
env_direction_state = rl_params.env_params.direction_state_space[1]
|
env_direction_state = rl_params.env_params.direction_state_space[1]
|
||||||
|
|
||||||
|
@ -266,7 +269,7 @@ function post_integration_hook(
|
||||||
return nothing
|
return nothing
|
||||||
end
|
end
|
||||||
|
|
||||||
function run(
|
function run(;
|
||||||
n_episodes::Int64=100,
|
n_episodes::Int64=100,
|
||||||
episode_duration::Float64=5.0,
|
episode_duration::Float64=5.0,
|
||||||
update_actions_at::Float64=0.1,
|
update_actions_at::Float64=0.1,
|
||||||
|
@ -275,15 +278,17 @@ function run(
|
||||||
@assert n_episodes > 0
|
@assert n_episodes > 0
|
||||||
@assert episode_duration > 0
|
@assert episode_duration > 0
|
||||||
@assert update_actions_at in 0.01:0.01:episode_duration
|
@assert update_actions_at in 0.01:0.01:episode_duration
|
||||||
@assert episode_duration % update_actions_at == 0
|
|
||||||
@assert n_particles > 0
|
@assert n_particles > 0
|
||||||
|
|
||||||
Random.seed!(42)
|
Random.seed!(42)
|
||||||
|
|
||||||
sim_consts = gen_sim_consts(n_particles, v₀)
|
sim_consts = ReCo.gen_sim_consts(n_particles, 0.0; skin_to_interaction_r_ratio=3.5)
|
||||||
|
n_particles = sim_consts.n_particles
|
||||||
|
|
||||||
env_params = EnvParams(sim_consts.particle_radius, sim_consts.skin_r)
|
env_params = EnvParams(sim_consts.particle_radius, sim_consts.skin_r)
|
||||||
|
|
||||||
|
n_steps_before_actions_update = round(Int64, update_actions_at / sim_consts.δt)
|
||||||
|
|
||||||
rl_params = Params{TotalRewardPerEpisode}(
|
rl_params = Params{TotalRewardPerEpisode}(
|
||||||
n_particles, env_params, n_steps_before_actions_update
|
n_particles, env_params, n_steps_before_actions_update
|
||||||
)
|
)
|
||||||
|
@ -297,7 +302,7 @@ function run(
|
||||||
end
|
end
|
||||||
|
|
||||||
for episode in 1:n_episodes
|
for episode in 1:n_episodes
|
||||||
dir, particles = init_sim_with_sim_consts(; sim_consts, parent_dir="RL")
|
dir, particles = ReCo.init_sim_with_sim_consts(sim_consts; parent_dir="RL")
|
||||||
|
|
||||||
for i in 1:n_particles
|
for i in 1:n_particles
|
||||||
env = rl_params.envs[i]
|
env = rl_params.envs[i]
|
||||||
|
|
14
src/run.jl
14
src/run.jl
|
@ -11,12 +11,11 @@ function run_sim(
|
||||||
snapshot_at::Float64=0.1,
|
snapshot_at::Float64=0.1,
|
||||||
seed::Int64=42,
|
seed::Int64=42,
|
||||||
n_bundle_snapshots::Int64=100,
|
n_bundle_snapshots::Int64=100,
|
||||||
rl_params::Union{ReCo.Params,Nothing}=nothing,
|
rl_params::Union{RL.Params,Nothing}=nothing,
|
||||||
)
|
)
|
||||||
@assert length(dir) > 0
|
@assert length(dir) > 0
|
||||||
@assert duration > 0
|
@assert duration > 0
|
||||||
@assert snapshot_at in 0.001:0.001:duration
|
@assert snapshot_at in 0.001:0.001:duration
|
||||||
@assert duration % snapshot_at == 0
|
|
||||||
@assert seed > 0
|
@assert seed > 0
|
||||||
@assert n_bundle_snapshots >= 0
|
@assert n_bundle_snapshots >= 0
|
||||||
|
|
||||||
|
@ -29,6 +28,8 @@ function run_sim(
|
||||||
n_steps_before_snapshot = round(Int64, snapshot_at / sim_consts.δt)
|
n_steps_before_snapshot = round(Int64, snapshot_at / sim_consts.δt)
|
||||||
n_snapshots = floor(Int64, integration_steps / n_steps_before_snapshot) + 1
|
n_snapshots = floor(Int64, integration_steps / n_steps_before_snapshot) + 1
|
||||||
|
|
||||||
|
@assert (n_snapshots - 1) * snapshot_at == duration
|
||||||
|
|
||||||
n_bundle_snapshots = min(n_snapshots, n_bundle_snapshots)
|
n_bundle_snapshots = min(n_snapshots, n_bundle_snapshots)
|
||||||
|
|
||||||
sim_state = JSON3.read(read("$dir/sim_state.json", String))
|
sim_state = JSON3.read(read("$dir/sim_state.json", String))
|
||||||
|
@ -107,9 +108,9 @@ function run_sim(
|
||||||
)
|
)
|
||||||
|
|
||||||
if !isnothing(rl_params)
|
if !isnothing(rl_params)
|
||||||
pre_integration_hook! = ReCoRL.pre_integration_hook
|
pre_integration_hook! = RL.pre_integration_hook!
|
||||||
integration_hook = ReCoRL.integration_hook
|
integration_hook = RL.integration_hook
|
||||||
post_integration_hook = ReCoRL.post_integration_hook
|
post_integration_hook = RL.post_integration_hook
|
||||||
else
|
else
|
||||||
pre_integration_hook! = empty_hook
|
pre_integration_hook! = empty_hook
|
||||||
integration_hook = empty_hook
|
integration_hook = empty_hook
|
||||||
|
@ -118,10 +119,9 @@ function run_sim(
|
||||||
|
|
||||||
simulate(
|
simulate(
|
||||||
args,
|
args,
|
||||||
sim_consts.δt,
|
|
||||||
T0,
|
T0,
|
||||||
T,
|
T,
|
||||||
sim.consts.n_steps_before_verlet_list_update,
|
sim_consts.n_steps_before_verlet_list_update,
|
||||||
n_steps_before_snapshot,
|
n_steps_before_snapshot,
|
||||||
n_bundles,
|
n_bundles,
|
||||||
dir,
|
dir,
|
||||||
|
|
33
src/setup.jl
33
src/setup.jl
|
@ -3,7 +3,7 @@ using Dates: now
|
||||||
|
|
||||||
const DEFAULT_PACKING_RATIO = 0.5
|
const DEFAULT_PACKING_RATIO = 0.5
|
||||||
const DEFAULT_δt = 1e-5
|
const DEFAULT_δt = 1e-5
|
||||||
const DEFAULT_SKIN_TO_INTERACTION_R_RATIO = 3.0
|
const DEFAULT_SKIN_TO_INTERACTION_R_RATIO = 1.5
|
||||||
const DEFAULT_EXPORTS_DIR = "exports"
|
const DEFAULT_EXPORTS_DIR = "exports"
|
||||||
const DEFAULT_PARENT_DIR = ""
|
const DEFAULT_PARENT_DIR = ""
|
||||||
const DEFAULT_COMMENT = ""
|
const DEFAULT_COMMENT = ""
|
||||||
|
@ -47,7 +47,7 @@ end
|
||||||
|
|
||||||
function gen_sim_consts(
|
function gen_sim_consts(
|
||||||
n_particles::Int64,
|
n_particles::Int64,
|
||||||
v₀::Float64,
|
v₀::Float64;
|
||||||
δt::Float64=DEFAULT_δt,
|
δt::Float64=DEFAULT_δt,
|
||||||
packing_ratio::Float64=DEFAULT_PACKING_RATIO,
|
packing_ratio::Float64=DEFAULT_PACKING_RATIO,
|
||||||
skin_to_interaction_r_ratio::Float64=DEFAULT_SKIN_TO_INTERACTION_R_RATIO,
|
skin_to_interaction_r_ratio::Float64=DEFAULT_SKIN_TO_INTERACTION_R_RATIO,
|
||||||
|
@ -59,19 +59,26 @@ function gen_sim_consts(
|
||||||
|
|
||||||
μ = 1.0
|
μ = 1.0
|
||||||
D₀ = 1.0
|
D₀ = 1.0
|
||||||
particle_radius = 1.0
|
particle_radius = 0.5
|
||||||
Dᵣ = 3 * D₀ / ((2 * particle_radius)^2)
|
Dᵣ = 3 * D₀ / ((2 * particle_radius)^2)
|
||||||
σ = 1.0
|
σ = 1.0
|
||||||
ϵ = 100.0
|
ϵ = 100.0
|
||||||
interaction_r = 2^(1 / 6) * σ
|
interaction_r = 2^(1 / 6) * σ
|
||||||
|
|
||||||
buffer = 6.0
|
if v₀ != 0.0
|
||||||
|
buffer = 1.8
|
||||||
max_approach_after_one_integration_step = buffer * (2 * v₀ * δt) / interaction_r
|
max_approach_after_one_integration_step = buffer * (2 * v₀ * δt) / interaction_r
|
||||||
@assert skin_to_interaction_r_ratio >= 1 + max_approach_after_one_integration_step
|
@assert skin_to_interaction_r_ratio >= 1 + max_approach_after_one_integration_step
|
||||||
|
|
||||||
skin_r = skin_to_interaction_r_ratio * interaction_r
|
skin_r = skin_to_interaction_r_ratio * interaction_r
|
||||||
n_steps_before_verlet_list_update =
|
n_steps_before_verlet_list_update = round(
|
||||||
(skin_to_interaction_r_ratio - 1) / max_approach_after_one_integration_step
|
Int64,
|
||||||
|
(skin_to_interaction_r_ratio - 1) / max_approach_after_one_integration_step,
|
||||||
|
)
|
||||||
|
else
|
||||||
|
skin_r = 1.5 * interaction_r
|
||||||
|
n_steps_before_verlet_list_update = 100
|
||||||
|
end
|
||||||
|
|
||||||
grid_n = round(Int64, ceil(sqrt(n_particles)))
|
grid_n = round(Int64, ceil(sqrt(n_particles)))
|
||||||
|
|
||||||
|
@ -103,7 +110,7 @@ function gen_sim_consts(
|
||||||
end
|
end
|
||||||
|
|
||||||
function init_sim_with_sim_consts(
|
function init_sim_with_sim_consts(
|
||||||
sim_consts,
|
sim_consts;
|
||||||
exports_dir::String=DEFAULT_EXPORTS_DIR,
|
exports_dir::String=DEFAULT_EXPORTS_DIR,
|
||||||
parent_dir::String=DEFAULT_PARENT_DIR,
|
parent_dir::String=DEFAULT_PARENT_DIR,
|
||||||
comment::String=DEFAULT_COMMENT,
|
comment::String=DEFAULT_COMMENT,
|
||||||
|
@ -111,7 +118,7 @@ function init_sim_with_sim_consts(
|
||||||
particles = gen_particles(
|
particles = gen_particles(
|
||||||
sim_consts.grid_n, sim_consts.grid_box_width, sim_consts.half_box_len
|
sim_consts.grid_n, sim_consts.grid_box_width, sim_consts.half_box_len
|
||||||
)
|
)
|
||||||
bundle = Bundle(n_particles, 1)
|
bundle = Bundle(sim_consts.n_particles, 1)
|
||||||
save_snapshot!(bundle, 1, 0.0, particles)
|
save_snapshot!(bundle, 1, 0.0, particles)
|
||||||
|
|
||||||
dir = exports_dir
|
dir = exports_dir
|
||||||
|
@ -152,8 +159,14 @@ function init_sim(;
|
||||||
comment::String=DEFAULT_COMMENT,
|
comment::String=DEFAULT_COMMENT,
|
||||||
)
|
)
|
||||||
sim_consts = gen_sim_consts(
|
sim_consts = gen_sim_consts(
|
||||||
n_particles, v₀, δt, packing_ratio, skin_to_interaction_r_ratio
|
n_particles,
|
||||||
|
v₀;
|
||||||
|
δt=δt,
|
||||||
|
packing_ratio=packing_ratio,
|
||||||
|
skin_to_interaction_r_ratio=skin_to_interaction_r_ratio,
|
||||||
)
|
)
|
||||||
|
|
||||||
return init_sim_with_sim_consts(sim_consts, exports_dir, parent_dir, comment)[1]
|
return init_sim_with_sim_consts(
|
||||||
|
sim_consts; exports_dir=exports_dir, parent_dir=parent_dir, comment=comment
|
||||||
|
)[1]
|
||||||
end
|
end
|
|
@ -41,7 +41,7 @@ function update_verlet_lists!(args, cl)
|
||||||
end
|
end
|
||||||
|
|
||||||
function euler!(
|
function euler!(
|
||||||
args, state_hook::F, integration_hook::F, rl_params::Union{ReCoRL.Params,Nothing}
|
args, state_hook::F, integration_hook::F, rl_params::Union{RL.Params,Nothing}
|
||||||
) where {F<:Function}
|
) where {F<:Function}
|
||||||
for id1 in 1:(args.n_particles - 1)
|
for id1 in 1:(args.n_particles - 1)
|
||||||
p1 = args.particles[id1]
|
p1 = args.particles[id1]
|
||||||
|
@ -90,14 +90,13 @@ wait(::Nothing) = nothing
|
||||||
|
|
||||||
gen_run_hooks(::Nothing, args...) = false
|
gen_run_hooks(::Nothing, args...) = false
|
||||||
|
|
||||||
function gen_run_hooks(rl_params::ReCoRL.Params, integration_step::Int64)
|
function gen_run_hooks(rl_params::RL.Params, integration_step::Int64)
|
||||||
return (itegration_step == 1) ||
|
return (integration_step == 1) ||
|
||||||
(integration_step % rl_params.n_steps_before_actions_update == 0)
|
(integration_step % rl_params.n_steps_before_actions_update == 0)
|
||||||
end
|
end
|
||||||
|
|
||||||
function simulate(
|
function simulate(
|
||||||
args,
|
args,
|
||||||
δt::Float64,
|
|
||||||
T0::Float64,
|
T0::Float64,
|
||||||
T::Float64,
|
T::Float64,
|
||||||
n_steps_before_verlet_list_update::Int64,
|
n_steps_before_verlet_list_update::Int64,
|
||||||
|
@ -105,11 +104,11 @@ function simulate(
|
||||||
n_bundles::Int64,
|
n_bundles::Int64,
|
||||||
dir::String,
|
dir::String,
|
||||||
save_data::Bool,
|
save_data::Bool,
|
||||||
rl_params::Union{ReCoRL.Params,Nothing},
|
rl_params::Union{RL.Params,Nothing},
|
||||||
pre_integration_hook!::F,
|
pre_integration_hook!::Function,
|
||||||
integration_hook::F,
|
integration_hook::Function,
|
||||||
post_integration_hook::F,
|
post_integration_hook::Function,
|
||||||
) where {F<:Function}
|
)
|
||||||
bundle_snapshot_counter = 0
|
bundle_snapshot_counter = 0
|
||||||
|
|
||||||
task::Union{Task,Nothing} = nothing
|
task::Union{Task,Nothing} = nothing
|
||||||
|
@ -123,7 +122,7 @@ function simulate(
|
||||||
start_time = now()
|
start_time = now()
|
||||||
println("Started simulation at $start_time.")
|
println("Started simulation at $start_time.")
|
||||||
|
|
||||||
@showprogress 0.6 for (integration_step, t) in enumerate(T0:δt:T)
|
@showprogress 0.6 for (integration_step, t) in enumerate(T0:(args.δt):T)
|
||||||
if (integration_step % n_steps_before_snapshot == 0) && save_data
|
if (integration_step % n_steps_before_snapshot == 0) && save_data
|
||||||
wait(task)
|
wait(task)
|
||||||
|
|
||||||
|
@ -148,10 +147,10 @@ function simulate(
|
||||||
|
|
||||||
if run_hooks
|
if run_hooks
|
||||||
pre_integration_hook!(rl_params, args.n_particles)
|
pre_integration_hook!(rl_params, args.n_particles)
|
||||||
state_hook = ReCoRL.state_hook
|
state_hook = RL.state_hook
|
||||||
end
|
end
|
||||||
|
|
||||||
euler!(args, state_hook, integration_hook, rl.params.actions)
|
euler!(args, state_hook, integration_hook, rl_params)
|
||||||
|
|
||||||
if run_hooks
|
if run_hooks
|
||||||
post_integration_hook(rl_params, args.n_particles, args.particles)
|
post_integration_hook(rl_params, args.n_particles, args.particles)
|
||||||
|
|
Loading…
Reference in a new issue