module RL export run_rl, LocalCOMEnv using Base: OneTo using ReinforcementLearning using Flux: InvDecay using Intervals using StaticArrays: SVector using LoopVectorization: @turbo using Random: Random using ProgressMeter: @showprogress using ..ReCo: ReCo, Particle, angle2, Shape, DEFAULT_SKIN_TO_INTERACTION_R_RATIO const INITIAL_STATE_IND = 1 const INITIAL_REWARD = 0.0 method_not_implemented() = error("Method not implemented!") function gen_angle_state_space(n_angle_states::Int64) angle_range = range(; start=-π, stop=π, length=n_angle_states + 1) angle_state_space = Vector{Interval}(undef, n_angle_states) @simd for i in 1:n_angle_states if i == 1 bound = Closed else bound = Open end angle_state_space[i] = Interval{Float64,bound,Closed}( angle_range[i], angle_range[i + 1] ) end return angle_state_space end function gen_distance_state_space( min_distance::Float64, max_distance::Float64, n_distance_states::Int64 ) @assert min_distance >= 0.0 @assert max_distance > min_distance @assert n_distance_states > 1 distance_range = range(; start=min_distance, stop=max_distance, length=n_distance_states + 1 ) distance_state_space = Vector{Interval}(undef, n_distance_states) @simd for i in 1:n_distance_states if i == 1 bound = Closed else bound = Open end distance_state_space[i] = Interval{Float64,bound,Closed}( distance_range[i], distance_range[i + 1] ) end return distance_state_space end abstract type Env <: AbstractEnv end mutable struct EnvParams{state_dims} n_actions::Int64 action_space::Vector{SVector{2,Float64}} action_ind_space::OneTo{Int64} n_states::Int64 state_space::Vector{SVector{state_dims,Interval}} state_ind_space::OneTo{Int64} state_ind::Int64 reward::Float64 terminated::Bool function EnvParams( n_states::Int64, state_space::Vector{SVector{state_dims,Interval}}; n_v_actions::Int64=2, n_ω_actions::Int64=3, max_v::Float64=40.0, max_ω::Float64=π / 2, ) where {state_dims} @assert n_v_actions > 1 @assert n_ω_actions > 1 @assert max_v > 0 @assert max_ω > 0 v_action_space = range(; start=0.0, stop=max_v, length=n_v_actions) ω_action_space = range(; start=-max_ω, stop=max_ω, length=n_ω_actions) n_actions = n_v_actions * n_ω_actions action_space = Vector{SVector{2,Float64}}(undef, n_actions) ind = 1 for v in v_action_space for ω in ω_action_space action_space[ind] = SVector(v, ω) ind += 1 end end action_ind_space = OneTo(n_actions) state_ind_space = OneTo(n_states) return new{state_dims}( n_actions, action_space, action_ind_space, n_states, state_space, state_ind_space, INITIAL_STATE_IND, INITIAL_REWARD, false, ) end end function reset!(env::Env) env.params.terminated = false return nothing end RLBase.state_space(env::Env) = env.params.state_ind_space RLBase.state(env::Env) = env.params.state_ind RLBase.action_space(env::Env) = env.params.action_ind_space RLBase.reward(env::Env) = env.params.reward RLBase.is_terminated(env::Env) = env.params.terminated struct EnvHelperParams{H<:AbstractHook} env::Env agent::Agent hook::H n_steps_before_actions_update::Int64 goal_gyration_tensor_eigvals_ratio::Float64 n_particles::Int64 old_states_ind::Vector{Int64} states_ind::Vector{Int64} actions::Vector{SVector{2,Float64}} actions_ind::Vector{Int64} function EnvHelperParams( env::Env, agent::Agent, hook::H, n_steps_before_actions_update::Int64, goal_gyration_tensor_eigvals_ratio::Float64, n_particles::Int64, ) where {H<:AbstractHook} return new{H}( env, agent, hook, n_steps_before_actions_update, goal_gyration_tensor_eigvals_ratio, n_particles, fill(0, n_particles), fill(0, n_particles), fill(SVector(0.0, 0.0), n_particles), fill(0, n_particles), ) end end abstract type EnvHelper end function gen_env_helper(::Env, env_helper_params::EnvHelperParams) return method_not_implemented() end function pre_integration_hook(::EnvHelper) return method_not_implemented() end function state_update_helper_hook( ::EnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64} ) return method_not_implemented() end function find_state_ind(state::S, state_space::Vector{S}) where {S<:SVector} return findfirst(x -> x == state, state_space) end function find_state_interval(value::Float64, state_space::Vector{Interval})::Interval for state in state_space if value in state return state end end end function state_update_hook(::EnvHelper, particles::Vector{Particle}) return method_not_implemented() end function get_env_agent_hook(env_helper::EnvHelper) return (env_helper.params.env, env_helper.params.agent, env_helper.params.hook) end function update_reward!(::Env, ::EnvHelper, particle::Particle) return method_not_implemented() end function update_table_and_actions_hook( env_helper::EnvHelper, particle::Particle, first_integration_step::Bool ) env, agent, hook = get_env_agent_hook(env_helper) id = particle.id if !first_integration_step # Old state env.params.state_ind = env_helper.params.old_states_ind[id] action_ind = env_helper.params.actions_ind[id] # Pre act agent(PRE_ACT_STAGE, env, action_ind) hook(PRE_ACT_STAGE, agent, env, action_ind) # Update to current state env.params.state_ind = env_helper.params.states_ind[id] # Update reward update_reward!(env, env_helper, particle) # Post act agent(POST_ACT_STAGE, env) hook(POST_ACT_STAGE, agent, env) end # Update action action_ind = agent(env) action = env.params.action_space[action_ind] env_helper.params.actions[id] = action env_helper.params.actions_ind[id] = action_ind return nothing end act_hook(::Nothing, args...) = nothing function act_hook( env_helper::EnvHelper, particle::Particle, δt::Float64, si::Float64, co::Float64 ) # Apply action action = env_helper.params.actions[particle.id] vδt = action[1] * δt particle.tmp_c += SVector(vδt * co, vδt * si) particle.φ += action[2] * δt return nothing end function gen_agent(n_states::Int64, n_actions::Int64, ϵ_stable::Float64) # TODO: Optimize warmup and decay warmup_steps = 200_000 decay_steps = 1_000_000 policy = QBasedPolicy(; learner=MonteCarloLearner(; approximator=TabularQApproximator(; n_state=n_states, n_action=n_actions, opt=InvDecay(1.0) ), ), explorer=EpsilonGreedyExplorer(; kind=:linear, ϵ_init=1.0, ϵ_stable=ϵ_stable, warmup_steps=warmup_steps, decay_steps=decay_steps, ), ) return Agent(; policy=policy, trajectory=VectorSARTTrajectory()) end function run_rl(; EnvType::Type{E}, parent_dir_appendix::String, goal_gyration_tensor_eigvals_ratio::Float64, n_episodes::Int64=200, episode_duration::Float64=50.0, update_actions_at::Float64=0.1, n_particles::Int64=100, seed::Int64=42, ϵ_stable::Float64=0.0001, skin_to_interaction_r_ratio::Float64=DEFAULT_SKIN_TO_INTERACTION_R_RATIO, packing_ratio=0.22, ) where {E<:Env} @assert 0.0 <= goal_gyration_tensor_eigvals_ratio <= 1.0 @assert n_episodes > 0 @assert episode_duration > 0 @assert update_actions_at in 0.001:0.001:episode_duration @assert n_particles > 0 @assert 0.0 < ϵ_stable < 1.0 # Setup Random.seed!(seed) sim_consts = ReCo.gen_sim_consts( n_particles, 0.0; skin_to_interaction_r_ratio=skin_to_interaction_r_ratio, packing_ratio=packing_ratio, ) n_particles = sim_consts.n_particles # This not always equal to the input! env = EnvType(sim_consts) agent = gen_agent(env.params.n_states, env.params.n_actions, ϵ_stable) n_steps_before_actions_update = round(Int64, update_actions_at / sim_consts.δt) hook = TotalRewardPerEpisode() env_helper_params = EnvHelperParams( env, agent, hook, n_steps_before_actions_update, goal_gyration_tensor_eigvals_ratio, n_particles, ) env_helper = gen_env_helper(env, env_helper_params) parent_dir = "RL_" * parent_dir_appendix # Pre experiment hook(PRE_EXPERIMENT_STAGE, agent, env) agent(PRE_EXPERIMENT_STAGE, env) @showprogress 0.6 for episode in 1:n_episodes dir = ReCo.init_sim_with_sim_consts(sim_consts; parent_dir=parent_dir) # Reset reset!(env) # Pre espisode hook(PRE_EPISODE_STAGE, agent, env) agent(PRE_EPISODE_STAGE, env) # Episode ReCo.run_sim( dir; duration=episode_duration, seed=rand(1:typemax(Int64)), env_helper=env_helper, ) env.params.terminated = true # Post episode hook(POST_EPISODE_STAGE, agent, env) agent(POST_EPISODE_STAGE, env) # TODO: Replace with live plot display(hook.rewards) display(agent.policy.explorer.step) end # Post experiment hook(POST_EXPERIMENT_STAGE, agent, env) return env_helper end include("LocalCOMEnv.jl") end # module