mirror of
https://gitlab.rlp.net/mobitar/ReCo.jl.git
synced 2024-12-21 00:51:21 +00:00
Fixes for NearestNeighbourEnv
This commit is contained in:
parent
e8a5d4f684
commit
a647188e09
3 changed files with 14 additions and 3 deletions
|
@ -26,7 +26,13 @@ struct NearestNeighbourEnv <: Env
|
|||
n_states = n_distance_states * n_direction_angle_states + 1
|
||||
# Last state is when no particle is in the skin radius
|
||||
|
||||
shared = EnvSharedProps(n_states, (n_distance_states, n_direction_angle_states))
|
||||
state_spaces_labels = gen_state_spaces_labels(
|
||||
("d", "\\theta"), (distance_state_space, direction_angle_state_space)
|
||||
)
|
||||
|
||||
shared = EnvSharedProps(
|
||||
n_states, (n_distance_states, n_direction_angle_states), state_spaces_labels
|
||||
)
|
||||
|
||||
return new(shared, distance_state_space, direction_angle_state_space)
|
||||
end
|
||||
|
@ -103,6 +109,7 @@ function state_update_hook!(
|
|||
|
||||
for particle_id in 1:n_particles
|
||||
sq_distance = env_helper.sq_distances_to_neighbour[particle_id]
|
||||
|
||||
if sq_distance == Inf64
|
||||
state_id = env.shared.n_states
|
||||
else
|
||||
|
@ -133,7 +140,9 @@ end
|
|||
function update_reward!(
|
||||
env::NearestNeighbourEnv, env_helper::NearestNeighbourEnvHelper, particle::ReCo.Particle
|
||||
)
|
||||
reward = minimizing_reward(env_helper.current_κ, env_helper.max_distance_to_goal_κ)
|
||||
reward = minimizing_reward(
|
||||
abs(env_helper.current_κ - env_helper.goal_κ), env_helper.max_distance_to_goal_κ
|
||||
)
|
||||
set_normalized_reward!(env, reward, env_helper)
|
||||
|
||||
return nothing
|
||||
|
|
|
@ -164,6 +164,8 @@ function run_rl(;
|
|||
# Post experiment
|
||||
hook(POST_EXPERIMENT_STAGE, agent, env)
|
||||
|
||||
JLD2.save_object(env_helper_path, env_helper)
|
||||
|
||||
return env_helper
|
||||
end
|
||||
|
||||
|
|
|
@ -115,7 +115,7 @@ function gyration_tensor(
|
|||
)
|
||||
COM = center_of_mass(particles_or_centers, half_box_len)
|
||||
|
||||
return gyration_tensor(particles, half_box_len, COM)
|
||||
return gyration_tensor(particles_or_centers, half_box_len, COM)
|
||||
end
|
||||
|
||||
function eigvals_ratio(matrix)
|
||||
|
|
Loading…
Reference in a new issue