From 9e3b71197c3ea31086e758ba58c36bfcbff5e75c Mon Sep 17 00:00:00 2001 From: Mo8it Date: Mon, 31 Jan 2022 01:47:48 +0100 Subject: [PATCH] Added NearestNeighbourEnv --- .../LocalCOMWithAdditionalShapeRewardEnv.jl | 9 +- src/RL/Envs/NearestNeighbourEnv.jl | 140 ++++++++++++++++++ src/RL/Envs/OriginEnv.jl | 2 +- src/RL/Hooks.jl | 2 +- src/simulation.jl | 2 +- 5 files changed, 151 insertions(+), 4 deletions(-) create mode 100644 src/RL/Envs/NearestNeighbourEnv.jl diff --git a/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv.jl b/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv.jl index 5b315ba..29d825f 100644 --- a/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv.jl +++ b/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv.jl @@ -26,7 +26,13 @@ struct LocalCOMWithAdditionalShapeRewardEnv <: Env n_states = n_distance_states * n_direction_angle_states + 1 # Last state is when no particle is in the skin radius - shared = EnvSharedProps(n_states, (n_distance_states, n_direction_angle_states)) + state_spaces_labels = gen_state_spaces_labels( + ("d", "\\theta"), (distance_state_space, direction_angle_state_space) + ) + + shared = EnvSharedProps( + n_states, (n_distance_states, n_direction_angle_states), state_spaces_labels + ) return new(shared, distance_state_space, direction_angle_state_space) end @@ -95,6 +101,7 @@ function state_update_helper_hook!( id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}, + distance²::Float64, ) env_helper.vec_to_neighbour_sums[id1] += r⃗₁₂ env_helper.vec_to_neighbour_sums[id2] -= r⃗₁₂ diff --git a/src/RL/Envs/NearestNeighbourEnv.jl b/src/RL/Envs/NearestNeighbourEnv.jl new file mode 100644 index 0000000..a12eab0 --- /dev/null +++ b/src/RL/Envs/NearestNeighbourEnv.jl @@ -0,0 +1,140 @@ +export NearestNeighbourEnv + +using ..ReCo: ReCo + +struct NearestNeighbourEnv <: Env + shared::EnvSharedProps + + distance_state_space::Vector{Interval} + direction_angle_state_space::Vector{Interval} + + function NearestNeighbourEnv(; + n_distance_states::Int64=3, n_direction_angle_states::Int64=3, args + ) + @assert n_distance_states > 1 + @assert n_direction_angle_states > 1 + + direction_angle_state_space = gen_angle_state_space(n_direction_angle_states) + + min_distance = 0.0 + max_distance = args.skin_radius + + distance_state_space = gen_distance_state_space( + min_distance, max_distance, n_distance_states + ) + + n_states = n_distance_states * n_direction_angle_states + 1 + # Last state is when no particle is in the skin radius + + shared = EnvSharedProps(n_states, (n_distance_states, n_direction_angle_states)) + + return new(shared, distance_state_space, direction_angle_state_space) + end +end + +mutable struct NearestNeighbourEnvHelper <: EnvHelper + shared::EnvHelperSharedProps + + vecs_to_neighbour::Vector{SVector{2,Float64}} + sq_distances_to_neighbour::Vector{Float64} + + current_κ::Float64 + goal_κ::Float64 + max_distance_to_goal_κ::Float64 + + half_box_len::Float64 + + function NearestNeighbourEnvHelper(shared::EnvHelperSharedProps, half_box_len::Float64) + goal_κ = 0.4 + max_distance_to_goal_κ = max(1 - goal_κ, goal_κ) + + return new( + shared, + fill(SVector(0.0, 0.0), shared.n_particles), + zeros(shared.n_particles), + 1.0, + goal_κ, + max_distance_to_goal_κ, + half_box_len, + ) + end +end + +function gen_env_helper( + ::NearestNeighbourEnv, env_helper_shared::EnvHelperSharedProps; args +) + return NearestNeighbourEnvHelper(env_helper_shared, args.half_box_len) +end + +function pre_integration_hook!(env_helper::NearestNeighbourEnvHelper) + @simd for particle_id in 1:(env_helper.shared.n_particles) + env_helper.sq_distances_to_neighbour[particle_id] = Inf64 + end + + return nothing +end + +function state_update_helper_hook!( + env_helper::NearestNeighbourEnvHelper, + id1::Int64, + id2::Int64, + r⃗₁₂::SVector{2,Float64}, + distance²::Float64, +) + if distance² < env_helper.sq_distances_to_neighbour[id1] + env_helper.vecs_to_neighbour[id1] = r⃗₁₂ + env_helper.sq_distances_to_neighbour[id1] = distance² + end + + if distance² < env_helper.sq_distances_to_neighbour[id2] + env_helper.vecs_to_neighbour[id2] = -r⃗₁₂ + env_helper.sq_distances_to_neighbour[id2] = distance² + end + + return nothing +end + +function state_update_hook!( + env_helper::NearestNeighbourEnvHelper, particles::Vector{ReCo.Particle} +) + n_particles = env_helper.shared.n_particles + + env = env_helper.shared.env + + for particle_id in 1:n_particles + sq_distance = env_helper.sq_distances_to_neighbour[particle_id] + if sq_distance == Inf64 + state_id = env.shared.n_states + else + distance = sqrt(sq_distance) + distance_state_ind = find_state_ind(distance, env.distance_state_space) + + si, co = sincos(particles[particle_id].φ) + direction_angle = ReCo.angle2( + SVector(co, si), env_helper.vecs_to_neighbour[particle_id] + ) + direction_state_ind = find_state_ind( + direction_angle, env.direction_angle_state_space + ) + + state_id = env.shared.state_id_tensor[distance_state_ind, direction_state_ind] + end + + env_helper.shared.states_id[particle_id] = state_id + end + + env_helper.current_κ = ReCo.gyration_tensor_eigvals_ratio( + particles, env_helper.half_box_len + ) + + return nothing +end + +function update_reward!( + env::NearestNeighbourEnv, env_helper::NearestNeighbourEnvHelper, particle::ReCo.Particle +) + reward = minimizing_reward(env_helper.current_κ, env_helper.max_distance_to_goal_κ) + set_normalized_reward!(env, reward, env_helper) + + return nothing +end \ No newline at end of file diff --git a/src/RL/Envs/OriginEnv.jl b/src/RL/Envs/OriginEnv.jl index f75983a..9ad38e8 100644 --- a/src/RL/Envs/OriginEnv.jl +++ b/src/RL/Envs/OriginEnv.jl @@ -61,7 +61,7 @@ function pre_integration_hook!(::OriginEnvHelper) end function state_update_helper_hook!( - ::OriginEnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64} + ::OriginEnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}, distance²::Float64 ) return nothing end diff --git a/src/RL/Hooks.jl b/src/RL/Hooks.jl index 18a4756..f282f17 100644 --- a/src/RL/Hooks.jl +++ b/src/RL/Hooks.jl @@ -5,7 +5,7 @@ function pre_integration_hook!(::EnvHelper) end function state_update_helper_hook!( - ::EnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64} + ::EnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}, distance²::Float64 ) return ReCo.method_not_implemented() end diff --git a/src/simulation.jl b/src/simulation.jl index 6a71558..4fd6213 100644 --- a/src/simulation.jl +++ b/src/simulation.jl @@ -54,7 +54,7 @@ function euler!( p1_c, p2.c, args.interaction_radius², args.half_box_len ) - state_update_helper_hook!(env_helper, id1, id2, r⃗₁₂) + state_update_helper_hook!(env_helper, id1, id2, r⃗₁₂, distance²) if overlapping factor = args.ϵσ⁶δtμₜ24 / (distance²^4) * (1.0 - args.σ⁶2 / (distance²^3))