1
0
Fork 0
mirror of https://gitlab.rlp.net/mobitar/ReCo.jl.git synced 2024-12-30 17:53:32 +00:00

Added NearestNeighbourEnv

This commit is contained in:
Mo8it 2022-01-31 01:47:48 +01:00
parent 9e80c22bdb
commit 9e3b71197c
5 changed files with 151 additions and 4 deletions

View file

@ -26,7 +26,13 @@ struct LocalCOMWithAdditionalShapeRewardEnv <: Env
n_states = n_distance_states * n_direction_angle_states + 1 n_states = n_distance_states * n_direction_angle_states + 1
# Last state is when no particle is in the skin radius # Last state is when no particle is in the skin radius
shared = EnvSharedProps(n_states, (n_distance_states, n_direction_angle_states)) state_spaces_labels = gen_state_spaces_labels(
("d", "\\theta"), (distance_state_space, direction_angle_state_space)
)
shared = EnvSharedProps(
n_states, (n_distance_states, n_direction_angle_states), state_spaces_labels
)
return new(shared, distance_state_space, direction_angle_state_space) return new(shared, distance_state_space, direction_angle_state_space)
end end
@ -95,6 +101,7 @@ function state_update_helper_hook!(
id1::Int64, id1::Int64,
id2::Int64, id2::Int64,
r⃗₁₂::SVector{2,Float64}, r⃗₁₂::SVector{2,Float64},
distance²::Float64,
) )
env_helper.vec_to_neighbour_sums[id1] += r⃗₁₂ env_helper.vec_to_neighbour_sums[id1] += r⃗₁₂
env_helper.vec_to_neighbour_sums[id2] -= r⃗₁₂ env_helper.vec_to_neighbour_sums[id2] -= r⃗₁₂

View file

@ -0,0 +1,140 @@
export NearestNeighbourEnv
using ..ReCo: ReCo
struct NearestNeighbourEnv <: Env
shared::EnvSharedProps
distance_state_space::Vector{Interval}
direction_angle_state_space::Vector{Interval}
function NearestNeighbourEnv(;
n_distance_states::Int64=3, n_direction_angle_states::Int64=3, args
)
@assert n_distance_states > 1
@assert n_direction_angle_states > 1
direction_angle_state_space = gen_angle_state_space(n_direction_angle_states)
min_distance = 0.0
max_distance = args.skin_radius
distance_state_space = gen_distance_state_space(
min_distance, max_distance, n_distance_states
)
n_states = n_distance_states * n_direction_angle_states + 1
# Last state is when no particle is in the skin radius
shared = EnvSharedProps(n_states, (n_distance_states, n_direction_angle_states))
return new(shared, distance_state_space, direction_angle_state_space)
end
end
mutable struct NearestNeighbourEnvHelper <: EnvHelper
shared::EnvHelperSharedProps
vecs_to_neighbour::Vector{SVector{2,Float64}}
sq_distances_to_neighbour::Vector{Float64}
current_κ::Float64
goal_κ::Float64
max_distance_to_goal_κ::Float64
half_box_len::Float64
function NearestNeighbourEnvHelper(shared::EnvHelperSharedProps, half_box_len::Float64)
goal_κ = 0.4
max_distance_to_goal_κ = max(1 - goal_κ, goal_κ)
return new(
shared,
fill(SVector(0.0, 0.0), shared.n_particles),
zeros(shared.n_particles),
1.0,
goal_κ,
max_distance_to_goal_κ,
half_box_len,
)
end
end
function gen_env_helper(
::NearestNeighbourEnv, env_helper_shared::EnvHelperSharedProps; args
)
return NearestNeighbourEnvHelper(env_helper_shared, args.half_box_len)
end
function pre_integration_hook!(env_helper::NearestNeighbourEnvHelper)
@simd for particle_id in 1:(env_helper.shared.n_particles)
env_helper.sq_distances_to_neighbour[particle_id] = Inf64
end
return nothing
end
function state_update_helper_hook!(
env_helper::NearestNeighbourEnvHelper,
id1::Int64,
id2::Int64,
r⃗₁₂::SVector{2,Float64},
distance²::Float64,
)
if distance² < env_helper.sq_distances_to_neighbour[id1]
env_helper.vecs_to_neighbour[id1] = r⃗₁₂
env_helper.sq_distances_to_neighbour[id1] = distance²
end
if distance² < env_helper.sq_distances_to_neighbour[id2]
env_helper.vecs_to_neighbour[id2] = -r⃗₁₂
env_helper.sq_distances_to_neighbour[id2] = distance²
end
return nothing
end
function state_update_hook!(
env_helper::NearestNeighbourEnvHelper, particles::Vector{ReCo.Particle}
)
n_particles = env_helper.shared.n_particles
env = env_helper.shared.env
for particle_id in 1:n_particles
sq_distance = env_helper.sq_distances_to_neighbour[particle_id]
if sq_distance == Inf64
state_id = env.shared.n_states
else
distance = sqrt(sq_distance)
distance_state_ind = find_state_ind(distance, env.distance_state_space)
si, co = sincos(particles[particle_id].φ)
direction_angle = ReCo.angle2(
SVector(co, si), env_helper.vecs_to_neighbour[particle_id]
)
direction_state_ind = find_state_ind(
direction_angle, env.direction_angle_state_space
)
state_id = env.shared.state_id_tensor[distance_state_ind, direction_state_ind]
end
env_helper.shared.states_id[particle_id] = state_id
end
env_helper.current_κ = ReCo.gyration_tensor_eigvals_ratio(
particles, env_helper.half_box_len
)
return nothing
end
function update_reward!(
env::NearestNeighbourEnv, env_helper::NearestNeighbourEnvHelper, particle::ReCo.Particle
)
reward = minimizing_reward(env_helper.current_κ, env_helper.max_distance_to_goal_κ)
set_normalized_reward!(env, reward, env_helper)
return nothing
end

View file

@ -61,7 +61,7 @@ function pre_integration_hook!(::OriginEnvHelper)
end end
function state_update_helper_hook!( function state_update_helper_hook!(
::OriginEnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64} ::OriginEnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}, distance²::Float64
) )
return nothing return nothing
end end

View file

@ -5,7 +5,7 @@ function pre_integration_hook!(::EnvHelper)
end end
function state_update_helper_hook!( function state_update_helper_hook!(
::EnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64} ::EnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}, distance²::Float64
) )
return ReCo.method_not_implemented() return ReCo.method_not_implemented()
end end

View file

@ -54,7 +54,7 @@ function euler!(
p1_c, p2.c, args.interaction_radius², args.half_box_len p1_c, p2.c, args.interaction_radius², args.half_box_len
) )
state_update_helper_hook!(env_helper, id1, id2, r⃗₁₂) state_update_helper_hook!(env_helper, id1, id2, r⃗₁₂, distance²)
if overlapping if overlapping
factor = args.ϵσ⁶δtμₜ24 / (distance²^4) * (1.0 - args.σ⁶2 / (distance²^3)) factor = args.ϵσ⁶δtμₜ24 / (distance²^4) * (1.0 - args.σ⁶2 / (distance²^3))