mirror of
https://gitlab.rlp.net/mobitar/ReCo.jl.git
synced 2024-12-21 00:51:21 +00:00
Renaming and new Envs dir
This commit is contained in:
parent
f4d47c7d2d
commit
bcf760243c
3 changed files with 28 additions and 14 deletions
|
@ -1,14 +1,14 @@
|
||||||
export LocalCOMEnv
|
export LocalCOMWithAdditionalShapeRewardEnv
|
||||||
|
|
||||||
using ..ReCo: Particle
|
using ..ReCo: Particle
|
||||||
|
|
||||||
struct LocalCOMEnv <: Env
|
struct LocalCOMWithAdditionalShapeRewardEnv <: Env
|
||||||
shared::EnvSharedProps
|
shared::EnvSharedProps
|
||||||
|
|
||||||
distance_state_space::Vector{Interval}
|
distance_state_space::Vector{Interval}
|
||||||
direction_angle_state_space::Vector{Interval}
|
direction_angle_state_space::Vector{Interval}
|
||||||
|
|
||||||
function LocalCOMEnv(;
|
function LocalCOMWithAdditionalShapeRewardEnv(;
|
||||||
n_distance_states::Int64=3, n_direction_angle_states::Int64=3, args
|
n_distance_states::Int64=3, n_direction_angle_states::Int64=3, args
|
||||||
)
|
)
|
||||||
@assert n_distance_states > 1
|
@assert n_distance_states > 1
|
||||||
|
@ -32,7 +32,7 @@ struct LocalCOMEnv <: Env
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
mutable struct LocalCOMEnvHelper <: EnvHelper
|
mutable struct LocalCOMWithAdditionalShapeRewardEnvHelper <: EnvHelper
|
||||||
shared::EnvHelperSharedProps
|
shared::EnvHelperSharedProps
|
||||||
|
|
||||||
vec_to_neighbour_sums::Vector{SVector{2,Float64}}
|
vec_to_neighbour_sums::Vector{SVector{2,Float64}}
|
||||||
|
@ -50,7 +50,7 @@ mutable struct LocalCOMEnvHelper <: EnvHelper
|
||||||
half_box_len::Float64
|
half_box_len::Float64
|
||||||
max_elliptical_distance::Float64
|
max_elliptical_distance::Float64
|
||||||
|
|
||||||
function LocalCOMEnvHelper(
|
function LocalCOMWithAdditionalShapeRewardEnvHelper(
|
||||||
shared::EnvHelperSharedProps, half_box_len::Float64, skin_radius
|
shared::EnvHelperSharedProps, half_box_len::Float64, skin_radius
|
||||||
)
|
)
|
||||||
max_elliptical_distance = sqrt(2) * half_box_len / shared.elliptical_a_b_ratio
|
max_elliptical_distance = sqrt(2) * half_box_len / shared.elliptical_a_b_ratio
|
||||||
|
@ -73,11 +73,15 @@ mutable struct LocalCOMEnvHelper <: EnvHelper
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function gen_env_helper(::LocalCOMEnv, env_helper_shared::EnvHelperSharedProps; args)
|
function gen_env_helper(
|
||||||
return LocalCOMEnvHelper(env_helper_shared, args.half_box_len, args.skin_radius)
|
::LocalCOMWithAdditionalShapeRewardEnv, env_helper_shared::EnvHelperSharedProps; args
|
||||||
|
)
|
||||||
|
return LocalCOMWithAdditionalShapeRewardEnvHelper(
|
||||||
|
env_helper_shared, args.half_box_len, args.skin_radius
|
||||||
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
function pre_integration_hook!(env_helper::LocalCOMEnvHelper)
|
function pre_integration_hook!(env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper)
|
||||||
@simd for id in 1:(env_helper.shared.n_particles)
|
@simd for id in 1:(env_helper.shared.n_particles)
|
||||||
env_helper.vec_to_neighbour_sums[id] = SVector(0.0, 0.0)
|
env_helper.vec_to_neighbour_sums[id] = SVector(0.0, 0.0)
|
||||||
env_helper.n_neighbours[id] = 0
|
env_helper.n_neighbours[id] = 0
|
||||||
|
@ -87,7 +91,10 @@ function pre_integration_hook!(env_helper::LocalCOMEnvHelper)
|
||||||
end
|
end
|
||||||
|
|
||||||
function state_update_helper_hook!(
|
function state_update_helper_hook!(
|
||||||
env_helper::LocalCOMEnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}
|
env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper,
|
||||||
|
id1::Int64,
|
||||||
|
id2::Int64,
|
||||||
|
r⃗₁₂::SVector{2,Float64},
|
||||||
)
|
)
|
||||||
env_helper.vec_to_neighbour_sums[id1] += r⃗₁₂
|
env_helper.vec_to_neighbour_sums[id1] += r⃗₁₂
|
||||||
env_helper.vec_to_neighbour_sums[id2] -= r⃗₁₂
|
env_helper.vec_to_neighbour_sums[id2] -= r⃗₁₂
|
||||||
|
@ -98,7 +105,9 @@ function state_update_helper_hook!(
|
||||||
return nothing
|
return nothing
|
||||||
end
|
end
|
||||||
|
|
||||||
function state_update_hook!(env_helper::LocalCOMEnvHelper, particles::Vector{Particle})
|
function state_update_hook!(
|
||||||
|
env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper, particles::Vector{Particle}
|
||||||
|
)
|
||||||
n_particles = env_helper.shared.n_particles
|
n_particles = env_helper.shared.n_particles
|
||||||
|
|
||||||
@turbo for id in 1:n_particles
|
@turbo for id in 1:n_particles
|
||||||
|
@ -161,7 +170,11 @@ function state_update_hook!(env_helper::LocalCOMEnvHelper, particles::Vector{Par
|
||||||
return nothing
|
return nothing
|
||||||
end
|
end
|
||||||
|
|
||||||
function update_reward!(env::LocalCOMEnv, env_helper::LocalCOMEnvHelper, particle::Particle)
|
function update_reward!(
|
||||||
|
env::LocalCOMWithAdditionalShapeRewardEnv,
|
||||||
|
env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper,
|
||||||
|
particle::Particle,
|
||||||
|
)
|
||||||
id = particle.id
|
id = particle.id
|
||||||
|
|
||||||
normalization = env_helper.shared.n_particles # TODO: Add factor from steps
|
normalization = env_helper.shared.n_particles # TODO: Add factor from steps
|
|
@ -1,6 +1,6 @@
|
||||||
module RL
|
module RL
|
||||||
|
|
||||||
export run_rl, LocalCOMEnv
|
export run_rl, LocalCOMWithAdditionalShapeRewardEnv
|
||||||
|
|
||||||
using Base: OneTo
|
using Base: OneTo
|
||||||
|
|
||||||
|
@ -140,6 +140,6 @@ function run_rl(;
|
||||||
return env_helper
|
return env_helper
|
||||||
end
|
end
|
||||||
|
|
||||||
include("LocalCOMEnv.jl")
|
include("Envs/LocalCOMWithAdditionalShapeRewardEnv.jl")
|
||||||
|
|
||||||
end # module
|
end # module
|
|
@ -1,6 +1,7 @@
|
||||||
module ReCo
|
module ReCo
|
||||||
|
|
||||||
export init_sim, run_sim, run_rl, animate, plot_snapshot, LocalCOMEnv
|
export init_sim,
|
||||||
|
run_sim, run_rl, animate, plot_snapshot, LocalCOMWithAdditionalShapeRewardEnv
|
||||||
|
|
||||||
using StaticArrays: SVector
|
using StaticArrays: SVector
|
||||||
using JLD2: JLD2
|
using JLD2: JLD2
|
||||||
|
|
Loading…
Reference in a new issue