mirror of
https://gitlab.rlp.net/mobitar/ReCo.jl.git
synced 2024-12-21 00:51:21 +00:00
Fixes for NearestNeighbourEnv
This commit is contained in:
parent
e8a5d4f684
commit
a647188e09
3 changed files with 14 additions and 3 deletions
|
@ -26,7 +26,13 @@ struct NearestNeighbourEnv <: Env
|
||||||
n_states = n_distance_states * n_direction_angle_states + 1
|
n_states = n_distance_states * n_direction_angle_states + 1
|
||||||
# Last state is when no particle is in the skin radius
|
# Last state is when no particle is in the skin radius
|
||||||
|
|
||||||
shared = EnvSharedProps(n_states, (n_distance_states, n_direction_angle_states))
|
state_spaces_labels = gen_state_spaces_labels(
|
||||||
|
("d", "\\theta"), (distance_state_space, direction_angle_state_space)
|
||||||
|
)
|
||||||
|
|
||||||
|
shared = EnvSharedProps(
|
||||||
|
n_states, (n_distance_states, n_direction_angle_states), state_spaces_labels
|
||||||
|
)
|
||||||
|
|
||||||
return new(shared, distance_state_space, direction_angle_state_space)
|
return new(shared, distance_state_space, direction_angle_state_space)
|
||||||
end
|
end
|
||||||
|
@ -103,6 +109,7 @@ function state_update_hook!(
|
||||||
|
|
||||||
for particle_id in 1:n_particles
|
for particle_id in 1:n_particles
|
||||||
sq_distance = env_helper.sq_distances_to_neighbour[particle_id]
|
sq_distance = env_helper.sq_distances_to_neighbour[particle_id]
|
||||||
|
|
||||||
if sq_distance == Inf64
|
if sq_distance == Inf64
|
||||||
state_id = env.shared.n_states
|
state_id = env.shared.n_states
|
||||||
else
|
else
|
||||||
|
@ -133,7 +140,9 @@ end
|
||||||
function update_reward!(
|
function update_reward!(
|
||||||
env::NearestNeighbourEnv, env_helper::NearestNeighbourEnvHelper, particle::ReCo.Particle
|
env::NearestNeighbourEnv, env_helper::NearestNeighbourEnvHelper, particle::ReCo.Particle
|
||||||
)
|
)
|
||||||
reward = minimizing_reward(env_helper.current_κ, env_helper.max_distance_to_goal_κ)
|
reward = minimizing_reward(
|
||||||
|
abs(env_helper.current_κ - env_helper.goal_κ), env_helper.max_distance_to_goal_κ
|
||||||
|
)
|
||||||
set_normalized_reward!(env, reward, env_helper)
|
set_normalized_reward!(env, reward, env_helper)
|
||||||
|
|
||||||
return nothing
|
return nothing
|
||||||
|
|
|
@ -164,6 +164,8 @@ function run_rl(;
|
||||||
# Post experiment
|
# Post experiment
|
||||||
hook(POST_EXPERIMENT_STAGE, agent, env)
|
hook(POST_EXPERIMENT_STAGE, agent, env)
|
||||||
|
|
||||||
|
JLD2.save_object(env_helper_path, env_helper)
|
||||||
|
|
||||||
return env_helper
|
return env_helper
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -115,7 +115,7 @@ function gyration_tensor(
|
||||||
)
|
)
|
||||||
COM = center_of_mass(particles_or_centers, half_box_len)
|
COM = center_of_mass(particles_or_centers, half_box_len)
|
||||||
|
|
||||||
return gyration_tensor(particles, half_box_len, COM)
|
return gyration_tensor(particles_or_centers, half_box_len, COM)
|
||||||
end
|
end
|
||||||
|
|
||||||
function eigvals_ratio(matrix)
|
function eigvals_ratio(matrix)
|
||||||
|
|
Loading…
Reference in a new issue