mirror of
https://gitlab.rlp.net/mobitar/ReCo.jl.git
synced 2024-12-21 00:51:21 +00:00
Added reinforcement_learning.jl
This commit is contained in:
parent
50feaee469
commit
8dd608222f
2 changed files with 79 additions and 0 deletions
|
@ -4,6 +4,7 @@ export init_sim, run_sim
|
|||
|
||||
include("PreVector.jl")
|
||||
include("Particle.jl")
|
||||
include("reinforcement_learning.jl")
|
||||
include("data.jl")
|
||||
include("setup.jl")
|
||||
include("simulation.jl")
|
||||
|
|
78
src/reinforcement_learning.jl
Normal file
78
src/reinforcement_learning.jl
Normal file
|
@ -0,0 +1,78 @@
|
|||
using ReinforcementLearning
|
||||
|
||||
mutable struct ReCoEnvParams
|
||||
n_particles::Int64
|
||||
half_box_len::Float64
|
||||
skin_r::Float64
|
||||
action_space::Vector{Tuple{Float64,Float64}}
|
||||
state_space::Vector{Tuple{Symbol,Symbol}}
|
||||
reward::Float64
|
||||
|
||||
function ReCoEnvParams(
|
||||
n_particles::Int64,
|
||||
half_box_len::Float64,
|
||||
skin_r::Float64,
|
||||
n_v_actions::Int64,
|
||||
n_ω_actions::Int64;
|
||||
max_v::Float64=80.0,
|
||||
max_ω::Float64=float(π),
|
||||
)
|
||||
@assert half_box_len > 0
|
||||
@assert skin_r > 0
|
||||
@assert n_v_actions > 1
|
||||
@assert n_ω_actoins > 1
|
||||
@assert max_v > 0
|
||||
@assert max_ω > 0
|
||||
|
||||
v_action_space = 0.0:(max_v / (n_v_actions - 1)):max_v
|
||||
ω_action_space = (-max_ω):(2 * max_ω / (n_ω_actions - 1)):max_ω
|
||||
|
||||
n_actions = n_v_actions * n_ω_actions
|
||||
|
||||
action_space = Vector{Tuple{Float64,Float64}}(undef, n_actions)
|
||||
|
||||
ind = 1
|
||||
for v in v_action_space
|
||||
for ω in ω_action_space
|
||||
action_space[ind] = (v, ω)
|
||||
ind += 1
|
||||
end
|
||||
end
|
||||
|
||||
distance_state_space = (:big, :medium, :small)
|
||||
direction_state_space = (:before, :behind, :left, :right)
|
||||
|
||||
n_states = undef, length(distance_state_space) * length(direction_state_space) + 1
|
||||
|
||||
state_space = Vector{Tuple{Symbol,Symbol}}(n_states)
|
||||
|
||||
ind = 1
|
||||
for distance in distance_state_space
|
||||
for direction in direction_state_space
|
||||
state_space[ind] = (distance, direction)
|
||||
ind += 1
|
||||
end
|
||||
end
|
||||
state_space[ind] = (:none, :none)
|
||||
|
||||
return new(n_particles, half_box_len, skin_r, action_space, state_space, 0.0)
|
||||
end
|
||||
end
|
||||
|
||||
mutable struct ReCoEnv <: AbstractEnv
|
||||
params::ReCoEnvParams
|
||||
particle::Particle
|
||||
state::Tuple{Symbol,Symbol}
|
||||
|
||||
function ReCoEnv(params::ReCoEnvParams, particle::Particle)
|
||||
return new(params, particle, (:none, :none))
|
||||
end
|
||||
end
|
||||
|
||||
RLBase.state_space(env::ReCoEnv) = env.state_space
|
||||
|
||||
RLBase.state(env::ReCoEnv) = env.state
|
||||
|
||||
RLBase.action_space(env::ReCoEnv) = env.params.action_space
|
||||
|
||||
RLBase.reward(env::ReCoEnv) = env.params.reward
|
Loading…
Reference in a new issue