1
0
Fork 0
mirror of https://gitlab.rlp.net/mobitar/ReCo.jl.git synced 2024-09-19 19:01:17 +00:00

Save env_helper and added states and actions labels

This commit is contained in:
Mo8it 2022-01-30 01:22:37 +01:00
parent 3e713fd44c
commit 839e766206
4 changed files with 64 additions and 8 deletions

View file

@ -10,12 +10,16 @@ mutable struct EnvSharedProps{n_state_dims}
state_id_space::OneTo{Int64} state_id_space::OneTo{Int64}
state_id::Int64 state_id::Int64
action_spaces_labels::Vector{Vector{LaTeXStrings.LaTeXString}}
state_spaces_labels::Vector{Vector{LaTeXStrings.LaTeXString}}
reward::Float64 reward::Float64
terminated::Bool terminated::Bool
function EnvSharedProps( function EnvSharedProps(
n_states::Int64, # Can be different from the sum of state_id_tensor_dims n_states::Int64, # Can be different from the sum of state_id_tensor_dims
state_id_tensor_dims::NTuple{n_state_dims,Int64}; state_id_tensor_dims::NTuple{n_state_dims,Int64},
state_spaces_labels::Vector{Vector{LaTeXStrings.LaTeXString}};
n_v_actions::Int64=2, n_v_actions::Int64=2,
n_ω_actions::Int64=3, n_ω_actions::Int64=3,
max_v::Float64=40.0, max_v::Float64=40.0,
@ -29,6 +33,10 @@ mutable struct EnvSharedProps{n_state_dims}
v_action_space = range(; start=0.0, stop=max_v, length=n_v_actions) v_action_space = range(; start=0.0, stop=max_v, length=n_v_actions)
ω_action_space = range(; start=-max_ω, stop=max_ω, length=n_ω_actions) ω_action_space = range(; start=-max_ω, stop=max_ω, length=n_ω_actions)
action_spaces_labels = gen_action_spaces_labels(
("v", "\\omega"), (v_action_space, ω_action_space)
)
n_actions = n_v_actions * n_ω_actions n_actions = n_v_actions * n_ω_actions
action_space = Vector{SVector{2,Float64}}(undef, n_actions) action_space = Vector{SVector{2,Float64}}(undef, n_actions)
@ -61,6 +69,8 @@ mutable struct EnvSharedProps{n_state_dims}
state_id_tensor, state_id_tensor,
state_id_space, state_id_space,
INITIAL_STATE_IND, INITIAL_STATE_IND,
action_spaces_labels,
state_spaces_labels,
INITIAL_REWARD, INITIAL_REWARD,
false, false,
) )
@ -91,4 +101,21 @@ end
function RLBase.is_terminated(env::Env) function RLBase.is_terminated(env::Env)
return env.shared.terminated return env.shared.terminated
end
function gen_action_space_labels(action_label::String, action_space::AbstractRange)
labels = Vector{LaTeXStrings.LaTeXString}(undef, length(action_space))
for (action_ind, action) in enumerate(action_space)
labels[action_ind] = LaTeXStrings.latexstring(
action_label * " = $(round(action; digits=2))"
)
end
return labels
end
function gen_action_spaces_labels(
actions_labels::NTuple{N,String}, action_spaces::NTuple{N,AbstractRange}
) where {N}
return [gen_action_space_labels(actions_labels[i], action_spaces[i]) for i in 1:N]
end end

View file

@ -9,7 +9,7 @@ struct OriginEnv <: Env
direction_angle_state_space::Vector{Interval} direction_angle_state_space::Vector{Interval}
function OriginEnv(; function OriginEnv(;
n_distance_states::Int64=4, n_direction_angle_states::Int64=3, args n_distance_states::Int64=3, n_direction_angle_states::Int64=3, args
) )
@assert n_distance_states > 1 @assert n_distance_states > 1
@assert n_direction_angle_states > 1 @assert n_direction_angle_states > 1
@ -25,7 +25,13 @@ struct OriginEnv <: Env
n_states = n_distance_states * n_direction_angle_states n_states = n_distance_states * n_direction_angle_states
shared = EnvSharedProps(n_states, (n_distance_states, n_direction_angle_states)) state_spaces_labels = gen_state_spaces_labels(
("d", "\\theta"), (distance_state_space, direction_angle_state_space)
)
shared = EnvSharedProps(
n_states, (n_distance_states, n_direction_angle_states), state_spaces_labels
)
return new(shared, distance_state_space, direction_angle_state_space) return new(shared, distance_state_space, direction_angle_state_space)
end end

