1
0
Fork 0
mirror of https://gitlab.rlp.net/mobitar/ReCo.jl.git synced 2024-09-19 19:01:17 +00:00

Added OriginEnv

This commit is contained in:
Mo8it 2022-01-29 17:13:17 +01:00
parent 8850e5dd34
commit d568964eb4
3 changed files with 115 additions and 17 deletions

View file

@ -114,8 +114,8 @@ function state_update_hook!(
distance_to_local_center_of_mass_sum = 0.0 distance_to_local_center_of_mass_sum = 0.0
for id in 1:n_particles for particle_id in 1:n_particles
n_neighbours = env_helper.n_neighbours[id] n_neighbours = env_helper.n_neighbours[particle_id]
if n_neighbours == 0 if n_neighbours == 0
state_id = env.shared.n_states state_id = env.shared.n_states
@ -124,20 +124,14 @@ function state_update_hook!(
env_helper.max_distance_to_local_center_of_mass env_helper.max_distance_to_local_center_of_mass
else else
vec_to_local_center_of_mass = vec_to_local_center_of_mass =
env_helper.vec_to_neighbour_sums[id] / n_neighbours env_helper.vec_to_neighbour_sums[particle_id] / n_neighbours
distance = ReCo.norm2d(vec_to_local_center_of_mass) distance = ReCo.norm2d(vec_to_local_center_of_mass)
env_helper.distances_to_local_center_of_mass[particle_id] = distance
env_helper.distances_to_local_center_of_mass[id] = distance
distance_to_local_center_of_mass_sum += distance distance_to_local_center_of_mass_sum += distance
distance_state_ind = find_state_ind(distance, env.distance_state_space) distance_state_ind = find_state_ind(distance, env.distance_state_space)
si, co = sincos(particles[id].φ) si, co = sincos(particles[particle_id].φ)
direction_angle = ReCo.angle2(SVector(co, si), vec_to_local_center_of_mass) direction_angle = ReCo.angle2(SVector(co, si), vec_to_local_center_of_mass)
direction_state_ind = find_state_ind( direction_state_ind = find_state_ind(
direction_angle, env.direction_angle_state_space direction_angle, env.direction_angle_state_space
) )
@ -145,7 +139,7 @@ function state_update_hook!(
state_id = env.shared.state_id_tensor[distance_state_ind, direction_state_ind] state_id = env.shared.state_id_tensor[distance_state_ind, direction_state_ind]
end end
env_helper.shared.states_id[id] = state_id env_helper.shared.states_id[particle_id] = state_id
end end
mean_distance_to_local_center_of_mass = mean_distance_to_local_center_of_mass =
@ -171,16 +165,14 @@ function update_reward!(
env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper, env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper,
particle::ReCo.Particle, particle::ReCo.Particle,
) )
id = particle.id
normalization = env_helper.shared.n_particles # TODO: Add factor from steps normalization = env_helper.shared.n_particles # TODO: Add factor from steps
n_neighbours = env_helper.n_neighbours[id] n_neighbours = env_helper.n_neighbours[particle.id]
if n_neighbours == 0 if n_neighbours == 0
env.shared.reward = 0.0 env.shared.reward = 0.0
else else
reward = minimizing_reward( reward = minimizing_reward(
env_helper.distances_to_local_center_of_mass[id], env_helper.distances_to_local_center_of_mass[particle.id],
env_helper.max_distance_to_local_center_of_mass, env_helper.max_distance_to_local_center_of_mass,
) )

105
src/RL/Envs/OriginEnv.jl Normal file
View file

@ -0,0 +1,105 @@
export OriginEnv
using ..ReCo: ReCo
struct OriginEnv <: Env
shared::EnvSharedProps
distance_state_space::Vector{Interval}
direction_angle_state_space::Vector{Interval}
function OriginEnv(;
n_distance_states::Int64=4, n_direction_angle_states::Int64=3, args
)
@assert n_distance_states > 1
@assert n_direction_angle_states > 1
direction_angle_state_space = gen_angle_state_space(n_direction_angle_states)
min_distance = 0.0
max_distance = sqrt(2) * args.half_box_len
distance_state_space = gen_distance_state_space(
min_distance, max_distance, n_distance_states
)
n_states = n_distance_states * n_direction_angle_states
shared = EnvSharedProps(n_states, (n_distance_states, n_direction_angle_states))
return new(shared, distance_state_space, direction_angle_state_space)
end
end
mutable struct OriginEnvHelper <: EnvHelper
shared::EnvHelperSharedProps
distances_to_origin::Vector{Float64}
max_distance_to_origin::Float64
half_box_len::Float64
function OriginEnvHelper(shared::EnvHelperSharedProps, half_box_len::Float64)
max_distance_to_origin = sqrt(2) * half_box_len
return new(
shared,
fill(SVector(0.0, 0.0), shared.n_particles),
max_distance_to_origin,
half_box_len,
)
end
end
function gen_env_helper(::OriginEnv, env_helper_shared::EnvHelperSharedProps; args)
return OriginEnvHelper(env_helper_shared, args.half_box_len)
end
function pre_integration_hook!(::OriginEnvHelper)
return nothing
end
function state_update_helper_hook!(
::OriginEnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}
)
return nothing
end
function state_update_hook!(env_helper::OriginEnvHelper, particles::Vector{ReCo.Particle})
n_particles = env_helper.shared.n_particles
env = env_helper.shared.env
for particle_id in 1:n_particles
vec_to_origin = -particles[particle_id].c
distance_to_origin = ReCo.norm2d(vec_to_origin)
env_helper.distances_to_origin[particle_id] = distance_to_origin
distance_state_ind = find_state_ind(distance_to_origin, env.distance_state_space)
si, co = sincos(particles[particle_id].φ)
direction_angle = ReCo.angle2(SVector(co, si), vec_to_origin)
direction_state_ind = find_state_ind(
direction_angle, env.direction_angle_state_space
)
state_id = env.shared.state_id_tensor[distance_state_ind, direction_state_ind]
env_helper.shared.states_id[particle_id] = state_id
end
return nothing
end
function update_reward!(
env::OriginEnv, env_helper::OriginEnvHelper, particle::ReCo.Particle
)
normalization = env_helper.shared.n_particles # TODO: Add factor from steps
reward = minimizing_reward(
env_helper.distances_to_origin[particle.id], env_helper.max_distance_to_origin
)
env.shared.reward = reward / normalization
return nothing
end

View file

@ -1,6 +1,6 @@
module RL module RL
export run_rl, LocalCOMWithAdditionalShapeRewardEnv export run_rl, LocalCOMWithAdditionalShapeRewardEnv, OriginEnv
using Base: OneTo using Base: OneTo
@ -141,5 +141,6 @@ function run_rl(;
end end
include("Envs/LocalCOMWithAdditionalShapeRewardEnv.jl") include("Envs/LocalCOMWithAdditionalShapeRewardEnv.jl")
include("Envs/OriginEnv.jl")
end # module end # module