1
0
Fork 0
mirror of https://gitlab.rlp.net/mobitar/ReCo.jl.git synced 2024-09-19 19:01:17 +00:00

Added LocalCOMEnv

This commit is contained in:
Mo8it 2022-01-31 02:34:49 +01:00
parent a647188e09
commit 44398370e4
3 changed files with 151 additions and 2 deletions

146
src/RL/Envs/LocalCOMEnv.jl Normal file
View file

@ -0,0 +1,146 @@
export LocalCOMEnv
using ..ReCo: ReCo
struct LocalCOMEnv <: Env
shared::EnvSharedProps
distance_state_space::Vector{Interval}
direction_angle_state_space::Vector{Interval}
function LocalCOMEnv(;
n_distance_states::Int64=3, n_direction_angle_states::Int64=3, args
)
@assert n_distance_states > 1
@assert n_direction_angle_states > 1
direction_angle_state_space = gen_angle_state_space(n_direction_angle_states)
min_distance = 0.0
max_distance = args.skin_radius
distance_state_space = gen_distance_state_space(
min_distance, max_distance, n_distance_states
)
n_states = n_distance_states * n_direction_angle_states + 1
# Last state is when no particle is in the skin radius
state_spaces_labels = gen_state_spaces_labels(
("d", "\\theta"), (distance_state_space, direction_angle_state_space)
)
shared = EnvSharedProps(
n_states, (n_distance_states, n_direction_angle_states), state_spaces_labels
)
return new(shared, distance_state_space, direction_angle_state_space)
end
end
mutable struct LocalCOMEnvHelper <: EnvHelper
shared::EnvHelperSharedProps
vec_to_neighbour_sums::Vector{SVector{2,Float64}}
n_neighbours::Vector{Int64}
distances_to_local_center_of_mass::Vector{Float64}
max_distance_to_local_center_of_mass::Float64
half_box_len::Float64
function LocalCOMEnvHelper(
shared::EnvHelperSharedProps, half_box_len::Float64, skin_radius::Float64
)
max_distance_to_local_center_of_mass = skin_radius
return new(
shared,
fill(SVector(0.0, 0.0), shared.n_particles),
fill(0, shared.n_particles),
zeros(shared.n_particles),
max_distance_to_local_center_of_mass,
half_box_len,
)
end
end
function gen_env_helper(::LocalCOMEnv, env_helper_shared::EnvHelperSharedProps; args)
return LocalCOMEnvHelper(env_helper_shared, args.half_box_len, args.skin_radius)
end
function pre_integration_hook!(env_helper::LocalCOMEnvHelper)
@simd for id in 1:(env_helper.shared.n_particles)
env_helper.vec_to_neighbour_sums[id] = SVector(0.0, 0.0)
env_helper.n_neighbours[id] = 0
end
return nothing
end
function state_update_helper_hook!(
env_helper::LocalCOMEnvHelper,
id1::Int64,
id2::Int64,
r⃗₁₂::SVector{2,Float64},
distance²::Float64,
)
env_helper.vec_to_neighbour_sums[id1] += r⃗₁₂
env_helper.vec_to_neighbour_sums[id2] -= r⃗₁₂
env_helper.n_neighbours[id1] += 1
env_helper.n_neighbours[id2] += 1
return nothing
end
function state_update_hook!(env_helper::LocalCOMEnvHelper, particles::Vector{ReCo.Particle})
n_particles = env_helper.shared.n_particles
env = env_helper.shared.env
for particle_id in 1:n_particles
n_neighbours = env_helper.n_neighbours[particle_id]
if n_neighbours == 0
state_id = env.shared.n_states
else
vec_to_local_center_of_mass =
env_helper.vec_to_neighbour_sums[particle_id] / n_neighbours
distance = ReCo.norm2d(vec_to_local_center_of_mass)
env_helper.distances_to_local_center_of_mass[particle_id] = distance
distance_state_ind = find_state_ind(distance, env.distance_state_space)
si, co = sincos(particles[particle_id].φ)
direction_angle = ReCo.angle2(SVector(co, si), vec_to_local_center_of_mass)
direction_state_ind = find_state_ind(
direction_angle, env.direction_angle_state_space
)
state_id = env.shared.state_id_tensor[distance_state_ind, direction_state_ind]
end
env_helper.shared.states_id[particle_id] = state_id
end
return nothing
end
function update_reward!(
env::LocalCOMEnv, env_helper::LocalCOMEnvHelper, particle::ReCo.Particle
)
n_neighbours = env_helper.n_neighbours[particle.id]
if n_neighbours == 0
env.shared.reward = 0.0
else
reward = minimizing_reward(
env_helper.distances_to_local_center_of_mass[particle.id],
env_helper.max_distance_to_local_center_of_mass,
)
set_normalized_reward!(env, reward, env_helper)
end
return nothing
end

View file

@ -1,6 +1,7 @@
module RL
export run_rl, LocalCOMWithAdditionalShapeRewardEnv, OriginEnv, NearestNeighbourEnv
export run_rl,
LocalCOMWithAdditionalShapeRewardEnv, OriginEnv, NearestNeighbourEnv, LocalCOMEnv
using Base: OneTo
@ -191,5 +192,6 @@ end
include("Envs/LocalCOMWithAdditionalShapeRewardEnv.jl")
include("Envs/OriginEnv.jl")
include("Envs/NearestNeighbourEnv.jl")
include("Envs/LocalCOMEnv.jl")
end # module

View file

@ -7,7 +7,8 @@ export init_sim,
plot_snapshot,
LocalCOMWithAdditionalShapeRewardEnv,
OriginEnv,
NearestNeighbourEnv
NearestNeighbourEnv,
LocalCOMEnv
using StaticArrays: SVector
using JLD2: JLD2