1
0
Fork 0
mirror of https://gitlab.rlp.net/mobitar/ReCo.jl.git synced 2024-12-21 00:51:21 +00:00
ReCo.jl/src/RL/Env.jl

121 lines
3.2 KiB
Julia
Raw Normal View History

2022-01-11 18:00:41 +00:00
abstract type Env <: AbstractEnv end
2022-01-15 20:27:15 +00:00
mutable struct EnvSharedProps{n_state_dims}
2022-01-11 18:00:41 +00:00
n_actions::Int64
action_space::Vector{SVector{2,Float64}}
2022-01-15 20:27:15 +00:00
action_id_space::OneTo{Int64}
2022-01-11 18:00:41 +00:00
n_states::Int64
2022-01-15 20:27:15 +00:00
state_id_tensor::Array{Int64,n_state_dims}
state_id_space::OneTo{Int64}
state_id::Int64
2022-01-11 18:00:41 +00:00
action_spaces_labels::Vector{Vector{LaTeXStrings.LaTeXString}}
state_spaces_labels::Vector{Vector{LaTeXStrings.LaTeXString}}
2022-01-11 18:00:41 +00:00
reward::Float64
terminated::Bool
function EnvSharedProps(
2022-01-15 20:27:15 +00:00
n_states::Int64, # Can be different from the sum of state_id_tensor_dims
state_id_tensor_dims::NTuple{n_state_dims,Int64},
state_spaces_labels::Vector{Vector{LaTeXStrings.LaTeXString}};
2022-01-11 18:00:41 +00:00
n_v_actions::Int64=2,
n_ω_actions::Int64=3,
max_v::Float64=40.0,
max_ω::Float64=π / 2,
2022-01-15 20:27:15 +00:00
) where {n_state_dims}
2022-01-11 18:00:41 +00:00
@assert n_v_actions > 1
@assert n_ω_actions > 1
@assert max_v > 0
@assert max_ω > 0
v_action_space = range(; start=0.0, stop=max_v, length=n_v_actions)
ω_action_space = range(; start=-max_ω, stop=max_ω, length=n_ω_actions)
action_spaces_labels = gen_action_spaces_labels(
("v", "\\omega"), (v_action_space, ω_action_space)
)
2022-01-11 18:00:41 +00:00
n_actions = n_v_actions * n_ω_actions
action_space = Vector{SVector{2,Float64}}(undef, n_actions)
ind = 1
for v in v_action_space
for ω in ω_action_space
action_space[ind] = SVector(v, ω)
ind += 1
end
end
2022-01-15 20:27:15 +00:00
action_id_space = OneTo(n_actions)
2022-01-11 18:00:41 +00:00
2022-01-15 20:27:15 +00:00
state_id_tensor = Array{Int64,n_state_dims}(undef, state_id_tensor_dims)
2022-01-11 18:00:41 +00:00
2022-01-15 20:27:15 +00:00
id = 1
for ind in eachindex(state_id_tensor)
state_id_tensor[ind] = id
id += 1
end
state_id_space = OneTo(n_states)
return new{n_state_dims}(
2022-01-11 18:00:41 +00:00
n_actions,
action_space,
2022-01-15 20:27:15 +00:00
action_id_space,
2022-01-11 18:00:41 +00:00
n_states,
2022-01-15 20:27:15 +00:00
state_id_tensor,
state_id_space,
2022-01-11 18:00:41 +00:00
INITIAL_STATE_IND,
action_spaces_labels,
state_spaces_labels,
2022-01-11 18:00:41 +00:00
INITIAL_REWARD,
false,
)
end
end
function reset!(env::Env)
env.shared.terminated = false
return nothing
end
2022-01-18 01:17:52 +00:00
function RLBase.state_space(env::Env)
return env.shared.state_id_space
end
2022-01-11 18:00:41 +00:00
2022-01-18 01:17:52 +00:00
function RLBase.state(env::Env)
return env.shared.state_id
end
2022-01-11 18:00:41 +00:00
2022-01-18 01:17:52 +00:00
function RLBase.action_space(env::Env)
return env.shared.action_id_space
end
2022-01-11 18:00:41 +00:00
2022-01-18 01:17:52 +00:00
function RLBase.reward(env::Env)
return env.shared.reward
end
2022-01-11 18:00:41 +00:00
2022-01-18 01:17:52 +00:00
function RLBase.is_terminated(env::Env)
return env.shared.terminated
end
function gen_action_space_labels(action_label::String, action_space::AbstractRange)
labels = Vector{LaTeXStrings.LaTeXString}(undef, length(action_space))
for (action_ind, action) in enumerate(action_space)
labels[action_ind] = LaTeXStrings.latexstring(
2022-01-30 19:46:21 +00:00
"\$" * action_label * "\$=$(round(action; digits=2))"
)
end
return labels
end
function gen_action_spaces_labels(
actions_labels::NTuple{N,String}, action_spaces::NTuple{N,AbstractRange}
) where {N}
return [gen_action_space_labels(actions_labels[i], action_spaces[i]) for i in 1:N]
2022-01-18 01:17:52 +00:00
end