From d6900d9a92aa23e42d01f663d740059b8d36bb88 Mon Sep 17 00:00:00 2001 From: Mo8it Date: Mon, 31 Jan 2022 17:14:24 +0100 Subject: [PATCH] Added mean_kappa.jl --- Manifest.toml | 4 +- analysis/mean_kappa.jl | 96 +++++++++++++++++++ analysis/mean_squared_displacement.jl | 2 +- .../radial_distribution_function.jl | 2 +- src/RL/Envs/LocalCOMEnv.jl | 2 +- .../LocalCOMWithAdditionalShapeRewardEnv.jl | 4 +- .../LocalCOMWithAdditionalShapeRewardEnv2.jl | 4 +- src/RL/Envs/NearestNeighbourEnv.jl | 2 +- src/RL/Envs/OriginEnv.jl | 4 +- src/RL/RL.jl | 2 +- src/Visualization/Animation.jl | 4 +- 11 files changed, 114 insertions(+), 12 deletions(-) create mode 100644 analysis/mean_kappa.jl diff --git a/Manifest.toml b/Manifest.toml index 8e37f5b..a16bfc4 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -294,9 +294,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[deps.Distributions]] deps = ["ChainRulesCore", "DensityInterface", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test"] -git-tree-sha1 = "38bcc22b6e358e88a7715ad0db446dfd3a4fea47" +git-tree-sha1 = "c6dd4a56078a7760c04b882d9d94a08a4669598d" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.25.43" +version = "0.25.44" [[deps.DocStringExtensions]] deps = ["LibGit2"] diff --git a/analysis/mean_kappa.jl b/analysis/mean_kappa.jl new file mode 100644 index 0000000..08a6e6e --- /dev/null +++ b/analysis/mean_kappa.jl @@ -0,0 +1,96 @@ +using CairoMakie +using LaTeXStrings: @L_str +using Statistics: Statistics + +using ReCo: ReCo + +includet("../src/Visualization/common_CairoMakie.jl") + +function plot_mean_kappa(; rl_dir::String, n_last_episodes::Int64) + dir_content = readdir(rl_dir; join=true, sort=true) + n_content = length(dir_content) + + sim_dirs = Vector{String}(undef, n_last_episodes) + sim_dir_counter = 1 + + # Skip first sim_dir for the case that the simulation is still running + skipped_first_sim_dir = false + + for file_or_dir_ind in n_content:-1:1 + file_or_dir = dir_content[file_or_dir_ind] + + if isdir(file_or_dir) + if skipped_first_sim_dir + sim_dirs[sim_dir_counter] = file_or_dir + sim_dir_counter += 1 + if sim_dir_counter > n_last_episodes + break + end + else + skipped_first_sim_dir = true + end + end + end + + if sim_dir_counter < n_last_episodes + error("The rl_dir does not have n_last_episodes + 1 dirs!") + end + + sim_consts = ReCo.load_sim_consts(sim_dirs[1]) + half_box_len = sim_consts.half_box_len + total_n_snapshots = ReCo.BundlesInfo(sim_dirs[1]).total_n_snapshots + + snapshot_κs = zeros(Float64, total_n_snapshots) + + for sim_dir in sim_dirs + bundles_info = ReCo.BundlesInfo(sim_dir) + total_n_snapshots = bundles_info.total_n_snapshots + + for snapshot_ind in 1:total_n_snapshots + bundle, bundle_snapshot = ReCo.get_bundle_to_snapshot( + bundles_info, snapshot_ind + ) + + cs_view = view(bundle.c, :, bundle_snapshot) + + κ = ReCo.gyration_tensor_eigvals_ratio(cs_view, half_box_len) + snapshot_κs[snapshot_ind] += κ + end + end + + snapshot_κs ./= n_last_episodes + + mean_κ = Statistics.mean(snapshot_κs) + + init_cairomakie!() + fig = gen_figure(; padding=9) + + ax = Axis( + fig[1, 1]; + xlabel="Frame", + ylabel=L"\kappa", + limits=(1, total_n_snapshots, 0.0, 1.04), + title="Averaged over last $n_last_episodes episodes", + ) + + lines!(ax, 1:total_n_snapshots, snapshot_κs; label=L"\kappa") + + rounded_mean_κ = round(mean_κ; digits=2) + + lines!( + ax, + [1, total_n_snapshots], + [mean_κ, mean_κ]; + label=L"Mean $\tilde{\kappa} = %$rounded_mean_κ$", + linestyle=:dash, + color=:red, + ) + + axislegend(ax; position=:lb, padding=3, rowgap=-3) + + set_gaps!(fig) + + save_fig("mean_kappa.pdf", fig; parent_dir=rl_dir) + + return nothing +end \ No newline at end of file diff --git a/analysis/mean_squared_displacement.jl b/analysis/mean_squared_displacement.jl index 9166c74..43b2ac7 100644 --- a/analysis/mean_squared_displacement.jl +++ b/analysis/mean_squared_displacement.jl @@ -90,7 +90,7 @@ function mean_squared_displacement(; (ts,), (:t,), sim_dirs[1, 1]; particle_slice=1, snapshot_slice=:, first_bundle=2 ) # Skip the first bundle to avoid t = 0 - mean_sq_displacements = zeros((length(ts), n_v₀s)) + mean_sq_displacements = zeros(Float64, (length(ts), n_v₀s)) @simd for v₀_ind in 1:n_v₀s for sim_ind in 1:n_simulations diff --git a/analysis/radial_distribution_function/radial_distribution_function.jl b/analysis/radial_distribution_function/radial_distribution_function.jl index d146b3b..629861e 100644 --- a/analysis/radial_distribution_function/radial_distribution_function.jl +++ b/analysis/radial_distribution_function/radial_distribution_function.jl @@ -102,7 +102,7 @@ function radial_distribution(; error("snapshot_conunter != n_last_snapshots") end - g = zeros(n_radii) + g = zeros(Float64, n_radii) for snapshot_ind in 1:n_last_snapshots for p1_ind in 1:n_particles diff --git a/src/RL/Envs/LocalCOMEnv.jl b/src/RL/Envs/LocalCOMEnv.jl index 91cdca2..5d6e8af 100644 --- a/src/RL/Envs/LocalCOMEnv.jl +++ b/src/RL/Envs/LocalCOMEnv.jl @@ -58,7 +58,7 @@ mutable struct LocalCOMEnvHelper <: EnvHelper shared, fill(SVector(0.0, 0.0), shared.n_particles), fill(0, shared.n_particles), - zeros(shared.n_particles), + zeros(Float64, shared.n_particles), max_distance_to_local_center_of_mass, half_box_len, ) diff --git a/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv.jl b/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv.jl index fe1a604..3181228 100644 --- a/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv.jl +++ b/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv.jl @@ -67,7 +67,7 @@ mutable struct LocalCOMWithAdditionalShapeRewardEnvHelper <: EnvHelper shared, fill(SVector(0.0, 0.0), shared.n_particles), fill(0, shared.n_particles), - zeros(shared.n_particles), + zeros(Float64, shared.n_particles), max_distance_to_local_center_of_mass, false, SVector(0.0, 0.0), @@ -197,6 +197,8 @@ function update_reward!( ) end + reward /= 2 + set_normalized_reward!(env, reward, env_helper) end diff --git a/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv2.jl b/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv2.jl index b0b0f7e..51be11f 100644 --- a/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv2.jl +++ b/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv2.jl @@ -67,7 +67,7 @@ mutable struct LocalCOMWithAdditionalShapeRewardEnv2Helper <: EnvHelper shared, fill(SVector(0.0, 0.0), shared.n_particles), fill(0, shared.n_particles), - zeros(shared.n_particles), + zeros(Float64, shared.n_particles), max_distance_to_local_center_of_mass, false, 1.0, @@ -180,6 +180,8 @@ function update_reward!( ) end + reward /= 2 + set_normalized_reward!(env, reward, env_helper) end diff --git a/src/RL/Envs/NearestNeighbourEnv.jl b/src/RL/Envs/NearestNeighbourEnv.jl index b96179b..5460686 100644 --- a/src/RL/Envs/NearestNeighbourEnv.jl +++ b/src/RL/Envs/NearestNeighbourEnv.jl @@ -57,7 +57,7 @@ mutable struct NearestNeighbourEnvHelper <: EnvHelper return new( shared, fill(SVector(0.0, 0.0), shared.n_particles), - zeros(shared.n_particles), + zeros(Float64, shared.n_particles), 1.0, goal_κ, max_distance_to_goal_κ, diff --git a/src/RL/Envs/OriginEnv.jl b/src/RL/Envs/OriginEnv.jl index 9ad38e8..118007d 100644 --- a/src/RL/Envs/OriginEnv.jl +++ b/src/RL/Envs/OriginEnv.jl @@ -48,7 +48,9 @@ mutable struct OriginEnvHelper <: EnvHelper function OriginEnvHelper(shared::EnvHelperSharedProps, half_box_len::Float64) max_distance_to_origin = sqrt(2) * half_box_len - return new(shared, zeros(shared.n_particles), max_distance_to_origin, half_box_len) + return new( + shared, zeros(Float64, shared.n_particles), max_distance_to_origin, half_box_len + ) end end diff --git a/src/RL/RL.jl b/src/RL/RL.jl index 977e545..ee37711 100644 --- a/src/RL/RL.jl +++ b/src/RL/RL.jl @@ -70,7 +70,7 @@ function run_rl(; update_actions_at::Float64=0.1, n_particles::Int64=100, seed::Int64=42, - ϵ_stable::Float64=0.0001, + ϵ_stable::Float64=0.00001, skin_to_interaction_radius_ratio::Float64=ReCo.DEFAULT_SKIN_TO_INTERACTION_RADIUS_RATIO, packing_ratio::Float64=0.15, show_progress::Bool=true, diff --git a/src/Visualization/Animation.jl b/src/Visualization/Animation.jl index 6cce2d8..46a7cf8 100644 --- a/src/Visualization/Animation.jl +++ b/src/Visualization/Animation.jl @@ -204,8 +204,8 @@ function animate( end if show_frame_diff - segment_xs = Observable(zeros(2 * n_particles)) - segment_ys = Observable(zeros(2 * n_particles)) + segment_xs = Observable(zeros(Float64, 2 * n_particles)) + segment_ys = Observable(zeros(Float64, 2 * n_particles)) end bundle_paths = ReCo.sorted_bundle_paths(sim_dir)