mirror of
https://gitlab.rlp.net/mobitar/ReCo.jl.git
synced 2024-12-21 00:51:21 +00:00
Added NearestNeighbourEnv
This commit is contained in:
parent
9e80c22bdb
commit
9e3b71197c
5 changed files with 151 additions and 4 deletions
|
@ -26,7 +26,13 @@ struct LocalCOMWithAdditionalShapeRewardEnv <: Env
|
|||
n_states = n_distance_states * n_direction_angle_states + 1
|
||||
# Last state is when no particle is in the skin radius
|
||||
|
||||
shared = EnvSharedProps(n_states, (n_distance_states, n_direction_angle_states))
|
||||
state_spaces_labels = gen_state_spaces_labels(
|
||||
("d", "\\theta"), (distance_state_space, direction_angle_state_space)
|
||||
)
|
||||
|
||||
shared = EnvSharedProps(
|
||||
n_states, (n_distance_states, n_direction_angle_states), state_spaces_labels
|
||||
)
|
||||
|
||||
return new(shared, distance_state_space, direction_angle_state_space)
|
||||
end
|
||||
|
@ -95,6 +101,7 @@ function state_update_helper_hook!(
|
|||
id1::Int64,
|
||||
id2::Int64,
|
||||
r⃗₁₂::SVector{2,Float64},
|
||||
distance²::Float64,
|
||||
)
|
||||
env_helper.vec_to_neighbour_sums[id1] += r⃗₁₂
|
||||
env_helper.vec_to_neighbour_sums[id2] -= r⃗₁₂
|
||||
|
|
140
src/RL/Envs/NearestNeighbourEnv.jl
Normal file
140
src/RL/Envs/NearestNeighbourEnv.jl
Normal file
|
@ -0,0 +1,140 @@
|
|||
export NearestNeighbourEnv
|
||||
|
||||
using ..ReCo: ReCo
|
||||
|
||||
struct NearestNeighbourEnv <: Env
|
||||
shared::EnvSharedProps
|
||||
|
||||
distance_state_space::Vector{Interval}
|
||||
direction_angle_state_space::Vector{Interval}
|
||||
|
||||
function NearestNeighbourEnv(;
|
||||
n_distance_states::Int64=3, n_direction_angle_states::Int64=3, args
|
||||
)
|
||||
@assert n_distance_states > 1
|
||||
@assert n_direction_angle_states > 1
|
||||
|
||||
direction_angle_state_space = gen_angle_state_space(n_direction_angle_states)
|
||||
|
||||
min_distance = 0.0
|
||||
max_distance = args.skin_radius
|
||||
|
||||
distance_state_space = gen_distance_state_space(
|
||||
min_distance, max_distance, n_distance_states
|
||||
)
|
||||
|
||||
n_states = n_distance_states * n_direction_angle_states + 1
|
||||
# Last state is when no particle is in the skin radius
|
||||
|
||||
shared = EnvSharedProps(n_states, (n_distance_states, n_direction_angle_states))
|
||||
|
||||
return new(shared, distance_state_space, direction_angle_state_space)
|
||||
end
|
||||
end
|
||||
|
||||
mutable struct NearestNeighbourEnvHelper <: EnvHelper
|
||||
shared::EnvHelperSharedProps
|
||||
|
||||
vecs_to_neighbour::Vector{SVector{2,Float64}}
|
||||
sq_distances_to_neighbour::Vector{Float64}
|
||||
|
||||
current_κ::Float64
|
||||
goal_κ::Float64
|
||||
max_distance_to_goal_κ::Float64
|
||||
|
||||
half_box_len::Float64
|
||||
|
||||
function NearestNeighbourEnvHelper(shared::EnvHelperSharedProps, half_box_len::Float64)
|
||||
goal_κ = 0.4
|
||||
max_distance_to_goal_κ = max(1 - goal_κ, goal_κ)
|
||||
|
||||
return new(
|
||||
shared,
|
||||
fill(SVector(0.0, 0.0), shared.n_particles),
|
||||
zeros(shared.n_particles),
|
||||
1.0,
|
||||
goal_κ,
|
||||
max_distance_to_goal_κ,
|
||||
half_box_len,
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
function gen_env_helper(
|
||||
::NearestNeighbourEnv, env_helper_shared::EnvHelperSharedProps; args
|
||||
)
|
||||
return NearestNeighbourEnvHelper(env_helper_shared, args.half_box_len)
|
||||
end
|
||||
|
||||
function pre_integration_hook!(env_helper::NearestNeighbourEnvHelper)
|
||||
@simd for particle_id in 1:(env_helper.shared.n_particles)
|
||||
env_helper.sq_distances_to_neighbour[particle_id] = Inf64
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function state_update_helper_hook!(
|
||||
env_helper::NearestNeighbourEnvHelper,
|
||||
id1::Int64,
|
||||
id2::Int64,
|
||||
r⃗₁₂::SVector{2,Float64},
|
||||
distance²::Float64,
|
||||
)
|
||||
if distance² < env_helper.sq_distances_to_neighbour[id1]
|
||||
env_helper.vecs_to_neighbour[id1] = r⃗₁₂
|
||||
env_helper.sq_distances_to_neighbour[id1] = distance²
|
||||
end
|
||||
|
||||
if distance² < env_helper.sq_distances_to_neighbour[id2]
|
||||
env_helper.vecs_to_neighbour[id2] = -r⃗₁₂
|
||||
env_helper.sq_distances_to_neighbour[id2] = distance²
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function state_update_hook!(
|
||||
env_helper::NearestNeighbourEnvHelper, particles::Vector{ReCo.Particle}
|
||||
)
|
||||
n_particles = env_helper.shared.n_particles
|
||||
|
||||
env = env_helper.shared.env
|
||||
|
||||
for particle_id in 1:n_particles
|
||||
sq_distance = env_helper.sq_distances_to_neighbour[particle_id]
|
||||
if sq_distance == Inf64
|
||||
state_id = env.shared.n_states
|
||||
else
|
||||
distance = sqrt(sq_distance)
|
||||
distance_state_ind = find_state_ind(distance, env.distance_state_space)
|
||||
|
||||
si, co = sincos(particles[particle_id].φ)
|
||||
direction_angle = ReCo.angle2(
|
||||
SVector(co, si), env_helper.vecs_to_neighbour[particle_id]
|
||||
)
|
||||
direction_state_ind = find_state_ind(
|
||||
direction_angle, env.direction_angle_state_space
|
||||
)
|
||||
|
||||
state_id = env.shared.state_id_tensor[distance_state_ind, direction_state_ind]
|
||||
end
|
||||
|
||||
env_helper.shared.states_id[particle_id] = state_id
|
||||
end
|
||||
|
||||
env_helper.current_κ = ReCo.gyration_tensor_eigvals_ratio(
|
||||
particles, env_helper.half_box_len
|
||||
)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function update_reward!(
|
||||
env::NearestNeighbourEnv, env_helper::NearestNeighbourEnvHelper, particle::ReCo.Particle
|
||||
)
|
||||
reward = minimizing_reward(env_helper.current_κ, env_helper.max_distance_to_goal_κ)
|
||||
set_normalized_reward!(env, reward, env_helper)
|
||||
|
||||
return nothing
|
||||
end
|
|
@ -61,7 +61,7 @@ function pre_integration_hook!(::OriginEnvHelper)
|
|||
end
|
||||
|
||||
function state_update_helper_hook!(
|
||||
::OriginEnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}
|
||||
::OriginEnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}, distance²::Float64
|
||||
)
|
||||
return nothing
|
||||
end
|
||||
|
|
|
@ -5,7 +5,7 @@ function pre_integration_hook!(::EnvHelper)
|
|||
end
|
||||
|
||||
function state_update_helper_hook!(
|
||||
::EnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}
|
||||
::EnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}, distance²::Float64
|
||||
)
|
||||
return ReCo.method_not_implemented()
|
||||
end
|
||||
|
|
|
@ -54,7 +54,7 @@ function euler!(
|
|||
p1_c, p2.c, args.interaction_radius², args.half_box_len
|
||||
)
|
||||
|
||||
state_update_helper_hook!(env_helper, id1, id2, r⃗₁₂)
|
||||
state_update_helper_hook!(env_helper, id1, id2, r⃗₁₂, distance²)
|
||||
|
||||
if overlapping
|
||||
factor = args.ϵσ⁶δtμₜ24 / (distance²^4) * (1.0 - args.σ⁶2 / (distance²^3))
|
||||
|
|
Loading…
Reference in a new issue