mirror of
https://gitlab.rlp.net/mobitar/ReCo.jl.git
synced 2024-12-21 00:51:21 +00:00
Final mean squared displacement
This commit is contained in:
parent
e50336d6e4
commit
b8ceb8b340
4 changed files with 20 additions and 24 deletions
|
@ -5,18 +5,21 @@ using Random: Random
|
||||||
using StaticArrays: SVector
|
using StaticArrays: SVector
|
||||||
using JLD2: JLD2
|
using JLD2: JLD2
|
||||||
using CellListMap: CellListMap
|
using CellListMap: CellListMap
|
||||||
|
using ProgressMeter: ProgressMeter
|
||||||
|
|
||||||
using ReCo: ReCo
|
using ReCo: ReCo
|
||||||
|
|
||||||
# IMPORTANT: Disable the periodic boundary conditions
|
# IMPORTANT: Disable the periodic boundary conditions
|
||||||
# The arguments types have to match for the function to be overwritten!
|
# The arguments types have to match for the function to be overwritten!
|
||||||
ReCo.update_verlet_lists!(::Any, ::CellListMap.CellList) = nothing
|
|
||||||
ReCo.gen_cell_list(::Vector{SVector{2,Float64}}, ::CellListMap.Box) = nothing
|
|
||||||
ReCo.gen_cell_list_box(::Float64, ::Float64) = nothing
|
|
||||||
ReCo.push_to_verlet_list!(::Any, ::Any, ::Any) = nothing
|
ReCo.push_to_verlet_list!(::Any, ::Any, ::Any) = nothing
|
||||||
|
ReCo.update_verlet_lists!(::Any, ::CellListMap.CellList) = nothing
|
||||||
|
ReCo.update_verlet_lists!(::Any, ::Nothing) = nothing
|
||||||
|
ReCo.gen_cell_list(::Vector{SVector{2,Float64}}, ::CellListMap.Box) = nothing
|
||||||
|
ReCo.gen_cell_list(::Vector{SVector{2,Float64}}, ::Nothing) = nothing
|
||||||
|
ReCo.gen_cell_list_box(::Float64, ::Float64) = nothing
|
||||||
ReCo.restrict_coordinate(value::Float64, ::Float64) = value
|
ReCo.restrict_coordinate(value::Float64, ::Float64) = value
|
||||||
ReCo.restrict_coordinates(v::SVector{2,Float64}, ::Float64) = v
|
ReCo.restrict_coordinates(v::SVector{2,Float64}, ::Float64) = v
|
||||||
ReCo.restrict_coordinates!(::Particle, ::Float64) = nothing
|
ReCo.restrict_coordinates!(::ReCo.Particle, ::Float64) = nothing
|
||||||
ReCo.minimum_image_coordinate(value::Float64, ::Float64) = value
|
ReCo.minimum_image_coordinate(value::Float64, ::Float64) = value
|
||||||
ReCo.minimum_image(v::SVector{2,Float64}, ::Float64) = v
|
ReCo.minimum_image(v::SVector{2,Float64}, ::Float64) = v
|
||||||
|
|
||||||
|
@ -27,7 +30,7 @@ function mean_squared_displacement(;
|
||||||
|
|
||||||
n_v₀s = length(v₀s)
|
n_v₀s = length(v₀s)
|
||||||
|
|
||||||
δt = ReCo.DEFAULT_δt
|
δt = 1e-4
|
||||||
Dₜ = ReCo.DEFAULT_Dₜ
|
Dₜ = ReCo.DEFAULT_Dₜ
|
||||||
|
|
||||||
main_parent_dir = "mean_squared_displacement_$(Dates.now())"
|
main_parent_dir = "mean_squared_displacement_$(Dates.now())"
|
||||||
|
@ -35,6 +38,8 @@ function mean_squared_displacement(;
|
||||||
|
|
||||||
sim_dirs = Matrix{String}(undef, (n_simulations, n_v₀s))
|
sim_dirs = Matrix{String}(undef, (n_simulations, n_v₀s))
|
||||||
|
|
||||||
|
progress = ProgressMeter.Progress(n_v₀s * n_simulations; dt=3, desc="MSD: ")
|
||||||
|
|
||||||
for (v₀_ind, v₀) in enumerate(v₀s)
|
for (v₀_ind, v₀) in enumerate(v₀s)
|
||||||
max_possible_displacement = T * v₀ + T / δt * sqrt(2 * Dₜ * δt)
|
max_possible_displacement = T * v₀ + T / δt * sqrt(2 * Dₜ * δt)
|
||||||
|
|
||||||
|
@ -55,9 +60,11 @@ function mean_squared_displacement(;
|
||||||
dir;
|
dir;
|
||||||
duration=T,
|
duration=T,
|
||||||
seed=rand(1:typemax(Int64)),
|
seed=rand(1:typemax(Int64)),
|
||||||
snapshot_at=0.01,
|
snapshot_at=0.5,
|
||||||
n_bundle_snapshots=200,
|
n_bundle_snapshots=1000,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ProgressMeter.next!(progress; showvalues=[(:v₀, v₀)])
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -141,7 +148,7 @@ function plot_mean_sq_displacement_with_expectation(
|
||||||
lines!(ax, t_linrange, expected_mean_sq_displacements)
|
lines!(ax, t_linrange, expected_mean_sq_displacements)
|
||||||
end
|
end
|
||||||
|
|
||||||
Legend(fig[1, 2], v₀_scatter_plots, [L"v₀ = %$v₀" for v₀ in v₀s])
|
Legend(fig[1, 2], v₀_scatter_plots, [L"v_0 = %$v₀" for v₀ in v₀s])
|
||||||
|
|
||||||
colgap!(fig.layout, 5)
|
colgap!(fig.layout, 5)
|
||||||
rowgap!(fig.layout, 5)
|
rowgap!(fig.layout, 5)
|
||||||
|
@ -158,7 +165,7 @@ function run_analysis()
|
||||||
v₀s = SVector(0.0, 20.0, 40.0, 60.0, 80.0)
|
v₀s = SVector(0.0, 20.0, 40.0, 60.0, 80.0)
|
||||||
|
|
||||||
ts, mean_sq_displacements = mean_squared_displacement(;
|
ts, mean_sq_displacements = mean_squared_displacement(;
|
||||||
n_simulations=3 * Threads.nthreads(), v₀s=v₀s, T=10.0
|
n_simulations=200 * Threads.nthreads(), v₀s=v₀s, T=100.0
|
||||||
)
|
)
|
||||||
|
|
||||||
plot_mean_sq_displacement_with_expectation(ts, mean_sq_displacements, v₀s)
|
plot_mean_sq_displacement_with_expectation(ts, mean_sq_displacements, v₀s)
|
||||||
|
|
|
@ -186,7 +186,7 @@ function animate_with_sim_consts(
|
||||||
bundle_paths = ReCo.sorted_bundle_paths(dir)
|
bundle_paths = ReCo.sorted_bundle_paths(dir)
|
||||||
|
|
||||||
progress = ProgressMeter.Progress(
|
progress = ProgressMeter.Progress(
|
||||||
length(bundle_paths); dt=1, enabled=show_progress, desc="Animation: "
|
length(bundle_paths); dt=2, enabled=show_progress, desc="Animation: "
|
||||||
)
|
)
|
||||||
|
|
||||||
for (n_bundle, bundle_path) in enumerate(bundle_paths)
|
for (n_bundle, bundle_path) in enumerate(bundle_paths)
|
||||||
|
|
|
@ -111,7 +111,7 @@ function run_rl(;
|
||||||
hook(PRE_EXPERIMENT_STAGE, agent, env)
|
hook(PRE_EXPERIMENT_STAGE, agent, env)
|
||||||
agent(PRE_EXPERIMENT_STAGE, env)
|
agent(PRE_EXPERIMENT_STAGE, env)
|
||||||
|
|
||||||
progress = ProgressMeter.Progress(n_episodes; dt=1, enabled=show_progress, desc="RL: ")
|
progress = ProgressMeter.Progress(n_episodes; dt=2, enabled=show_progress, desc="RL: ")
|
||||||
|
|
||||||
for episode in 1:n_episodes
|
for episode in 1:n_episodes
|
||||||
dir = ReCo.init_sim_with_sim_consts(sim_consts; parent_dir=parent_dir)
|
dir = ReCo.init_sim_with_sim_consts(sim_consts; parent_dir=parent_dir)
|
||||||
|
@ -137,11 +137,7 @@ function run_rl(;
|
||||||
hook(POST_EPISODE_STAGE, agent, env)
|
hook(POST_EPISODE_STAGE, agent, env)
|
||||||
agent(POST_EPISODE_STAGE, env)
|
agent(POST_EPISODE_STAGE, env)
|
||||||
|
|
||||||
# TODO: Replace with live plot
|
ProgressMeter.next!(progress; showvalues=[(:rewards, hook.rewards)])
|
||||||
@show hook.rewards
|
|
||||||
# @show agent.policy.explorer.step
|
|
||||||
|
|
||||||
ProgressMeter.next!(progress)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
# Post experiment
|
# Post experiment
|
||||||
|
|
|
@ -129,13 +129,10 @@ function simulate!(
|
||||||
state_update_helper_hook! =
|
state_update_helper_hook! =
|
||||||
state_update_hook! = update_table_and_actions_hook! = empty_hook
|
state_update_hook! = update_table_and_actions_hook! = empty_hook
|
||||||
|
|
||||||
start_time = Dates.now()
|
|
||||||
println("Started simulation at $start_time.")
|
|
||||||
|
|
||||||
time_range = T0:(args.δt):T
|
time_range = T0:(args.δt):T
|
||||||
|
|
||||||
progress = ProgressMeter.Progress(
|
progress = ProgressMeter.Progress(
|
||||||
length(time_range); dt=1, enabled=args.show_progress, desc="Simulation: "
|
length(time_range); dt=2, enabled=args.show_progress, desc="Simulation: "
|
||||||
)
|
)
|
||||||
|
|
||||||
for (integration_step, t) in enumerate(time_range)
|
for (integration_step, t) in enumerate(time_range)
|
||||||
|
@ -198,9 +195,5 @@ function simulate!(
|
||||||
save_bundle(dir, bundle, n_bundles, T)
|
save_bundle(dir, bundle, n_bundles, T)
|
||||||
end
|
end
|
||||||
|
|
||||||
end_time = Dates.now()
|
|
||||||
elapsed_time = Dates.canonicalize(Dates.CompoundPeriod(end_time - start_time))
|
|
||||||
println("Simulation done at $end_time and took $elapsed_time.")
|
|
||||||
|
|
||||||
return nothing
|
return nothing
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in a new issue