mirror of
https://gitlab.rlp.net/mobitar/ReCo.jl.git
synced 2024-12-21 00:51:21 +00:00
Added OriginCompassEnv
This commit is contained in:
parent
705bd462c3
commit
b359c4834a
9 changed files with 163 additions and 13 deletions
|
@ -7,7 +7,6 @@ using ReCo: ReCo
|
||||||
function gen_rdf_graphics()
|
function gen_rdf_graphics()
|
||||||
Random.seed!(1)
|
Random.seed!(1)
|
||||||
|
|
||||||
box_length = 100
|
|
||||||
box_length = 100
|
box_length = 100
|
||||||
|
|
||||||
graphics_export_dir = "exports/graphics"
|
graphics_export_dir = "exports/graphics"
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
export LocalCOMEnv
|
|
||||||
|
|
||||||
using ..ReCo: ReCo
|
using ..ReCo: ReCo
|
||||||
|
|
||||||
struct LocalCOMEnv <: Env
|
struct LocalCOMEnv <: Env
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
export LocalCOMWithAdditionalShapeRewardEnv
|
|
||||||
|
|
||||||
using ..ReCo: ReCo
|
using ..ReCo: ReCo
|
||||||
|
|
||||||
struct LocalCOMWithAdditionalShapeRewardEnv <: Env
|
struct LocalCOMWithAdditionalShapeRewardEnv <: Env
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
export LocalCOMWithAdditionalShapeRewardEnv2
|
|
||||||
|
|
||||||
using ..ReCo: ReCo
|
using ..ReCo: ReCo
|
||||||
|
|
||||||
struct LocalCOMWithAdditionalShapeRewardEnv2 <: Env
|
struct LocalCOMWithAdditionalShapeRewardEnv2 <: Env
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
export NearestNeighbourEnv
|
|
||||||
|
|
||||||
using ..ReCo: ReCo
|
using ..ReCo: ReCo
|
||||||
|
|
||||||
struct NearestNeighbourEnv <: Env
|
struct NearestNeighbourEnv <: Env
|
||||||
|
|
155
src/RL/Envs/OriginCompass.jl
Normal file
155
src/RL/Envs/OriginCompass.jl
Normal file
|
@ -0,0 +1,155 @@
|
||||||
|
using ..ReCo: ReCo
|
||||||
|
|
||||||
|
struct OriginCompassEnv <: Env
|
||||||
|
shared::EnvSharedProps
|
||||||
|
|
||||||
|
distance_state_space::Vector{Interval}
|
||||||
|
direction_angle_state_space::Vector{Interval}
|
||||||
|
position_angle_state_space::Vector{Interval}
|
||||||
|
|
||||||
|
function OriginCompassEnv(;
|
||||||
|
n_distance_states::Int64=3,
|
||||||
|
n_direction_angle_states::Int64=3,
|
||||||
|
n_position_angle_states::Int64=8,
|
||||||
|
args,
|
||||||
|
)
|
||||||
|
@assert n_distance_states > 1
|
||||||
|
@assert n_direction_angle_states > 1
|
||||||
|
@assert n_position_angle_states > 1
|
||||||
|
|
||||||
|
direction_angle_state_space = gen_angle_state_space(n_direction_angle_states)
|
||||||
|
position_angle_state_space = gen_angle_state_space(n_position_angle_states)
|
||||||
|
|
||||||
|
min_distance = 0.0
|
||||||
|
max_distance = sqrt(2) * args.half_box_len
|
||||||
|
|
||||||
|
distance_state_space = gen_distance_state_space(
|
||||||
|
min_distance, max_distance, n_distance_states
|
||||||
|
)
|
||||||
|
|
||||||
|
n_states = n_distance_states * n_direction_angle_states * n_position_angle_states
|
||||||
|
|
||||||
|
state_spaces_labels = gen_state_spaces_labels(
|
||||||
|
("d", "\\theta", "\\alpha"),
|
||||||
|
(distance_state_space, direction_angle_state_space, position_angle_state_space),
|
||||||
|
)
|
||||||
|
|
||||||
|
shared = EnvSharedProps(
|
||||||
|
n_states,
|
||||||
|
(n_distance_states, n_direction_angle_states, n_position_angle_states),
|
||||||
|
state_spaces_labels,
|
||||||
|
)
|
||||||
|
|
||||||
|
return new(
|
||||||
|
shared,
|
||||||
|
distance_state_space,
|
||||||
|
direction_angle_state_space,
|
||||||
|
position_angle_state_space,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
mutable struct OriginCompassEnvHelper <: EnvHelper
|
||||||
|
shared::EnvHelperSharedProps
|
||||||
|
|
||||||
|
center_of_mass::SVector{2,Float64}
|
||||||
|
gyration_tensor_eigvec_to_smaller_eigval::SVector{2,Float64}
|
||||||
|
gyration_tensor_eigvec_to_bigger_eigval::SVector{2,Float64}
|
||||||
|
|
||||||
|
half_box_len::Float64
|
||||||
|
max_elliptical_distance::Float64
|
||||||
|
|
||||||
|
function OriginCompassEnvHelper(shared::EnvHelperSharedProps, half_box_len::Float64)
|
||||||
|
max_elliptical_distance = sqrt(
|
||||||
|
half_box_len^2 + (half_box_len / shared.elliptical_a_b_ratio)^2
|
||||||
|
)
|
||||||
|
|
||||||
|
return new(
|
||||||
|
shared,
|
||||||
|
SVector(0.0, 0.0),
|
||||||
|
SVector(0.0, 0.0),
|
||||||
|
SVector(0.0, 0.0),
|
||||||
|
half_box_len,
|
||||||
|
max_elliptical_distance,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function gen_env_helper(::OriginCompassEnv, env_helper_shared::EnvHelperSharedProps; args)
|
||||||
|
return OriginCompassEnvHelper(env_helper_shared, args.half_box_len)
|
||||||
|
end
|
||||||
|
|
||||||
|
function pre_integration_hook!(::OriginCompassEnvHelper)
|
||||||
|
return nothing
|
||||||
|
end
|
||||||
|
|
||||||
|
function state_update_helper_hook!(
|
||||||
|
::OriginCompassEnvHelper,
|
||||||
|
id1::Int64,
|
||||||
|
id2::Int64,
|
||||||
|
r⃗₁₂::SVector{2,Float64},
|
||||||
|
distance²::Float64,
|
||||||
|
)
|
||||||
|
return nothing
|
||||||
|
end
|
||||||
|
|
||||||
|
function state_update_hook!(
|
||||||
|
env_helper::OriginCompassEnvHelper, particles::Vector{ReCo.Particle}
|
||||||
|
)
|
||||||
|
n_particles = env_helper.shared.n_particles
|
||||||
|
|
||||||
|
env = env_helper.shared.env
|
||||||
|
|
||||||
|
env_helper.center_of_mass = ReCo.center_of_mass(particles, env_helper.half_box_len)
|
||||||
|
|
||||||
|
for particle_id in 1:n_particles
|
||||||
|
vec_to_origin = -particles[particle_id].c
|
||||||
|
distance_to_center_of_mass = ReCo.norm2d(vec_to_origin)
|
||||||
|
distance_state_ind = find_state_ind(
|
||||||
|
distance_to_center_of_mass, env.distance_state_space
|
||||||
|
)
|
||||||
|
|
||||||
|
si, co = sincos(particles[particle_id].φ)
|
||||||
|
direction_angle = ReCo.angle2(SVector(co, si), vec_to_origin)
|
||||||
|
direction_state_ind = find_state_ind(
|
||||||
|
direction_angle, env.direction_angle_state_space
|
||||||
|
)
|
||||||
|
|
||||||
|
position_angle = atan(si, co)
|
||||||
|
position_state_ind = find_state_ind(position_angle, env.position_angle_state_space)
|
||||||
|
|
||||||
|
state_id = env.shared.state_id_tensor[
|
||||||
|
distance_state_ind, direction_state_ind, position_state_ind
|
||||||
|
]
|
||||||
|
|
||||||
|
env_helper.shared.states_id[particle_id] = state_id
|
||||||
|
end
|
||||||
|
|
||||||
|
v1, v2 = ReCo.gyration_tensor_eigvecs(
|
||||||
|
particles, env_helper.half_box_len, env_helper.center_of_mass
|
||||||
|
)
|
||||||
|
|
||||||
|
env_helper.gyration_tensor_eigvec_to_smaller_eigval = v1
|
||||||
|
env_helper.gyration_tensor_eigvec_to_bigger_eigval = v2
|
||||||
|
|
||||||
|
return nothing
|
||||||
|
end
|
||||||
|
|
||||||
|
function update_reward!(
|
||||||
|
env::OriginCompassEnv, env_helper::OriginCompassEnvHelper, particle::ReCo.Particle
|
||||||
|
)
|
||||||
|
elliptical_distance = ReCo.elliptical_distance(
|
||||||
|
particle.c,
|
||||||
|
env_helper.center_of_mass,
|
||||||
|
env_helper.gyration_tensor_eigvec_to_smaller_eigval,
|
||||||
|
env_helper.gyration_tensor_eigvec_to_bigger_eigval,
|
||||||
|
env_helper.shared.elliptical_a_b_ratio,
|
||||||
|
env_helper.half_box_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
reward = minimizing_reward(elliptical_distance, env_helper.max_elliptical_distance)
|
||||||
|
|
||||||
|
set_normalized_reward!(env, reward, env_helper)
|
||||||
|
|
||||||
|
return nothing
|
||||||
|
end
|
|
@ -1,5 +1,3 @@
|
||||||
export OriginEnv
|
|
||||||
|
|
||||||
using ..ReCo: ReCo
|
using ..ReCo: ReCo
|
||||||
|
|
||||||
struct OriginEnv <: Env
|
struct OriginEnv <: Env
|
||||||
|
|
|
@ -5,7 +5,9 @@ export run_rl,
|
||||||
LocalCOMWithAdditionalShapeRewardEnv2,
|
LocalCOMWithAdditionalShapeRewardEnv2,
|
||||||
OriginEnv,
|
OriginEnv,
|
||||||
NearestNeighbourEnv,
|
NearestNeighbourEnv,
|
||||||
LocalCOMEnv
|
LocalCOMEnv,
|
||||||
|
OriginCompassEnv,
|
||||||
|
COMCompassEnv
|
||||||
|
|
||||||
using Base: OneTo
|
using Base: OneTo
|
||||||
|
|
||||||
|
@ -198,5 +200,7 @@ include("Envs/LocalCOMWithAdditionalShapeRewardEnv2.jl")
|
||||||
include("Envs/OriginEnv.jl")
|
include("Envs/OriginEnv.jl")
|
||||||
include("Envs/NearestNeighbourEnv.jl")
|
include("Envs/NearestNeighbourEnv.jl")
|
||||||
include("Envs/LocalCOMEnv.jl")
|
include("Envs/LocalCOMEnv.jl")
|
||||||
|
include("Envs/OriginCompass.jl")
|
||||||
|
include("Envs/COMCompass.jl")
|
||||||
|
|
||||||
end # module
|
end # module
|
|
@ -10,7 +10,9 @@ export init_sim,
|
||||||
LocalCOMWithAdditionalShapeRewardEnv2,
|
LocalCOMWithAdditionalShapeRewardEnv2,
|
||||||
OriginEnv,
|
OriginEnv,
|
||||||
NearestNeighbourEnv,
|
NearestNeighbourEnv,
|
||||||
LocalCOMEnv
|
LocalCOMEnv,
|
||||||
|
OriginCompassEnv,
|
||||||
|
COMCompassEnv
|
||||||
|
|
||||||
using StaticArrays: SVector
|
using StaticArrays: SVector
|
||||||
using JLD2: JLD2
|
using JLD2: JLD2
|
||||||
|
|
Loading…
Reference in a new issue