diff --git a/src/RL/Envs/LocalCOMEnv.jl b/src/RL/Envs/LocalCOMEnv.jl new file mode 100644 index 0000000..91cdca2 --- /dev/null +++ b/src/RL/Envs/LocalCOMEnv.jl @@ -0,0 +1,146 @@ +export LocalCOMEnv + +using ..ReCo: ReCo + +struct LocalCOMEnv <: Env + shared::EnvSharedProps + + distance_state_space::Vector{Interval} + direction_angle_state_space::Vector{Interval} + + function LocalCOMEnv(; + n_distance_states::Int64=3, n_direction_angle_states::Int64=3, args + ) + @assert n_distance_states > 1 + @assert n_direction_angle_states > 1 + + direction_angle_state_space = gen_angle_state_space(n_direction_angle_states) + + min_distance = 0.0 + max_distance = args.skin_radius + + distance_state_space = gen_distance_state_space( + min_distance, max_distance, n_distance_states + ) + + n_states = n_distance_states * n_direction_angle_states + 1 + # Last state is when no particle is in the skin radius + + state_spaces_labels = gen_state_spaces_labels( + ("d", "\\theta"), (distance_state_space, direction_angle_state_space) + ) + + shared = EnvSharedProps( + n_states, (n_distance_states, n_direction_angle_states), state_spaces_labels + ) + + return new(shared, distance_state_space, direction_angle_state_space) + end +end + +mutable struct LocalCOMEnvHelper <: EnvHelper + shared::EnvHelperSharedProps + + vec_to_neighbour_sums::Vector{SVector{2,Float64}} + n_neighbours::Vector{Int64} + + distances_to_local_center_of_mass::Vector{Float64} + max_distance_to_local_center_of_mass::Float64 + + half_box_len::Float64 + + function LocalCOMEnvHelper( + shared::EnvHelperSharedProps, half_box_len::Float64, skin_radius::Float64 + ) + max_distance_to_local_center_of_mass = skin_radius + + return new( + shared, + fill(SVector(0.0, 0.0), shared.n_particles), + fill(0, shared.n_particles), + zeros(shared.n_particles), + max_distance_to_local_center_of_mass, + half_box_len, + ) + end +end + +function gen_env_helper(::LocalCOMEnv, env_helper_shared::EnvHelperSharedProps; args) + return LocalCOMEnvHelper(env_helper_shared, args.half_box_len, args.skin_radius) +end + +function pre_integration_hook!(env_helper::LocalCOMEnvHelper) + @simd for id in 1:(env_helper.shared.n_particles) + env_helper.vec_to_neighbour_sums[id] = SVector(0.0, 0.0) + env_helper.n_neighbours[id] = 0 + end + + return nothing +end + +function state_update_helper_hook!( + env_helper::LocalCOMEnvHelper, + id1::Int64, + id2::Int64, + r⃗₁₂::SVector{2,Float64}, + distance²::Float64, +) + env_helper.vec_to_neighbour_sums[id1] += r⃗₁₂ + env_helper.vec_to_neighbour_sums[id2] -= r⃗₁₂ + + env_helper.n_neighbours[id1] += 1 + env_helper.n_neighbours[id2] += 1 + + return nothing +end + +function state_update_hook!(env_helper::LocalCOMEnvHelper, particles::Vector{ReCo.Particle}) + n_particles = env_helper.shared.n_particles + + env = env_helper.shared.env + + for particle_id in 1:n_particles + n_neighbours = env_helper.n_neighbours[particle_id] + + if n_neighbours == 0 + state_id = env.shared.n_states + else + vec_to_local_center_of_mass = + env_helper.vec_to_neighbour_sums[particle_id] / n_neighbours + distance = ReCo.norm2d(vec_to_local_center_of_mass) + env_helper.distances_to_local_center_of_mass[particle_id] = distance + distance_state_ind = find_state_ind(distance, env.distance_state_space) + + si, co = sincos(particles[particle_id].φ) + direction_angle = ReCo.angle2(SVector(co, si), vec_to_local_center_of_mass) + direction_state_ind = find_state_ind( + direction_angle, env.direction_angle_state_space + ) + + state_id = env.shared.state_id_tensor[distance_state_ind, direction_state_ind] + end + + env_helper.shared.states_id[particle_id] = state_id + end + + return nothing +end + +function update_reward!( + env::LocalCOMEnv, env_helper::LocalCOMEnvHelper, particle::ReCo.Particle +) + n_neighbours = env_helper.n_neighbours[particle.id] + + if n_neighbours == 0 + env.shared.reward = 0.0 + else + reward = minimizing_reward( + env_helper.distances_to_local_center_of_mass[particle.id], + env_helper.max_distance_to_local_center_of_mass, + ) + + set_normalized_reward!(env, reward, env_helper) + end + + return nothing +end \ No newline at end of file diff --git a/src/RL/RL.jl b/src/RL/RL.jl index 503ddad..d373f83 100644 --- a/src/RL/RL.jl +++ b/src/RL/RL.jl @@ -1,6 +1,7 @@ module RL -export run_rl, LocalCOMWithAdditionalShapeRewardEnv, OriginEnv, NearestNeighbourEnv +export run_rl, + LocalCOMWithAdditionalShapeRewardEnv, OriginEnv, NearestNeighbourEnv, LocalCOMEnv using Base: OneTo @@ -191,5 +192,6 @@ end include("Envs/LocalCOMWithAdditionalShapeRewardEnv.jl") include("Envs/OriginEnv.jl") include("Envs/NearestNeighbourEnv.jl") +include("Envs/LocalCOMEnv.jl") end # module \ No newline at end of file diff --git a/src/ReCo.jl b/src/ReCo.jl index 8c44e7f..f7a3bab 100644 --- a/src/ReCo.jl +++ b/src/ReCo.jl @@ -7,7 +7,8 @@ export init_sim, plot_snapshot, LocalCOMWithAdditionalShapeRewardEnv, OriginEnv, - NearestNeighbourEnv + NearestNeighbourEnv, + LocalCOMEnv using StaticArrays: SVector using JLD2: JLD2