2022-01-11 01:31:30 +01:00
|
|
|
export LocalCOMEnv
|
|
|
|
|
|
|
|
struct LocalCOMEnv <: Env
|
2022-01-11 18:39:38 +01:00
|
|
|
shared::EnvSharedProps
|
2022-01-11 01:31:30 +01:00
|
|
|
|
|
|
|
distance_state_space::Vector{Interval}
|
|
|
|
direction_angle_state_space::Vector{Interval}
|
|
|
|
|
|
|
|
max_distance::Float64
|
|
|
|
|
|
|
|
function LocalCOMEnv(
|
|
|
|
sim_consts; n_distance_states::Int64=3, n_direction_angle_states::Int64=3
|
|
|
|
)
|
|
|
|
@assert n_direction_angle_states > 1
|
|
|
|
|
|
|
|
direction_angle_state_space = gen_angle_state_space(n_direction_angle_states)
|
|
|
|
|
|
|
|
min_distance = 0.0
|
|
|
|
max_distance = sim_consts.skin_r
|
|
|
|
|
|
|
|
distance_state_space = gen_distance_state_space(
|
|
|
|
min_distance, max_distance, n_distance_states
|
|
|
|
)
|
|
|
|
|
|
|
|
n_states = n_distance_states * n_direction_angle_states + 1
|
|
|
|
|
|
|
|
state_space = Vector{SVector{2,Interval}}(undef, n_states - 1)
|
|
|
|
|
|
|
|
ind = 1
|
|
|
|
for distance_state in distance_state_space
|
|
|
|
for direction_angle_state in direction_angle_state_space
|
|
|
|
state_space[ind] = SVector(distance_state, direction_angle_state)
|
|
|
|
ind += 1
|
|
|
|
end
|
|
|
|
end
|
|
|
|
# Last state is when no particle is in the skin radius
|
|
|
|
|
2022-01-11 18:39:38 +01:00
|
|
|
shared = EnvSharedProps(n_states, state_space)
|
2022-01-11 01:31:30 +01:00
|
|
|
|
2022-01-11 18:39:38 +01:00
|
|
|
return new(shared, distance_state_space, direction_angle_state_space, max_distance)
|
2022-01-11 01:31:30 +01:00
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
struct LocalCOMEnvHelper <: EnvHelper
|
2022-01-11 18:39:38 +01:00
|
|
|
shared::EnvHelperSharedProps
|
2022-01-11 01:31:30 +01:00
|
|
|
|
|
|
|
vec_to_neighbour_sums::Vector{SVector{2,Float64}}
|
|
|
|
n_neighbours::Vector{Int64}
|
2022-01-11 18:39:38 +01:00
|
|
|
sq_norm2d_vec_to_local_center_of_mass::Vector{Float64}
|
2022-01-11 01:31:30 +01:00
|
|
|
|
2022-01-11 18:39:38 +01:00
|
|
|
function LocalCOMEnvHelper(shared::EnvHelperSharedProps)
|
2022-01-11 01:31:30 +01:00
|
|
|
return new(
|
2022-01-11 18:39:38 +01:00
|
|
|
shared,
|
|
|
|
fill(SVector(0.0, 0.0), shared.n_particles),
|
|
|
|
fill(0, shared.n_particles),
|
|
|
|
zeros(shared.n_particles),
|
2022-01-11 01:31:30 +01:00
|
|
|
)
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
2022-01-11 18:39:38 +01:00
|
|
|
function gen_env_helper(::LocalCOMEnv, env_helper_params::EnvHelperSharedProps)
|
2022-01-11 01:31:30 +01:00
|
|
|
return LocalCOMEnvHelper(env_helper_params)
|
|
|
|
end
|
|
|
|
|
|
|
|
function pre_integration_hook(env_helper::LocalCOMEnvHelper)
|
2022-01-11 18:39:38 +01:00
|
|
|
@simd for id in 1:(env_helper.shared.n_particles)
|
2022-01-11 01:31:30 +01:00
|
|
|
env_helper.vec_to_neighbour_sums[id] = SVector(0.0, 0.0)
|
|
|
|
env_helper.n_neighbours[id] = 0
|
|
|
|
end
|
|
|
|
|
|
|
|
return nothing
|
|
|
|
end
|
|
|
|
|
|
|
|
function state_update_helper_hook(
|
|
|
|
env_helper::LocalCOMEnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}
|
|
|
|
)
|
|
|
|
env_helper.vec_to_neighbour_sums[id1] += r⃗₁₂
|
|
|
|
env_helper.vec_to_neighbour_sums[id2] -= r⃗₁₂
|
|
|
|
|
|
|
|
env_helper.n_neighbours[id1] += 1
|
|
|
|
env_helper.n_neighbours[id2] += 1
|
|
|
|
|
|
|
|
return nothing
|
|
|
|
end
|
|
|
|
|
|
|
|
function state_update_hook(env_helper::LocalCOMEnvHelper, particles::Vector{Particle})
|
2022-01-11 18:39:38 +01:00
|
|
|
n_particles = env_helper.shared.n_particles
|
2022-01-11 01:31:30 +01:00
|
|
|
|
|
|
|
@turbo for id in 1:(n_particles)
|
2022-01-11 18:39:38 +01:00
|
|
|
env_helper.shared.old_states_ind[id] = env_helper.shared.states_ind[id]
|
2022-01-11 01:31:30 +01:00
|
|
|
end
|
|
|
|
|
2022-01-11 18:39:38 +01:00
|
|
|
env = env_helper.shared.env
|
2022-01-11 01:31:30 +01:00
|
|
|
|
|
|
|
for id in 1:n_particles
|
|
|
|
n_neighbours = env_helper.n_neighbours[id]
|
|
|
|
|
|
|
|
if n_neighbours == 0
|
2022-01-11 18:39:38 +01:00
|
|
|
state_ind = env.shared.n_states
|
2022-01-11 01:31:30 +01:00
|
|
|
else
|
|
|
|
vec_to_local_center_of_mass =
|
|
|
|
env_helper.vec_to_neighbour_sums[id] / n_neighbours
|
|
|
|
|
2022-01-11 18:39:38 +01:00
|
|
|
sq_norm2d_vec_to_local_center_of_mass = sq_norm2d(vec_to_local_center_of_mass)
|
|
|
|
env_helper.sq_norm2d_vec_to_local_center_of_mass[id] =
|
|
|
|
sq_norm2d_vec_to_local_center_of_mass
|
|
|
|
|
|
|
|
distance = sqrt(sq_norm2d_vec_to_local_center_of_mass)
|
2022-01-11 01:31:30 +01:00
|
|
|
|
|
|
|
distance_state = find_state_interval(distance, env.distance_state_space)
|
|
|
|
|
|
|
|
si, co = sincos(particles[id].φ)
|
|
|
|
|
|
|
|
direction_angle = angle2(SVector(co, si), vec_to_local_center_of_mass)
|
|
|
|
|
|
|
|
direction_angle_state = find_state_interval(
|
|
|
|
direction_angle, env.direction_angle_state_space
|
|
|
|
)
|
|
|
|
|
|
|
|
state = SVector{2,Interval}(distance_state, direction_angle_state)
|
2022-01-11 18:39:38 +01:00
|
|
|
state_ind = find_state_ind(state, env.shared.state_space)
|
2022-01-11 01:31:30 +01:00
|
|
|
end
|
|
|
|
|
2022-01-11 18:39:38 +01:00
|
|
|
env_helper.shared.states_ind[id] = state_ind
|
2022-01-11 01:31:30 +01:00
|
|
|
end
|
|
|
|
|
|
|
|
return nothing
|
|
|
|
end
|
|
|
|
|
|
|
|
function update_reward!(env::LocalCOMEnv, env_helper::LocalCOMEnvHelper, particle::Particle)
|
|
|
|
id = particle.id
|
|
|
|
|
2022-01-11 18:39:38 +01:00
|
|
|
normalization = (env.max_distance * env_helper.shared.n_particles)
|
2022-01-11 01:31:30 +01:00
|
|
|
|
|
|
|
n_neighbours = env_helper.n_neighbours[id]
|
|
|
|
if n_neighbours == 0
|
2022-01-11 18:39:38 +01:00
|
|
|
env.shared.reward = -(env.max_distance^2) / normalization
|
2022-01-11 01:31:30 +01:00
|
|
|
else
|
2022-01-11 18:39:38 +01:00
|
|
|
env.shared.reward =
|
|
|
|
-(env_helper.sq_norm2d_vec_to_local_center_of_mass[id]) / normalization # TODO: Add shape term
|
2022-01-11 01:31:30 +01:00
|
|
|
end
|
|
|
|
|
|
|
|
return nothing
|
|
|
|
end
|