mirror of
https://gitlab.rlp.net/mobitar/ReCo.jl.git
synced 2024-12-30 17:53:32 +00:00
Added NearestNeighbourEnv
This commit is contained in:
parent
9e80c22bdb
commit
9e3b71197c
5 changed files with 151 additions and 4 deletions
|
@ -26,7 +26,13 @@ struct LocalCOMWithAdditionalShapeRewardEnv <: Env
|
||||||
n_states = n_distance_states * n_direction_angle_states + 1
|
n_states = n_distance_states * n_direction_angle_states + 1
|
||||||
# Last state is when no particle is in the skin radius
|
# Last state is when no particle is in the skin radius
|
||||||
|
|
||||||
shared = EnvSharedProps(n_states, (n_distance_states, n_direction_angle_states))
|
state_spaces_labels = gen_state_spaces_labels(
|
||||||
|
("d", "\\theta"), (distance_state_space, direction_angle_state_space)
|
||||||
|
)
|
||||||
|
|
||||||
|
shared = EnvSharedProps(
|
||||||
|
n_states, (n_distance_states, n_direction_angle_states), state_spaces_labels
|
||||||
|
)
|
||||||
|
|
||||||
return new(shared, distance_state_space, direction_angle_state_space)
|
return new(shared, distance_state_space, direction_angle_state_space)
|
||||||
end
|
end
|
||||||
|
@ -95,6 +101,7 @@ function state_update_helper_hook!(
|
||||||
id1::Int64,
|
id1::Int64,
|
||||||
id2::Int64,
|
id2::Int64,
|
||||||
r⃗₁₂::SVector{2,Float64},
|
r⃗₁₂::SVector{2,Float64},
|
||||||
|
distance²::Float64,
|
||||||
)
|
)
|
||||||
env_helper.vec_to_neighbour_sums[id1] += r⃗₁₂
|
env_helper.vec_to_neighbour_sums[id1] += r⃗₁₂
|
||||||
env_helper.vec_to_neighbour_sums[id2] -= r⃗₁₂
|
env_helper.vec_to_neighbour_sums[id2] -= r⃗₁₂
|
||||||
|
|
140
src/RL/Envs/NearestNeighbourEnv.jl
Normal file
140
src/RL/Envs/NearestNeighbourEnv.jl
Normal file
|
@ -0,0 +1,140 @@
|
||||||
|
export NearestNeighbourEnv
|
||||||
|
|
||||||
|
using ..ReCo: ReCo
|
||||||
|
|
||||||
|
struct NearestNeighbourEnv <: Env
|
||||||
|
shared::EnvSharedProps
|
||||||
|
|
||||||
|
distance_state_space::Vector{Interval}
|
||||||
|
direction_angle_state_space::Vector{Interval}
|
||||||
|
|
||||||
|
function NearestNeighbourEnv(;
|
||||||
|
n_distance_states::Int64=3, n_direction_angle_states::Int64=3, args
|
||||||
|
)
|
||||||
|
@assert n_distance_states > 1
|
||||||
|
@assert n_direction_angle_states > 1
|
||||||
|
|
||||||
|
direction_angle_state_space = gen_angle_state_space(n_direction_angle_states)
|
||||||
|
|
||||||
|
min_distance = 0.0
|
||||||
|
max_distance = args.skin_radius
|
||||||
|
|
||||||
|
distance_state_space = gen_distance_state_space(
|
||||||
|
min_distance, max_distance, n_distance_states
|
||||||
|
)
|
||||||
|
|
||||||
|
n_states = n_distance_states * n_direction_angle_states + 1
|
||||||
|
# Last state is when no particle is in the skin radius
|
||||||
|
|
||||||
|
shared = EnvSharedProps(n_states, (n_distance_states, n_direction_angle_states))
|
||||||
|
|
||||||
|
return new(shared, distance_state_space, direction_angle_state_space)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
mutable struct NearestNeighbourEnvHelper <: EnvHelper
|
||||||
|
shared::EnvHelperSharedProps
|
||||||
|
|
||||||
|
vecs_to_neighbour::Vector{SVector{2,Float64}}
|
||||||
|
sq_distances_to_neighbour::Vector{Float64}
|
||||||
|
|
||||||
|
current_κ::Float64
|
||||||
|
goal_κ::Float64
|
||||||
|
max_distance_to_goal_κ::Float64
|
||||||
|
|
||||||
|
half_box_len::Float64
|
||||||
|
|
||||||
|
function NearestNeighbourEnvHelper(shared::EnvHelperSharedProps, half_box_len::Float64)
|
||||||
|
goal_κ = 0.4
|
||||||
|
max_distance_to_goal_κ = max(1 - goal_κ, goal_κ)
|
||||||
|
|
||||||
|
return new(
|
||||||
|
shared,
|
||||||
|
fill(SVector(0.0, 0.0), shared.n_particles),
|
||||||
|
zeros(shared.n_particles),
|
||||||
|
1.0,
|
||||||
|
goal_κ,
|
||||||
|
max_distance_to_goal_κ,
|
||||||
|
half_box_len,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function gen_env_helper(
|
||||||
|
::NearestNeighbourEnv, env_helper_shared::EnvHelperSharedProps; args
|
||||||
|
)
|
||||||
|
return NearestNeighbourEnvHelper(env_helper_shared, args.half_box_len)
|
||||||
|
end
|
||||||
|
|
||||||
|
function pre_integration_hook!(env_helper::NearestNeighbourEnvHelper)
|
||||||
|
@simd for particle_id in 1:(env_helper.shared.n_particles)
|
||||||
|
env_helper.sq_distances_to_neighbour[particle_id] = Inf64
|
||||||
|
end
|
||||||
|
|
||||||
|
return nothing
|
||||||
|
end
|
||||||
|
|
||||||
|
function state_update_helper_hook!(
|
||||||
|
env_helper::NearestNeighbourEnvHelper,
|
||||||
|
id1::Int64,
|
||||||
|
id2::Int64,
|
||||||
|
r⃗₁₂::SVector{2,Float64},
|
||||||
|
distance²::Float64,
|
||||||
|
)
|
||||||
|
if distance² < env_helper.sq_distances_to_neighbour[id1]
|
||||||
|
env_helper.vecs_to_neighbour[id1] = r⃗₁₂
|
||||||
|
env_helper.sq_distances_to_neighbour[id1] = distance²
|
||||||
|
end
|
||||||
|
|
||||||
|
if distance² < env_helper.sq_distances_to_neighbour[id2]
|
||||||
|
env_helper.vecs_to_neighbour[id2] = -r⃗₁₂
|
||||||
|
env_helper.sq_distances_to_neighbour[id2] = distance²
|
||||||
|
end
|
||||||
|
|
||||||
|
return nothing
|
||||||
|
end
|
||||||
|
|
||||||
|
function state_update_hook!(
|
||||||
|
env_helper::NearestNeighbourEnvHelper, particles::Vector{ReCo.Particle}
|
||||||
|
)
|
||||||
|
n_particles = env_helper.shared.n_particles
|
||||||
|
|
||||||
|
env = env_helper.shared.env
|
||||||
|
|
||||||
|
for particle_id in 1:n_particles
|
||||||
|
sq_distance = env_helper.sq_distances_to_neighbour[particle_id]
|
||||||
|
if sq_distance == Inf64
|
||||||
|
state_id = env.shared.n_states
|
||||||
|
else
|
||||||
|
distance = sqrt(sq_distance)
|
||||||
|
distance_state_ind = find_state_ind(distance, env.distance_state_space)
|
||||||
|
|
||||||
|
si, co = sincos(particles[particle_id].φ)
|
||||||
|
direction_angle = ReCo.angle2(
|
||||||
|
SVector(co, si), env_helper.vecs_to_neighbour[particle_id]
|
||||||
|
)
|
||||||
|
direction_state_ind = find_state_ind(
|
||||||
|
direction_angle, env.direction_angle_state_space
|
||||||
|
)
|
||||||
|
|
||||||
|
state_id = env.shared.state_id_tensor[distance_state_ind, direction_state_ind]
|
||||||
|
end
|
||||||
|
|
||||||
|
env_helper.shared.states_id[particle_id] = state_id
|
||||||
|
end
|
||||||
|
|
||||||
|
env_helper.current_κ = ReCo.gyration_tensor_eigvals_ratio(
|
||||||
|
particles, env_helper.half_box_len
|
||||||
|
)
|
||||||
|
|
||||||
|
return nothing
|
||||||
|
end
|
||||||
|
|
||||||
|
function update_reward!(
|
||||||
|
env::NearestNeighbourEnv, env_helper::NearestNeighbourEnvHelper, particle::ReCo.Particle
|
||||||
|
)
|
||||||
|
reward = minimizing_reward(env_helper.current_κ, env_helper.max_distance_to_goal_κ)
|
||||||
|
set_normalized_reward!(env, reward, env_helper)
|
||||||
|
|
||||||
|
return nothing
|
||||||
|
end
|
|
@ -61,7 +61,7 @@ function pre_integration_hook!(::OriginEnvHelper)
|
||||||
end
|
end
|
||||||
|
|
||||||
function state_update_helper_hook!(
|
function state_update_helper_hook!(
|
||||||
::OriginEnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}
|
::OriginEnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}, distance²::Float64
|
||||||
)
|
)
|
||||||
return nothing
|
return nothing
|
||||||
end
|
end
|
||||||
|
|
|
@ -5,7 +5,7 @@ function pre_integration_hook!(::EnvHelper)
|
||||||
end
|
end
|
||||||
|
|
||||||
function state_update_helper_hook!(
|
function state_update_helper_hook!(
|
||||||
::EnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}
|
::EnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}, distance²::Float64
|
||||||
)
|
)
|
||||||
return ReCo.method_not_implemented()
|
return ReCo.method_not_implemented()
|
||||||
end
|
end
|
||||||
|
|
|
@ -54,7 +54,7 @@ function euler!(
|
||||||
p1_c, p2.c, args.interaction_radius², args.half_box_len
|
p1_c, p2.c, args.interaction_radius², args.half_box_len
|
||||||
)
|
)
|
||||||
|
|
||||||
state_update_helper_hook!(env_helper, id1, id2, r⃗₁₂)
|
state_update_helper_hook!(env_helper, id1, id2, r⃗₁₂, distance²)
|
||||||
|
|
||||||
if overlapping
|
if overlapping
|
||||||
factor = args.ϵσ⁶δtμₜ24 / (distance²^4) * (1.0 - args.σ⁶2 / (distance²^3))
|
factor = args.ϵσ⁶δtμₜ24 / (distance²^4) * (1.0 - args.σ⁶2 / (distance²^3))
|
||||||
|
|
Loading…
Reference in a new issue