mirror of
https://gitlab.rlp.net/mobitar/ReCo.jl.git
synced 2024-12-21 00:51:21 +00:00
76 lines
1.9 KiB
Julia
76 lines
1.9 KiB
Julia
|
abstract type Env <: AbstractEnv end
|
||
|
|
||
|
mutable struct EnvSharedProps{state_dims}
|
||
|
n_actions::Int64
|
||
|
action_space::Vector{SVector{2,Float64}}
|
||
|
action_ind_space::OneTo{Int64}
|
||
|
|
||
|
n_states::Int64
|
||
|
state_space::Vector{SVector{state_dims,Interval}}
|
||
|
state_ind_space::OneTo{Int64}
|
||
|
state_ind::Int64
|
||
|
|
||
|
reward::Float64
|
||
|
terminated::Bool
|
||
|
|
||
|
function EnvSharedProps(
|
||
|
n_states::Int64,
|
||
|
state_space::Vector{SVector{state_dims,Interval}};
|
||
|
n_v_actions::Int64=2,
|
||
|
n_ω_actions::Int64=3,
|
||
|
max_v::Float64=40.0,
|
||
|
max_ω::Float64=π / 2,
|
||
|
) where {state_dims}
|
||
|
@assert n_v_actions > 1
|
||
|
@assert n_ω_actions > 1
|
||
|
@assert max_v > 0
|
||
|
@assert max_ω > 0
|
||
|
|
||
|
v_action_space = range(; start=0.0, stop=max_v, length=n_v_actions)
|
||
|
ω_action_space = range(; start=-max_ω, stop=max_ω, length=n_ω_actions)
|
||
|
|
||
|
n_actions = n_v_actions * n_ω_actions
|
||
|
|
||
|
action_space = Vector{SVector{2,Float64}}(undef, n_actions)
|
||
|
|
||
|
ind = 1
|
||
|
for v in v_action_space
|
||
|
for ω in ω_action_space
|
||
|
action_space[ind] = SVector(v, ω)
|
||
|
ind += 1
|
||
|
end
|
||
|
end
|
||
|
|
||
|
action_ind_space = OneTo(n_actions)
|
||
|
|
||
|
state_ind_space = OneTo(n_states)
|
||
|
|
||
|
return new{state_dims}(
|
||
|
n_actions,
|
||
|
action_space,
|
||
|
action_ind_space,
|
||
|
n_states,
|
||
|
state_space,
|
||
|
state_ind_space,
|
||
|
INITIAL_STATE_IND,
|
||
|
INITIAL_REWARD,
|
||
|
false,
|
||
|
)
|
||
|
end
|
||
|
end
|
||
|
|
||
|
function reset!(env::Env)
|
||
|
env.shared.terminated = false
|
||
|
|
||
|
return nothing
|
||
|
end
|
||
|
|
||
|
RLBase.state_space(env::Env) = env.shared.state_ind_space
|
||
|
|
||
|
RLBase.state(env::Env) = env.shared.state_ind
|
||
|
|
||
|
RLBase.action_space(env::Env) = env.shared.action_ind_space
|
||
|
|
||
|
RLBase.reward(env::Env) = env.shared.reward
|
||
|
|
||
|
RLBase.is_terminated(env::Env) = env.shared.terminated
|