View file

@ -11,6 +11,8 @@ using StaticArrays: SVector
using LoopVectorization: @turbo using LoopVectorization: @turbo
using Random: Random using Random: Random
using ProgressMeter: ProgressMeter using ProgressMeter: ProgressMeter
using JLD2: JLD2
using LaTeXStrings: LaTeXStrings, @L_str
using ..ReCo: ReCo using ..ReCo: ReCo
@ -54,7 +56,7 @@ end
function run_rl(; function run_rl(;
EnvType::Type{E}, EnvType::Type{E},
parent_dir_appendix::String, process_dir::String,
elliptical_a_b_ratio::Float64=1.0, elliptical_a_b_ratio::Float64=1.0,
n_episodes::Int64=200, n_episodes::Int64=200,
episode_duration::Float64=50.0, episode_duration::Float64=50.0,
@ -85,7 +87,6 @@ function run_rl(;
n_particles = sim_consts.n_particles # Not always equal to the input! n_particles = sim_consts.n_particles # Not always equal to the input!
env_args = (skin_radius=sim_consts.skin_radius, half_box_len=sim_consts.half_box_len) env_args = (skin_radius=sim_consts.skin_radius, half_box_len=sim_consts.half_box_len)
env = EnvType(; args=env_args) env = EnvType(; args=env_args)
agent = gen_agent(env.shared.n_states, env.shared.n_actions, ϵ_stable) agent = gen_agent(env.shared.n_states, env.shared.n_actions, ϵ_stable)
@ -104,7 +105,7 @@ function run_rl(;
env_helper = gen_env_helper(env, env_helper_shared; args=env_helper_args) env_helper = gen_env_helper(env, env_helper_shared; args=env_helper_args)
parent_dir = "RL_" * parent_dir_appendix parent_dir = "RL/" * process_dir
# Pre experiment # Pre experiment
hook(PRE_EXPERIMENT_STAGE, agent, env) hook(PRE_EXPERIMENT_STAGE, agent, env)
@ -137,9 +138,31 @@ function run_rl(;
# Post experiment # Post experiment
hook(POST_EXPERIMENT_STAGE, agent, env) hook(POST_EXPERIMENT_STAGE, agent, env)
process_dir = ReCo.DEFAULT_EXPORTS_DIR * "/$parent_dir"
JLD2.save_object("$process_dir/env_helper.jld2", env_helper)
return env_helper return env_helper
end end
function gen_state_space_labels(state_label::String, state_space::Vector{Interval})
labels = Vector{LaTeXStrings.LaTeXString}(undef, length(state_space))
for (state_ind, state) in enumerate(state_space)
labels[state_ind] = LaTeXStrings.latexstring(
state_label *
" = $(round(state.first; digits=2)):$(round(state.last, digits=2))",
)
end
return labels
end
function gen_state_spaces_labels(
states_labels::NTuple{N,String}, state_spaces::NTuple{N,Vector{Interval}}
) where {N}
return [gen_state_space_labels(states_labels[i], state_spaces[i]) for i in 1:N]
end
include("Envs/LocalCOMWithAdditionalShapeRewardEnv.jl") include("Envs/LocalCOMWithAdditionalShapeRewardEnv.jl")
include("Envs/OriginEnv.jl") include("Envs/OriginEnv.jl")

View file

@ -1,6 +1,6 @@
const DEFAULT_PACKING_RATIO = 0.5 const DEFAULT_PACKING_RATIO = 0.5
const DEFAULT_δt = 1e-5 const DEFAULT_δt = 1e-5
const DEFAULT_SKIN_TO_INTERACTION_R_RATIO = 2.0 const DEFAULT_SKIN_TO_INTERACTION_R_RATIO = 2.5
const DEFAULT_EXPORTS_DIR = "exports" const DEFAULT_EXPORTS_DIR = "exports"
const DEFAULT_PARENT_DIR = "" const DEFAULT_PARENT_DIR = ""
const DEFAULT_COMMENT = "" const DEFAULT_COMMENT = ""
@ -69,7 +69,7 @@ function gen_sim_consts(
skin_radius = skin_to_interaction_radius_ratio * interaction_radius skin_radius = skin_to_interaction_radius_ratio * interaction_radius
buffer = 2.5 buffer = 3
max_approach_after_one_integration_step = buffer * (2 * v₀ * δt) max_approach_after_one_integration_step = buffer * (2 * v₀ * δt)
@assert skin_radius >= interaction_radius + max_approach_after_one_integration_step @assert skin_radius >= interaction_radius + max_approach_after_one_integration_step