diff --git a/src/RL/EnvHelper.jl b/src/RL/EnvHelper.jl index 5ab2fa4..85dacbc 100644 --- a/src/RL/EnvHelper.jl +++ b/src/RL/EnvHelper.jl @@ -47,7 +47,7 @@ struct EnvHelperSharedProps{H<:AbstractHook} end end -function gen_env_helper(::Env, env_helper_params::EnvHelperSharedProps; args) +function gen_env_helper(::Env, env_helper_params::EnvHelperSharedProps; kwargs...) return ReCo.method_not_implemented() end diff --git a/src/RL/Envs/COMCompassEnv.jl b/src/RL/Envs/COMCompassEnv.jl index 6a8baf2..89ba401 100644 --- a/src/RL/Envs/COMCompassEnv.jl +++ b/src/RL/Envs/COMCompassEnv.jl @@ -11,11 +11,11 @@ struct COMCompassEnv <: Env direction_angle_state_space::Vector{Interval} position_angle_state_space::Vector{Interval} - function COMCompassEnv(; + function COMCompassEnv( + args; n_distance_states::Int64=3, n_direction_angle_states::Int64=3, n_position_angle_states::Int64=8, - args, ) @assert n_distance_states > 1 @assert n_direction_angle_states > 1 @@ -63,7 +63,7 @@ mutable struct COMCompassEnvHelper <: EnvHelper half_box_len::Float64 max_elliptical_distance::Float64 - function COMCompassEnvHelper(shared::EnvHelperSharedProps, half_box_len::Float64) + function COMCompassEnvHelper(shared::EnvHelperSharedProps; half_box_len::Float64) max_elliptical_distance = sqrt( half_box_len^2 + (half_box_len / shared.elliptical_b_a_ratio)^2 ) @@ -79,8 +79,8 @@ mutable struct COMCompassEnvHelper <: EnvHelper end end -function gen_env_helper(::COMCompassEnv, env_helper_shared::EnvHelperSharedProps; args) - return COMCompassEnvHelper(env_helper_shared, args.half_box_len) +function gen_env_helper(::COMCompassEnv, env_helper_shared::EnvHelperSharedProps; kwargs...) + return COMCompassEnvHelper(env_helper_shared; kwargs...) end function pre_integration_hook!(::COMCompassEnvHelper) diff --git a/src/RL/Envs/LocalCOMEnv.jl b/src/RL/Envs/LocalCOMEnv.jl index 4f1845c..1bd9343 100644 --- a/src/RL/Envs/LocalCOMEnv.jl +++ b/src/RL/Envs/LocalCOMEnv.jl @@ -10,8 +10,8 @@ struct LocalCOMEnv <: Env distance_state_space::Vector{Interval} direction_angle_state_space::Vector{Interval} - function LocalCOMEnv(; - n_distance_states::Int64=3, n_direction_angle_states::Int64=3, args + function LocalCOMEnv( + args; n_distance_states::Int64=3, n_direction_angle_states::Int64=3 ) @assert n_distance_states > 1 @assert n_direction_angle_states > 1 @@ -52,7 +52,7 @@ mutable struct LocalCOMEnvHelper <: EnvHelper half_box_len::Float64 function LocalCOMEnvHelper( - shared::EnvHelperSharedProps, half_box_len::Float64, skin_radius::Float64 + shared::EnvHelperSharedProps; half_box_len::Float64, skin_radius::Float64 ) max_distance_to_local_center_of_mass = skin_radius @@ -67,8 +67,8 @@ mutable struct LocalCOMEnvHelper <: EnvHelper end end -function gen_env_helper(::LocalCOMEnv, env_helper_shared::EnvHelperSharedProps; args) - return LocalCOMEnvHelper(env_helper_shared, args.half_box_len, args.skin_radius) +function gen_env_helper(::LocalCOMEnv, env_helper_shared::EnvHelperSharedProps; kwargs...) + return LocalCOMEnvHelper(env_helper_shared; kwargs...) end function pre_integration_hook!(env_helper::LocalCOMEnvHelper) diff --git a/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv.jl b/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv.jl index 83a0338..69ac911 100644 --- a/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv.jl +++ b/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv.jl @@ -5,7 +5,7 @@ The minimization variable of the additional reward term is the individual ellipt using ..ReCo: ReCo -const TRIGGER = 0.6 +const DEFAULT_TRIGGER = 0.35 struct LocalCOMWithAdditionalShapeRewardEnv <: Env shared::EnvSharedProps @@ -13,8 +13,8 @@ struct LocalCOMWithAdditionalShapeRewardEnv <: Env distance_state_space::Vector{Interval} direction_angle_state_space::Vector{Interval} - function LocalCOMWithAdditionalShapeRewardEnv(; - n_distance_states::Int64=3, n_direction_angle_states::Int64=3, args + function LocalCOMWithAdditionalShapeRewardEnv( + args; n_distance_states::Int64=3, n_direction_angle_states::Int64=3 ) @assert n_distance_states > 1 @assert n_direction_angle_states > 1 @@ -61,8 +61,13 @@ mutable struct LocalCOMWithAdditionalShapeRewardEnvHelper <: EnvHelper half_box_len::Float64 max_elliptical_distance::Float64 + trigger::Float64 + function LocalCOMWithAdditionalShapeRewardEnvHelper( - shared::EnvHelperSharedProps, half_box_len::Float64, skin_radius::Float64 + shared::EnvHelperSharedProps; + half_box_len::Float64, + skin_radius::Float64, + trigger::Float64=DEFAULT_TRIGGER, ) max_elliptical_distance = sqrt( half_box_len^2 + (half_box_len / shared.elliptical_b_a_ratio)^2 @@ -82,16 +87,17 @@ mutable struct LocalCOMWithAdditionalShapeRewardEnvHelper <: EnvHelper SVector(0.0, 0.0), half_box_len, max_elliptical_distance, + trigger, ) end end function gen_env_helper( - ::LocalCOMWithAdditionalShapeRewardEnv, env_helper_shared::EnvHelperSharedProps; args + ::LocalCOMWithAdditionalShapeRewardEnv, + env_helper_shared::EnvHelperSharedProps; + kwargs..., ) - return LocalCOMWithAdditionalShapeRewardEnvHelper( - env_helper_shared, args.half_box_len, args.skin_radius - ) + return LocalCOMWithAdditionalShapeRewardEnvHelper(env_helper_shared; kwargs...) end function pre_integration_hook!(env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper) @@ -160,7 +166,7 @@ function state_update_hook!( distance_to_local_center_of_mass_sum / n_particles env_helper.add_shape_reward_term = mean_distance_to_local_center_of_mass / - env_helper.max_distance_to_local_center_of_mass < TRIGGER + env_helper.max_distance_to_local_center_of_mass < env_helper.trigger if env_helper.add_shape_reward_term print("*") end diff --git a/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv2.jl b/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv2.jl index c5a8613..24e77ac 100644 --- a/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv2.jl +++ b/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv2.jl @@ -5,7 +5,7 @@ The minimization variable of the additional reward term is the absolute differen using ..ReCo: ReCo -const TRIGGER = 0.6 +const DEFAULT_TRIGGER = 0.35 struct LocalCOMWithAdditionalShapeRewardEnv2 <: Env shared::EnvSharedProps @@ -13,8 +13,8 @@ struct LocalCOMWithAdditionalShapeRewardEnv2 <: Env distance_state_space::Vector{Interval} direction_angle_state_space::Vector{Interval} - function LocalCOMWithAdditionalShapeRewardEnv2(; - n_distance_states::Int64=3, n_direction_angle_states::Int64=3, args + function LocalCOMWithAdditionalShapeRewardEnv2( + args; n_distance_states::Int64=3, n_direction_angle_states::Int64=3 ) @assert n_distance_states > 1 @assert n_direction_angle_states > 1 @@ -60,8 +60,13 @@ mutable struct LocalCOMWithAdditionalShapeRewardEnv2Helper <: EnvHelper half_box_len::Float64 + trigger::Float64 + function LocalCOMWithAdditionalShapeRewardEnv2Helper( - shared::EnvHelperSharedProps, half_box_len::Float64, skin_radius::Float64 + shared::EnvHelperSharedProps; + half_box_len::Float64, + skin_radius::Float64, + trigger::Float64=DEFAULT_TRIGGER, ) goal_κ = 0.4 max_distance_to_goal_κ = max(1 - goal_κ, goal_κ) @@ -79,16 +84,17 @@ mutable struct LocalCOMWithAdditionalShapeRewardEnv2Helper <: EnvHelper goal_κ, max_distance_to_goal_κ, half_box_len, + trigger, ) end end function gen_env_helper( - ::LocalCOMWithAdditionalShapeRewardEnv2, env_helper_shared::EnvHelperSharedProps; args + ::LocalCOMWithAdditionalShapeRewardEnv2, + env_helper_shared::EnvHelperSharedProps; + kwargs..., ) - return LocalCOMWithAdditionalShapeRewardEnv2Helper( - env_helper_shared, args.half_box_len, args.skin_radius - ) + return LocalCOMWithAdditionalShapeRewardEnv2Helper(env_helper_shared; kwargs...) end function pre_integration_hook!(env_helper::LocalCOMWithAdditionalShapeRewardEnv2Helper) @@ -158,7 +164,7 @@ function state_update_hook!( distance_to_local_center_of_mass_sum / n_particles env_helper.add_shape_reward_term = mean_distance_to_local_center_of_mass / - env_helper.max_distance_to_local_center_of_mass < TRIGGER + env_helper.max_distance_to_local_center_of_mass < env_helper.trigger return nothing end diff --git a/src/RL/Envs/NearestNeighborEnv.jl b/src/RL/Envs/NearestNeighborEnv.jl index 5aaa7ff..3f89dfe 100644 --- a/src/RL/Envs/NearestNeighborEnv.jl +++ b/src/RL/Envs/NearestNeighborEnv.jl @@ -10,8 +10,8 @@ struct NearestNeighborEnv <: Env distance_state_space::Vector{Interval} direction_angle_state_space::Vector{Interval} - function NearestNeighborEnv(; - n_distance_states::Int64=3, n_direction_angle_states::Int64=3, args + function NearestNeighborEnv( + args; n_distance_states::Int64=3, n_direction_angle_states::Int64=3 ) @assert n_distance_states > 1 @assert n_direction_angle_states > 1 @@ -52,7 +52,7 @@ mutable struct NearestNeighborEnvHelper <: EnvHelper half_box_len::Float64 - function NearestNeighborEnvHelper(shared::EnvHelperSharedProps, half_box_len::Float64) + function NearestNeighborEnvHelper(shared::EnvHelperSharedProps; half_box_len::Float64) goal_κ = 0.4 max_distance_to_goal_κ = max(1 - goal_κ, goal_κ) @@ -68,8 +68,10 @@ mutable struct NearestNeighborEnvHelper <: EnvHelper end end -function gen_env_helper(::NearestNeighborEnv, env_helper_shared::EnvHelperSharedProps; args) - return NearestNeighborEnvHelper(env_helper_shared, args.half_box_len) +function gen_env_helper( + ::NearestNeighborEnv, env_helper_shared::EnvHelperSharedProps; kwargs... +) + return NearestNeighborEnvHelper(env_helper_shared; kwargs...) end function pre_integration_hook!(env_helper::NearestNeighborEnvHelper) diff --git a/src/RL/Envs/OriginCompassEnv.jl b/src/RL/Envs/OriginCompassEnv.jl index 8a4717a..d6f795f 100644 --- a/src/RL/Envs/OriginCompassEnv.jl +++ b/src/RL/Envs/OriginCompassEnv.jl @@ -11,11 +11,11 @@ struct OriginCompassEnv <: Env direction_angle_state_space::Vector{Interval} position_angle_state_space::Vector{Interval} - function OriginCompassEnv(; + function OriginCompassEnv( + args; n_distance_states::Int64=3, n_direction_angle_states::Int64=3, n_position_angle_states::Int64=8, - args, ) @assert n_distance_states > 1 @assert n_direction_angle_states > 1 @@ -63,7 +63,7 @@ mutable struct OriginCompassEnvHelper <: EnvHelper half_box_len::Float64 max_elliptical_distance::Float64 - function OriginCompassEnvHelper(shared::EnvHelperSharedProps, half_box_len::Float64) + function OriginCompassEnvHelper(shared::EnvHelperSharedProps; half_box_len::Float64) max_elliptical_distance = sqrt( half_box_len^2 + (half_box_len / shared.elliptical_b_a_ratio)^2 ) @@ -79,8 +79,10 @@ mutable struct OriginCompassEnvHelper <: EnvHelper end end -function gen_env_helper(::OriginCompassEnv, env_helper_shared::EnvHelperSharedProps; args) - return OriginCompassEnvHelper(env_helper_shared, args.half_box_len) +function gen_env_helper( + ::OriginCompassEnv, env_helper_shared::EnvHelperSharedProps; kwargs... +) + return OriginCompassEnvHelper(env_helper_shared; kwargs...) end function pre_integration_hook!(::OriginCompassEnvHelper) diff --git a/src/RL/Envs/OriginEnv.jl b/src/RL/Envs/OriginEnv.jl index 8e2cd24..00ab405 100644 --- a/src/RL/Envs/OriginEnv.jl +++ b/src/RL/Envs/OriginEnv.jl @@ -10,9 +10,7 @@ struct OriginEnv <: Env distance_state_space::Vector{Interval} direction_angle_state_space::Vector{Interval} - function OriginEnv(; - n_distance_states::Int64=3, n_direction_angle_states::Int64=3, args - ) + function OriginEnv(args; n_distance_states::Int64=3, n_direction_angle_states::Int64=3) @assert n_distance_states > 1 @assert n_direction_angle_states > 1 @@ -47,7 +45,7 @@ mutable struct OriginEnvHelper <: EnvHelper half_box_len::Float64 - function OriginEnvHelper(shared::EnvHelperSharedProps, half_box_len::Float64) + function OriginEnvHelper(shared::EnvHelperSharedProps; half_box_len::Float64) max_distance_to_origin = sqrt(2) * half_box_len return new( @@ -56,8 +54,8 @@ mutable struct OriginEnvHelper <: EnvHelper end end -function gen_env_helper(::OriginEnv, env_helper_shared::EnvHelperSharedProps; args) - return OriginEnvHelper(env_helper_shared, args.half_box_len) +function gen_env_helper(::OriginEnv, env_helper_shared::EnvHelperSharedProps; kwargs...) + return OriginEnvHelper(env_helper_shared; kwargs...) end function pre_integration_hook!(::OriginEnvHelper)