diff --git a/src/RL/LocalCOMEnv.jl b/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv.jl similarity index 84% rename from src/RL/LocalCOMEnv.jl rename to src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv.jl index 75df014..52078e7 100644 --- a/src/RL/LocalCOMEnv.jl +++ b/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv.jl @@ -1,14 +1,14 @@ -export LocalCOMEnv +export LocalCOMWithAdditionalShapeRewardEnv using ..ReCo: Particle -struct LocalCOMEnv <: Env +struct LocalCOMWithAdditionalShapeRewardEnv <: Env shared::EnvSharedProps distance_state_space::Vector{Interval} direction_angle_state_space::Vector{Interval} - function LocalCOMEnv(; + function LocalCOMWithAdditionalShapeRewardEnv(; n_distance_states::Int64=3, n_direction_angle_states::Int64=3, args ) @assert n_distance_states > 1 @@ -32,7 +32,7 @@ struct LocalCOMEnv <: Env end end -mutable struct LocalCOMEnvHelper <: EnvHelper +mutable struct LocalCOMWithAdditionalShapeRewardEnvHelper <: EnvHelper shared::EnvHelperSharedProps vec_to_neighbour_sums::Vector{SVector{2,Float64}} @@ -50,7 +50,7 @@ mutable struct LocalCOMEnvHelper <: EnvHelper half_box_len::Float64 max_elliptical_distance::Float64 - function LocalCOMEnvHelper( + function LocalCOMWithAdditionalShapeRewardEnvHelper( shared::EnvHelperSharedProps, half_box_len::Float64, skin_radius ) max_elliptical_distance = sqrt(2) * half_box_len / shared.elliptical_a_b_ratio @@ -73,11 +73,15 @@ mutable struct LocalCOMEnvHelper <: EnvHelper end end -function gen_env_helper(::LocalCOMEnv, env_helper_shared::EnvHelperSharedProps; args) - return LocalCOMEnvHelper(env_helper_shared, args.half_box_len, args.skin_radius) +function gen_env_helper( + ::LocalCOMWithAdditionalShapeRewardEnv, env_helper_shared::EnvHelperSharedProps; args +) + return LocalCOMWithAdditionalShapeRewardEnvHelper( + env_helper_shared, args.half_box_len, args.skin_radius + ) end -function pre_integration_hook!(env_helper::LocalCOMEnvHelper) +function pre_integration_hook!(env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper) @simd for id in 1:(env_helper.shared.n_particles) env_helper.vec_to_neighbour_sums[id] = SVector(0.0, 0.0) env_helper.n_neighbours[id] = 0 @@ -87,7 +91,10 @@ function pre_integration_hook!(env_helper::LocalCOMEnvHelper) end function state_update_helper_hook!( - env_helper::LocalCOMEnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64} + env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper, + id1::Int64, + id2::Int64, + r⃗₁₂::SVector{2,Float64}, ) env_helper.vec_to_neighbour_sums[id1] += r⃗₁₂ env_helper.vec_to_neighbour_sums[id2] -= r⃗₁₂ @@ -98,7 +105,9 @@ function state_update_helper_hook!( return nothing end -function state_update_hook!(env_helper::LocalCOMEnvHelper, particles::Vector{Particle}) +function state_update_hook!( + env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper, particles::Vector{Particle} +) n_particles = env_helper.shared.n_particles @turbo for id in 1:n_particles @@ -161,7 +170,11 @@ function state_update_hook!(env_helper::LocalCOMEnvHelper, particles::Vector{Par return nothing end -function update_reward!(env::LocalCOMEnv, env_helper::LocalCOMEnvHelper, particle::Particle) +function update_reward!( + env::LocalCOMWithAdditionalShapeRewardEnv, + env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper, + particle::Particle, +) id = particle.id normalization = env_helper.shared.n_particles # TODO: Add factor from steps diff --git a/src/RL/RL.jl b/src/RL/RL.jl index 75febee..5570df1 100644 --- a/src/RL/RL.jl +++ b/src/RL/RL.jl @@ -1,6 +1,6 @@ module RL -export run_rl, LocalCOMEnv +export run_rl, LocalCOMWithAdditionalShapeRewardEnv using Base: OneTo @@ -140,6 +140,6 @@ function run_rl(; return env_helper end -include("LocalCOMEnv.jl") +include("Envs/LocalCOMWithAdditionalShapeRewardEnv.jl") end # module \ No newline at end of file diff --git a/src/ReCo.jl b/src/ReCo.jl index 7ce2bcd..a87325e 100644 --- a/src/ReCo.jl +++ b/src/ReCo.jl @@ -1,6 +1,7 @@ module ReCo -export init_sim, run_sim, run_rl, animate, plot_snapshot, LocalCOMEnv +export init_sim, + run_sim, run_rl, animate, plot_snapshot, LocalCOMWithAdditionalShapeRewardEnv using StaticArrays: SVector using JLD2: JLD2