1
0
Fork 0
mirror of https://gitlab.rlp.net/mobitar/ReCo.jl.git synced 2024-12-21 00:51:21 +00:00

Renaming and new Envs dir

This commit is contained in:
Mo8it 2022-01-29 15:48:13 +01:00
parent f4d47c7d2d
commit bcf760243c
3 changed files with 28 additions and 14 deletions

View file

@ -1,14 +1,14 @@
export LocalCOMEnv export LocalCOMWithAdditionalShapeRewardEnv
using ..ReCo: Particle using ..ReCo: Particle
struct LocalCOMEnv <: Env struct LocalCOMWithAdditionalShapeRewardEnv <: Env
shared::EnvSharedProps shared::EnvSharedProps
distance_state_space::Vector{Interval} distance_state_space::Vector{Interval}
direction_angle_state_space::Vector{Interval} direction_angle_state_space::Vector{Interval}
function LocalCOMEnv(; function LocalCOMWithAdditionalShapeRewardEnv(;
n_distance_states::Int64=3, n_direction_angle_states::Int64=3, args n_distance_states::Int64=3, n_direction_angle_states::Int64=3, args
) )
@assert n_distance_states > 1 @assert n_distance_states > 1
@ -32,7 +32,7 @@ struct LocalCOMEnv <: Env
end end
end end
mutable struct LocalCOMEnvHelper <: EnvHelper mutable struct LocalCOMWithAdditionalShapeRewardEnvHelper <: EnvHelper
shared::EnvHelperSharedProps shared::EnvHelperSharedProps
vec_to_neighbour_sums::Vector{SVector{2,Float64}} vec_to_neighbour_sums::Vector{SVector{2,Float64}}
@ -50,7 +50,7 @@ mutable struct LocalCOMEnvHelper <: EnvHelper
half_box_len::Float64 half_box_len::Float64
max_elliptical_distance::Float64 max_elliptical_distance::Float64
function LocalCOMEnvHelper( function LocalCOMWithAdditionalShapeRewardEnvHelper(
shared::EnvHelperSharedProps, half_box_len::Float64, skin_radius shared::EnvHelperSharedProps, half_box_len::Float64, skin_radius
) )
max_elliptical_distance = sqrt(2) * half_box_len / shared.elliptical_a_b_ratio max_elliptical_distance = sqrt(2) * half_box_len / shared.elliptical_a_b_ratio
@ -73,11 +73,15 @@ mutable struct LocalCOMEnvHelper <: EnvHelper
end end
end end
function gen_env_helper(::LocalCOMEnv, env_helper_shared::EnvHelperSharedProps; args) function gen_env_helper(
return LocalCOMEnvHelper(env_helper_shared, args.half_box_len, args.skin_radius) ::LocalCOMWithAdditionalShapeRewardEnv, env_helper_shared::EnvHelperSharedProps; args
)
return LocalCOMWithAdditionalShapeRewardEnvHelper(
env_helper_shared, args.half_box_len, args.skin_radius
)
end end
function pre_integration_hook!(env_helper::LocalCOMEnvHelper) function pre_integration_hook!(env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper)
@simd for id in 1:(env_helper.shared.n_particles) @simd for id in 1:(env_helper.shared.n_particles)
env_helper.vec_to_neighbour_sums[id] = SVector(0.0, 0.0) env_helper.vec_to_neighbour_sums[id] = SVector(0.0, 0.0)
env_helper.n_neighbours[id] = 0 env_helper.n_neighbours[id] = 0
@ -87,7 +91,10 @@ function pre_integration_hook!(env_helper::LocalCOMEnvHelper)
end end
function state_update_helper_hook!( function state_update_helper_hook!(
env_helper::LocalCOMEnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64} env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper,
id1::Int64,
id2::Int64,
r⃗₁₂::SVector{2,Float64},
) )
env_helper.vec_to_neighbour_sums[id1] += r⃗₁₂ env_helper.vec_to_neighbour_sums[id1] += r⃗₁₂
env_helper.vec_to_neighbour_sums[id2] -= r⃗₁₂ env_helper.vec_to_neighbour_sums[id2] -= r⃗₁₂
@ -98,7 +105,9 @@ function state_update_helper_hook!(
return nothing return nothing
end end
function state_update_hook!(env_helper::LocalCOMEnvHelper, particles::Vector{Particle}) function state_update_hook!(
env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper, particles::Vector{Particle}
)
n_particles = env_helper.shared.n_particles n_particles = env_helper.shared.n_particles
@turbo for id in 1:n_particles @turbo for id in 1:n_particles
@ -161,7 +170,11 @@ function state_update_hook!(env_helper::LocalCOMEnvHelper, particles::Vector{Par
return nothing return nothing
end end
function update_reward!(env::LocalCOMEnv, env_helper::LocalCOMEnvHelper, particle::Particle) function update_reward!(
env::LocalCOMWithAdditionalShapeRewardEnv,
env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper,
particle::Particle,
)
id = particle.id id = particle.id
normalization = env_helper.shared.n_particles # TODO: Add factor from steps normalization = env_helper.shared.n_particles # TODO: Add factor from steps

View file

@ -1,6 +1,6 @@
module RL module RL
export run_rl, LocalCOMEnv export run_rl, LocalCOMWithAdditionalShapeRewardEnv
using Base: OneTo using Base: OneTo
@ -140,6 +140,6 @@ function run_rl(;
return env_helper return env_helper
end end
include("LocalCOMEnv.jl") include("Envs/LocalCOMWithAdditionalShapeRewardEnv.jl")
end # module end # module

View file

@ -1,6 +1,7 @@
module ReCo module ReCo
export init_sim, run_sim, run_rl, animate, plot_snapshot, LocalCOMEnv export init_sim,
run_sim, run_rl, animate, plot_snapshot, LocalCOMWithAdditionalShapeRewardEnv
using StaticArrays: SVector using StaticArrays: SVector
using JLD2: JLD2 using JLD2: JLD2