module ReCoRL using ReinforcementLearning using Intervals import Base: run struct DistanceState{L<:Bound} interval::Interval{Float64,L,Closed} function DistanceState{L}(lower::Float64, upper::Float64) where {L<:Bound} return new(Interval{Float64,L,Closed}(lower, upper)) end end struct DirectionState interval::Interval{Float64,Open,Closed} function DirectionState(lower::Float64, upper::Float64) return new(Interval{Float64,Closed,Open}(lower, upper)) end end mutable struct EnvParams action_space::Vector{Tuple{Float64,Float64}} distance_state_space::Vector{DistanceState} direction_state_space::Vector{DirectionState} state_space::Vector{Tuple{Union{DistanceState,Nothing},Union{DirectionState,Nothing}}} reward::Float64 function EnvParams( min_distance::Float64, max_distance::Float64; n_v_actions::Int64=2, n_ω_actions::Int64=3, max_v::Float64=20.0, max_ω::Float64=π / 1.5, n_distance_states::Int64=3, n_direction_states::Int64=4, ) @assert min_distance > 0.0 @assert max_distance > min_distance @assert n_v_actions > 1 @assert n_ω_actoins > 1 @assert max_v > 0 @assert max_ω > 0 v_action_space = 0.0:(max_v / (n_v_actions - 1)):max_v ω_action_space = (-max_ω):(2 * max_ω / (n_ω_actions - 1)):max_ω n_actions = n_v_actions * n_ω_actions action_space = Vector{Tuple{Float64,Float64}}(undef, n_actions) ind = 1 for v in v_action_space for ω in ω_action_space action_space[ind] = (v, ω) ind += 1 end end distance_range = min_distance:((max_distance - min_distance) / n_distance_states):max_distance distance_state_space = Vector{DistanceState}(undef, n_distance_states) for i in 1:n_distance_states if i == 1 bound = Closed else bound = Open end distance_state_space[i] = DistanceState{bound}( distance_range[i], distance_range[i + 1] ) end direction_range = 0.0:(2 * π / n_direction_states):(2 * π) direction_state_space = Vector{DirectionState}(undef, n_direction_states) for i in 1:n_direction_states direction_state_space[i] = DirectionState( direction_range[i], direction_range[i + 1] ) end state_space = Vector{Tuple{DistanceState,DirectionState}}(undef, n_states) ind = 1 for distance_state in distance_state_space for direction_state in direction_state_space state_space[ind] = (distance_state, direction_state) ind += 1 end end state_space[ind] = (nothing, nothing) return new( action_space, distance_state_space, direction_state_space, state_space, 0.0 ) end end mutable struct Env <: AbstractEnv params::EnvParams particle::Particle state::Tuple{Union{DistanceState,Nothing},Union{DirectionState,Nothing}} function Env(params::EnvParams, particle::Particle) return new(params, particle, (nothing, nothing)) end end struct Params envs::Vector{Env} # agents actions::Vector{Tuple{Float64,Float64}} env_params::EnvParams n_steps_before_actions_update::Int64 function Params( n_particles::Int64, env_params::EnvParams, n_steps_before_actions_update::Int64 ) return new( Vector{Env}(undef, n_particles), Vector{Tuple{Float64,Float64}}(undef, n_particles), env_params, n_steps_before_actions_update, ) end end RLBase.state_space(env::Env) = env.state_space RLBase.state(env::Env) = env.state RLBase.action_space(env::Env) = env.params.action_space RLBase.reward(env::Env) = env.params.reward function pre_integration_hook!() end function integration_hook() end function post_integration_hook() end function run( n_episodes::Int64=100, episode_duration::Float64=5.0, update_actions_at::Float64=0.1, n_particles::Int64=100, ) @assert n_episodes > 0 @assert episode_duration > 0 @assert update_actions_at in 0.01:0.01:episode_duration @assert episode_duration % update_actions_at == 0 @assert n_particles > 0 Random.seed!(42) # envs # agents # pre_experiment sim_consts = gen_sim_consts(n_particles, v₀) env_params = EnvParams(sim_consts.particle_radius, sim_consts.skin_r) rl_params = Params(n_particles, env_params, n_steps_before_actions_update) for episode in 1:n_episodes # reset # pre_episode dir = init_sim_with_sim_consts(; sim_consts, parent_dir="RL") run_sim( dir; duration=episode_duration, seed=rand(1:typemax(Int64)), rl_params=rl_params, skin_r=skin_r, ) end return nothing end end # module