From a647188e09b152d9fc86542de0e30008bebc3814 Mon Sep 17 00:00:00 2001 From: Mo8it Date: Mon, 31 Jan 2022 02:04:16 +0100 Subject: [PATCH] Fixes for NearestNeighbourEnv --- src/RL/Envs/NearestNeighbourEnv.jl | 13 +++++++++++-- src/RL/RL.jl | 2 ++ src/Shape.jl | 2 +- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/RL/Envs/NearestNeighbourEnv.jl b/src/RL/Envs/NearestNeighbourEnv.jl index a12eab0..b96179b 100644 --- a/src/RL/Envs/NearestNeighbourEnv.jl +++ b/src/RL/Envs/NearestNeighbourEnv.jl @@ -26,7 +26,13 @@ struct NearestNeighbourEnv <: Env n_states = n_distance_states * n_direction_angle_states + 1 # Last state is when no particle is in the skin radius - shared = EnvSharedProps(n_states, (n_distance_states, n_direction_angle_states)) + state_spaces_labels = gen_state_spaces_labels( + ("d", "\\theta"), (distance_state_space, direction_angle_state_space) + ) + + shared = EnvSharedProps( + n_states, (n_distance_states, n_direction_angle_states), state_spaces_labels + ) return new(shared, distance_state_space, direction_angle_state_space) end @@ -103,6 +109,7 @@ function state_update_hook!( for particle_id in 1:n_particles sq_distance = env_helper.sq_distances_to_neighbour[particle_id] + if sq_distance == Inf64 state_id = env.shared.n_states else @@ -133,7 +140,9 @@ end function update_reward!( env::NearestNeighbourEnv, env_helper::NearestNeighbourEnvHelper, particle::ReCo.Particle ) - reward = minimizing_reward(env_helper.current_κ, env_helper.max_distance_to_goal_κ) + reward = minimizing_reward( + abs(env_helper.current_κ - env_helper.goal_κ), env_helper.max_distance_to_goal_κ + ) set_normalized_reward!(env, reward, env_helper) return nothing diff --git a/src/RL/RL.jl b/src/RL/RL.jl index 9a3ab5f..503ddad 100644 --- a/src/RL/RL.jl +++ b/src/RL/RL.jl @@ -164,6 +164,8 @@ function run_rl(; # Post experiment hook(POST_EXPERIMENT_STAGE, agent, env) + JLD2.save_object(env_helper_path, env_helper) + return env_helper end diff --git a/src/Shape.jl b/src/Shape.jl index 199f20b..5edbe12 100644 --- a/src/Shape.jl +++ b/src/Shape.jl @@ -115,7 +115,7 @@ function gyration_tensor( ) COM = center_of_mass(particles_or_centers, half_box_len) - return gyration_tensor(particles, half_box_len, COM) + return gyration_tensor(particles_or_centers, half_box_len, COM) end function eigvals_ratio(matrix)