diff --git a/src/ReCo.jl b/src/ReCo.jl index cff1ce0..cd1eb19 100644 --- a/src/ReCo.jl +++ b/src/ReCo.jl @@ -4,6 +4,7 @@ export init_sim, run_sim include("PreVector.jl") include("Particle.jl") +include("reinforcement_learning.jl") include("data.jl") include("setup.jl") include("simulation.jl") diff --git a/src/reinforcement_learning.jl b/src/reinforcement_learning.jl new file mode 100644 index 0000000..558fcf0 --- /dev/null +++ b/src/reinforcement_learning.jl @@ -0,0 +1,78 @@ +using ReinforcementLearning + +mutable struct ReCoEnvParams + n_particles::Int64 + half_box_len::Float64 + skin_r::Float64 + action_space::Vector{Tuple{Float64,Float64}} + state_space::Vector{Tuple{Symbol,Symbol}} + reward::Float64 + + function ReCoEnvParams( + n_particles::Int64, + half_box_len::Float64, + skin_r::Float64, + n_v_actions::Int64, + n_ω_actions::Int64; + max_v::Float64=80.0, + max_ω::Float64=float(π), + ) + @assert half_box_len > 0 + @assert skin_r > 0 + @assert n_v_actions > 1 + @assert n_ω_actoins > 1 + @assert max_v > 0 + @assert max_ω > 0 + + v_action_space = 0.0:(max_v / (n_v_actions - 1)):max_v + ω_action_space = (-max_ω):(2 * max_ω / (n_ω_actions - 1)):max_ω + + n_actions = n_v_actions * n_ω_actions + + action_space = Vector{Tuple{Float64,Float64}}(undef, n_actions) + + ind = 1 + for v in v_action_space + for ω in ω_action_space + action_space[ind] = (v, ω) + ind += 1 + end + end + + distance_state_space = (:big, :medium, :small) + direction_state_space = (:before, :behind, :left, :right) + + n_states = undef, length(distance_state_space) * length(direction_state_space) + 1 + + state_space = Vector{Tuple{Symbol,Symbol}}(n_states) + + ind = 1 + for distance in distance_state_space + for direction in direction_state_space + state_space[ind] = (distance, direction) + ind += 1 + end + end + state_space[ind] = (:none, :none) + + return new(n_particles, half_box_len, skin_r, action_space, state_space, 0.0) + end +end + +mutable struct ReCoEnv <: AbstractEnv + params::ReCoEnvParams + particle::Particle + state::Tuple{Symbol,Symbol} + + function ReCoEnv(params::ReCoEnvParams, particle::Particle) + return new(params, particle, (:none, :none)) + end +end + +RLBase.state_space(env::ReCoEnv) = env.state_space + +RLBase.state(env::ReCoEnv) = env.state + +RLBase.action_space(env::ReCoEnv) = env.params.action_space + +RLBase.reward(env::ReCoEnv) = env.params.reward \ No newline at end of file