diff --git a/src/RL.jl b/src/RL.jl index 7601e85..d9b6209 100644 --- a/src/RL.jl +++ b/src/RL.jl @@ -2,6 +2,8 @@ module RL export run_rl +using Base: OneTo + using ReinforcementLearning using Flux: InvDecay using Intervals @@ -13,34 +15,57 @@ using ProgressMeter: @showprogress using ..ReCo: ReCo, Particle, angle2, center_of_mass const INITIAL_REWARD = 0.0 +const INITIAL_STATE_IND = 1 + +function angle_state_space(n_angle_states::Int64) + angle_range = range(; start=-π, stop=π, length=n_angle_states + 1) + + angle_state_space = Vector{Interval}(undef, n_angle_states) + + @simd for i in 1:n_angle_states + if i == 1 + bound = Closed + else + bound = Open + end + + angle_state_space[i] = Interval{Float64,bound,Closed}( + angle_range[i], angle_range[i + 1] + ) + end + + return angle_state_space +end mutable struct Env <: AbstractEnv n_actions::Int64 action_space::Vector{SVector{2,Float64}} - action_ind_space::Vector{Int64} + action_ind_space::OneTo{Int64} distance_state_space::Vector{Interval} - angle_state_space::Vector{Interval} + direction_angle_state_space::Vector{Interval} + position_angle_state_space::Vector{Interval} n_states::Int64 - state_space::Vector{SVector{2,Interval}} - state_ind_space::Vector{Int64} + state_space::Vector{SVector{3,Interval}} + state_ind_space::OneTo{Int64} state_ind::Int64 reward::Float64 terminated::Bool - center_of_mass::SVector{2,Float64} + center_of_mass::SVector{2,Float64} # TODO: Use or remove - function Env( - max_distance::Float64; + function Env(; + max_distance::Float64, min_distance::Float64=0.0, - n_v_actions::Int64=3, + n_v_actions::Int64=2, n_ω_actions::Int64=3, max_v::Float64=40.0, max_ω::Float64=π / 2, - n_distance_states::Int64=3, - n_angle_states::Int64=4, + n_distance_states::Int64=4, + n_direction_angle_states::Int64=3, + n_position_angle_states::Int64=4, ) @assert min_distance >= 0.0 @assert max_distance > min_distance @@ -48,9 +73,12 @@ mutable struct Env <: AbstractEnv @assert n_ω_actions > 1 @assert max_v > 0 @assert max_ω > 0 + @assert n_distance_states > 1 + @assert n_direction_angle_states > 1 + @assert n_position_angle_states > 1 - v_action_space = 0.0:(max_v / (n_v_actions - 1)):max_v - ω_action_space = (-max_ω):(2 * max_ω / (n_ω_actions - 1)):max_ω + v_action_space = range(; start=0.0, stop=max_v, length=n_v_actions) + ω_action_space = range(; start=-max_ω, stop=max_ω, length=n_ω_actions) n_actions = n_v_actions * n_ω_actions @@ -64,10 +92,11 @@ mutable struct Env <: AbstractEnv end end - action_ind_space = collect(1:n_actions) + action_ind_space = OneTo(n_actions) - distance_range = - min_distance:((max_distance - min_distance) / n_distance_states):max_distance + distance_range = range(; + start=min_distance, stop=max_distance, length=n_distance_states + 1 + ) distance_state_space = Vector{Interval}(undef, n_distance_states) @@ -83,50 +112,38 @@ mutable struct Env <: AbstractEnv ) end - angle_range = (-π):(2 * π / n_angle_states):π + direction_angle_state_space = angle_state_space(n_direction_angle_states) + position_angle_state_space = angle_state_space(n_position_angle_states) - angle_state_space = Vector{Interval}(undef, n_angle_states) + n_states = n_distance_states * n_direction_angle_states * n_position_angle_states - @simd for i in 1:n_angle_states - if i == 1 - bound = Closed - else - bound = Open - end - - angle_state_space[i] = Interval{Float64,bound,Closed}( - angle_range[i], angle_range[i + 1] - ) - end - - n_states = n_distance_states * n_angle_states + 1 - - state_space = Vector{SVector{2,Interval}}(undef, n_states - 1) + state_space = Vector{SVector{3,Interval}}(undef, n_states) ind = 1 for distance_state in distance_state_space - for angle_state in angle_state_space - state_space[ind] = SVector(distance_state, angle_state) - ind += 1 + for direction_angle_state in direction_angle_state_space + for position_angle_state in position_angle_state_space + state_space[ind] = SVector( + distance_state, direction_angle_state, position_angle_state + ) + ind += 1 + end end end - # Last state is SVector(nothing, nothing) - state_ind_space = collect(1:n_states) - - # initial_state = SVector(nothing, nothing) - initial_state_ind = n_states + state_ind_space = OneTo(n_states) return new( n_actions, action_space, action_ind_space, distance_state_space, - angle_state_space, + direction_angle_state_space, + position_angle_state_space, n_states, state_space, state_ind_space, - initial_state_ind, + INITIAL_STATE_IND, INITIAL_REWARD, false, SVector(0.0, 0.0), @@ -171,9 +188,6 @@ struct Params{H<:AbstractHook} half_box_len::Float64 max_elliptic_distance::Float64 - local_centers_of_mass::Vector{SVector{2,Float64}} - updated_local_center_of_mass::Vector{Bool} - function Params( env::Env, agent::Agent, @@ -200,37 +214,32 @@ struct Params{H<:AbstractHook} n_particles, half_box_len, max_elliptic_distance, - fill(SVector(0.0, 0.0), n_particles), - falses(n_particles), ) end end function pre_integration_hook(rl_params::Params) - @simd for id in 1:(rl_params.n_particles) - rl_params.local_centers_of_mass[id] = SVector(0.0, 0.0) - rl_params.updated_local_center_of_mass[id] = false - end - return nothing end function state_update_helper_hook( rl_params::Params, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64} ) - rl_params.local_centers_of_mass[id1] += r⃗₁₂ - rl_params.local_centers_of_mass[id2] -= r⃗₁₂ - - rl_params.updated_local_center_of_mass[id1] = true - rl_params.updated_local_center_of_mass[id2] = true - return nothing end -function get_state_ind(state::S, state_space::Vector{S}) where {S<:SVector{2,Interval}} +function find_state_ind(state::S, state_space::Vector{S}) where {S<:SVector{3,Interval}} return findfirst(x -> x == state, state_space) end +function find_state_interval(value::Float64, state_space::Vector{Interval})::Interval + for state in state_space + if value in state + return state + end + end +end + function state_update_hook(rl_params::Params, particles::Vector{Particle}) @turbo for id in 1:(rl_params.n_particles) rl_params.old_states_ind[id] = rl_params.states_ind[id] @@ -238,39 +247,29 @@ function state_update_hook(rl_params::Params, particles::Vector{Particle}) env = rl_params.env - env_distance_state = env.distance_state_space[1] - env_angle_state = env.angle_state_space[1] - state_ind = 0 - for id in 1:(rl_params.n_particles) - if !rl_params.updated_local_center_of_mass[id] - state_ind = env.n_states - else - local_center_of_mass = rl_params.local_centers_of_mass[id] + particle = particles[id] - distance = sqrt(local_center_of_mass[1]^2 + local_center_of_mass[2]^2) + distance = sqrt(particle.c[1]^2 + particle.c[2]^2) - for distance_state in env.distance_state_space - if distance in distance_state - env_distance_state = distance_state - break - end - end + distance_state = find_state_interval(distance, env.distance_state_space) - si, co = sincos(particles[id].φ) + si, co = sincos(particles[id].φ) - angle = angle2(SVector(co, si), local_center_of_mass) + direction_angle = angle2(SVector(co, si), -particle.c) + position_angle = atan(particle.c[2], particle.c[1]) - for angle_state in env.angle_state_space - if angle in angle_state - env_angle_state = angle_state - break - end - end + direction_angle_state = find_state_interval( + direction_angle, env.direction_angle_state_space + ) + position_angle_state = find_state_interval( + position_angle, env.position_angle_state_space + ) - state = SVector{2,Interval}(env_distance_state, env_angle_state) - state_ind = get_state_ind(state, env.state_space) - end + state = SVector{3,Interval}( + distance_state, direction_angle_state, position_angle_state + ) + state_ind = find_state_ind(state, env.state_space) rl_params.states_ind[id] = state_ind end @@ -284,6 +283,14 @@ function get_env_agent_hook(rl_params::Params) return (rl_params.env, rl_params.agent, rl_params.hook) end +function update_reward!(env::Env, rl_params::Params, particle::Particle) + env.reward = + -(particle.c[1]^2 + particle.c[2]^2) / + (rl_params.max_elliptic_distance^2 * rl_params.n_particles) + + return nothing +end + function update_table_and_actions_hook( rl_params::Params, particle::Particle, first_integration_step::Bool ) @@ -305,13 +312,7 @@ function update_table_and_actions_hook( env.state_ind = rl_params.states_ind[id] # Update reward - vec_to_center_of_mass = ReCo.minimum_image( - particle.c - env.center_of_mass, rl_params.half_box_len - ) - - env.reward = - -(vec_to_center_of_mass[1]^2 + vec_to_center_of_mass[2]^2) / - rl_params.max_elliptic_distance / rl_params.n_particles + update_reward!(env, rl_params, particle) # Post act agent(POST_ACT_STAGE, env) @@ -343,14 +344,24 @@ function act_hook( return nothing end -function gen_agent(n_states::Int64, n_actions::Int64, ϵ::Float64) +function gen_agent(n_states::Int64, n_actions::Int64, ϵ_stable::Float64) + # TODO: Optimize warmup and decay + warmup_steps = 200_000 + decay_steps = 1_000_000 + policy = QBasedPolicy(; learner=MonteCarloLearner(; approximator=TabularQApproximator(; n_state=n_states, n_action=n_actions, opt=InvDecay(1.0) ), ), - explorer=EpsilonGreedyExplorer(ϵ), + explorer=EpsilonGreedyExplorer(; + kind=:linear, + ϵ_init=1.0, + ϵ_stable=ϵ_stable, + warmup_steps=warmup_steps, + decay_steps=decay_steps, + ), ) return Agent(; policy=policy, trajectory=VectorSARTTrajectory()) @@ -363,7 +374,7 @@ function run_rl(; update_actions_at::Float64=0.1, n_particles::Int64=100, seed::Int64=42, - ϵ::Float64=0.01, + ϵ_stable::Float64=0.0001, parent_dir::String="", ) @assert 0.0 <= goal_shape_ratio <= 1.0 @@ -371,19 +382,19 @@ function run_rl(; @assert episode_duration > 0 @assert update_actions_at in 0.001:0.001:episode_duration @assert n_particles > 0 - @assert 0.0 < ϵ < 1.0 + @assert 0.0 < ϵ_stable < 1.0 # Setup Random.seed!(seed) sim_consts = ReCo.gen_sim_consts( - n_particles, 0.0; skin_to_interaction_r_ratio=1.8, packing_ratio=0.15 + n_particles, 0.0; skin_to_interaction_r_ratio=1.5, packing_ratio=0.22 ) n_particles = sim_consts.n_particles - env = Env(sim_consts.skin_r) + env = Env(; max_distance=sqrt(2) * sim_consts.half_box_len) - agent = gen_agent(env.n_states, env.n_actions, ϵ) + agent = gen_agent(env.n_states, env.n_actions, ϵ_stable) n_steps_before_actions_update = round(Int64, update_actions_at / sim_consts.δt) @@ -426,7 +437,9 @@ function run_rl(; hook(POST_EPISODE_STAGE, agent, env) agent(POST_EPISODE_STAGE, env) + # TODO: Replace with live plot display(hook.rewards) + display(agent.policy.explorer.step) end # Post experiment