1
0
Fork 0
mirror of https://gitlab.rlp.net/mobitar/ReCo.jl.git synced 2024-11-08 22:21:08 +00:00

Fixes for NearestNeighbourEnv

This commit is contained in:
Mo8it 2022-01-31 02:04:16 +01:00
parent e8a5d4f684
commit a647188e09
3 changed files with 14 additions and 3 deletions

View file

@ -26,7 +26,13 @@ struct NearestNeighbourEnv <: Env
n_states = n_distance_states * n_direction_angle_states + 1 n_states = n_distance_states * n_direction_angle_states + 1
# Last state is when no particle is in the skin radius # Last state is when no particle is in the skin radius
shared = EnvSharedProps(n_states, (n_distance_states, n_direction_angle_states)) state_spaces_labels = gen_state_spaces_labels(
("d", "\\theta"), (distance_state_space, direction_angle_state_space)
)
shared = EnvSharedProps(
n_states, (n_distance_states, n_direction_angle_states), state_spaces_labels
)
return new(shared, distance_state_space, direction_angle_state_space) return new(shared, distance_state_space, direction_angle_state_space)
end end
@ -103,6 +109,7 @@ function state_update_hook!(
for particle_id in 1:n_particles for particle_id in 1:n_particles
sq_distance = env_helper.sq_distances_to_neighbour[particle_id] sq_distance = env_helper.sq_distances_to_neighbour[particle_id]
if sq_distance == Inf64 if sq_distance == Inf64
state_id = env.shared.n_states state_id = env.shared.n_states
else else
@ -133,7 +140,9 @@ end
function update_reward!( function update_reward!(
env::NearestNeighbourEnv, env_helper::NearestNeighbourEnvHelper, particle::ReCo.Particle env::NearestNeighbourEnv, env_helper::NearestNeighbourEnvHelper, particle::ReCo.Particle
) )
reward = minimizing_reward(env_helper.current_κ, env_helper.max_distance_to_goal_κ) reward = minimizing_reward(
abs(env_helper.current_κ - env_helper.goal_κ), env_helper.max_distance_to_goal_κ
)
set_normalized_reward!(env, reward, env_helper) set_normalized_reward!(env, reward, env_helper)
return nothing return nothing

View file

@ -164,6 +164,8 @@ function run_rl(;
# Post experiment # Post experiment
hook(POST_EXPERIMENT_STAGE, agent, env) hook(POST_EXPERIMENT_STAGE, agent, env)
JLD2.save_object(env_helper_path, env_helper)
return env_helper return env_helper
end end

View file

@ -115,7 +115,7 @@ function gyration_tensor(
) )
COM = center_of_mass(particles_or_centers, half_box_len) COM = center_of_mass(particles_or_centers, half_box_len)
return gyration_tensor(particles, half_box_len, COM) return gyration_tensor(particles_or_centers, half_box_len, COM)
end end
function eigvals_ratio(matrix) function eigvals_ratio(matrix)