From 8ad67229a8c955ad6b5ebcc6183b2bd86c5d1504 Mon Sep 17 00:00:00 2001 From: MoBit Date: Tue, 14 Dec 2021 04:03:14 +0100 Subject: [PATCH] Fixes incl. state --- src/ReCo.jl | 2 +- src/reinforcement_learning.jl | 257 +++++++++++++++++++++------------- src/run.jl | 8 +- src/setup.jl | 8 +- src/simulation.jl | 12 +- 5 files changed, 170 insertions(+), 117 deletions(-) diff --git a/src/ReCo.jl b/src/ReCo.jl index cd1eb19..27d93ef 100644 --- a/src/ReCo.jl +++ b/src/ReCo.jl @@ -1,6 +1,6 @@ module ReCo -export init_sim, run_sim +export init_sim, run_sim, RL include("PreVector.jl") include("Particle.jl") diff --git a/src/reinforcement_learning.jl b/src/reinforcement_learning.jl index 79997fe..4daa622 100644 --- a/src/reinforcement_learning.jl +++ b/src/reinforcement_learning.jl @@ -1,9 +1,12 @@ module RL +export run + using ReinforcementLearning using Flux: InvDecay using Intervals using StaticArrays: SVector +using LoopVectorization: @turbo using Random: Random using ProgressMeter: @showprogress @@ -11,6 +14,8 @@ using ..ReCo import Base: run +const INITIAL_REWARD = 0.0 + struct DistanceState{L<:Bound} interval::Interval{Float64,L,Closed} @@ -29,22 +34,25 @@ end mutable struct EnvParams action_space::Vector{Tuple{Float64,Float64}} - action_space_ind::Vector{Int64} + action_ind_space::Vector{Int64} + distance_state_space::Vector{DistanceState} direction_state_space::Vector{DirectionState} - state_space::Vector{Tuple{Union{DistanceState,Nothing},Union{DirectionState,Nothing}}} - state_space_ind::Vector{Int64} + state_space::Vector{Union{Tuple{DistanceState,DirectionState},Tuple{Nothing,Nothing}}} + state_ind_space::Vector{Int64} + n_states::Int64 + reward::Float64 function EnvParams( min_distance::Float64, max_distance::Float64; - n_v_actions::Int64=5, - n_ω_actions::Int64=5, + n_v_actions::Int64=3, + n_ω_actions::Int64=3, max_v::Float64=80.0, max_ω::Float64=π / 1.5, - n_distance_states::Int64=3, - n_direction_states::Int64=4, + n_distance_states::Int64=2, + n_direction_states::Int64=2, ) @assert min_distance > 0.0 @assert max_distance > min_distance @@ -68,14 +76,14 @@ mutable struct EnvParams end end - action_space_ind = collect(1:n_actions) + action_ind_space = collect(1:n_actions) distance_range = min_distance:((max_distance - min_distance) / n_distance_states):max_distance distance_state_space = Vector{DistanceState}(undef, n_distance_states) - for i in 1:n_distance_states + @simd for i in 1:n_distance_states if i == 1 bound = Closed else @@ -91,7 +99,7 @@ mutable struct EnvParams direction_state_space = Vector{DirectionState}(undef, n_direction_states) - for i in 1:n_direction_states + @simd for i in 1:n_direction_states direction_state_space[i] = DirectionState( direction_range[i], direction_range[i + 1] ) @@ -100,7 +108,7 @@ mutable struct EnvParams n_states = n_distance_states * n_direction_states + 1 state_space = Vector{ - Tuple{Union{DistanceState,Nothing},Union{DirectionState,Nothing}} + Union{Tuple{DistanceState,DirectionState},Tuple{Nothing,Nothing}} }( undef, n_states ) @@ -114,22 +122,27 @@ mutable struct EnvParams end state_space[ind] = (nothing, nothing) - state_space_ind = collect(1:n_states) - - initial_reward = 0.0 + state_ind_space = collect(1:n_states) return new( action_space, - action_space_ind, + action_ind_space, distance_state_space, direction_state_space, state_space, - state_space_ind, - initial_reward, + state_ind_space, + n_states, + INITIAL_REWARD, ) end end +function reset!(env_params::EnvParams) + env_params.reward = INITIAL_REWARD + + return nothing +end + mutable struct Env <: AbstractEnv params::EnvParams particle::ReCo.Particle @@ -137,12 +150,29 @@ mutable struct Env <: AbstractEnv function Env(params::EnvParams, particle::ReCo.Particle) # initial_state = (nothing, nothing) - initial_state_ind = length(params.state_space_ind) + initial_state_ind = params.n_states return new(params, particle, initial_state_ind) end end +function reset!(env::Env, particle::ReCo.Particle) + env.particle = particle + env.state_ind = env.params.n_states + + return nothing +end + +RLBase.state_space(env::Env) = env.params.state_ind_space + +RLBase.state(env::Env) = env.state_ind + +RLBase.action_space(env::Env) = env.params.action_ind_space + +RLBase.reward(env::Env) = env.params.reward + +RLBase.is_terminated(::Env) = false + function gen_policy(n_states::Int64, n_actions::Int64) return QBasedPolicy(; learner=MonteCarloLearner(; @@ -161,8 +191,8 @@ struct Params{H<:AbstractHook} actions::Vector{Tuple{Float64,Float64}} env_params::EnvParams n_steps_before_actions_update::Int64 - min_distance²::Vector{Float64} - r⃗₁₂_to_min_distance_particle::Vector{SVector{2,Float64}} + min_sq_distances::Vector{Float64} + vecs_r⃗₁₂_to_min_distance_particle::Vector{SVector{2,Float64}} goal_shape_ratio::Float64 function Params{H}( @@ -171,48 +201,57 @@ struct Params{H<:AbstractHook} n_steps_before_actions_update::Int64, goal_shape_ratio::Float64, ) where {H<:AbstractHook} - policies = [ - gen_policy(length(env_params.state_space), length(env_params.action_space)) for - i in 1:n_particles - ] + envs = [Env(env_params, ReCo.gen_tmp_particle()) for i in 1:n_particles] + agents = [ - Agent(; policy=policy, trajectory=VectorSARTTrajectory()) for policy in policies + Agent(; + policy=gen_policy(env_params.n_states, length(env_params.action_space)), + trajectory=VectorSARTTrajectory(), + ) for i in 1:n_particles ] + + hooks = [H() for i in 1:n_particles] + + actions = Vector{Tuple{Float64,Float64}}(undef, n_particles) + + min_sq_distances = fill(Inf64, n_particles) + + vecs_r⃗₁₂_to_min_distance_particle = fill(SVector(0.0, 0.0), n_particles) + return new( - [Env(env_params, ReCo.gen_tmp_particle()) for i in 1:n_particles], + envs, agents, - [H() for i in 1:n_particles], - Vector{Tuple{Float64,Float64}}(undef, n_particles), + hooks, + actions, env_params, n_steps_before_actions_update, - zeros(n_particles), - fill(SVector(0.0, 0.0), n_particles), + min_sq_distances, + vecs_r⃗₁₂_to_min_distance_particle, goal_shape_ratio, ) end end -RLBase.state_space(env::Env) = env.params.state_space_ind - -RLBase.state(env::Env) = env.state_ind - -RLBase.action_space(env::Env) = env.params.action_space_ind - -RLBase.reward(env::Env) = env.params.reward - -RLBase.is_terminated(::Env) = false +function get_env_agent_hook(rl_params::Params, ind::Int64) + return (rl_params.envs[ind], rl_params.agents[ind], rl_params.hooks[ind]) +end function pre_integration_hook!(rl_params::Params, n_particles::Int64) - for i in 1:n_particles - env = rl_params.envs[i] - agent = rl_params.agents[i] + @simd for i in 1:n_particles + env, agent, hook = get_env_agent_hook(rl_params, i) + # Update action action_ind = agent(env) action = rl_params.env_params.action_space[action_ind] rl_params.actions[i] = action + # Pre act agent(PRE_ACT_STAGE, env, action_ind) - rl_params.hooks[i](PRE_ACT_STAGE, agent, env, action_ind) + hook(PRE_ACT_STAGE, agent, env, action_ind) + end + + @turbo for i in 1:n_particles + rl_params.min_sq_distances[i] = Inf64 end return nothing @@ -221,24 +260,25 @@ end function state_hook( id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}, distance²::Float64, rl_params::Params ) - if rl_params.min_distance²[id1] > distance² - rl_params.min_distance²[id1] = distance² + if rl_params.min_sq_distances[id1] > distance² + rl_params.min_sq_distances[id1] = distance² - rl_params.r⃗₁₂_to_min_distance_particle[id1] = r⃗₁₂ + rl_params.vecs_r⃗₁₂_to_min_distance_particle[id1] = r⃗₁₂ end - if rl_params.min_distance²[id2] > distance² - rl_params.min_distance²[id2] = distance² + if rl_params.min_sq_distances[id2] > distance² + rl_params.min_sq_distances[id2] = distance² - rl_params.r⃗₁₂_to_min_distance_particle[id2] = -r⃗₁₂ + rl_params.vecs_r⃗₁₂_to_min_distance_particle[id2] = -r⃗₁₂ end return nothing end -function integration_hook( +function integration_hook!( particle::ReCo.Particle, rl_params::Params, δt::Float64, si::Float64, co::Float64 ) + # Apply action action = rl_params.actions[particle.id] vδt = action[1] * δt @@ -248,10 +288,12 @@ function integration_hook( return nothing end -function get_state_ind( - state::T, states::Vector{T} -) where {T<:Tuple{Union{DistanceState,Nothing},Union{DirectionState,Nothing}}} - return findfirst(x -> x == state, states) +function get_state_ind(state::Tuple{DistanceState,DirectionState}, env_params::EnvParams) + return findfirst(x -> x == state, env_params.state_space) +end + +function get_state_ind(::Tuple{Nothing,Nothing}, env_params::EnvParams) + return env_params.n_states end function post_integration_hook( @@ -260,28 +302,41 @@ function post_integration_hook( particles::Vector{ReCo.Particle}, half_box_len::Float64, ) + # Update reward + rl_params.env_params.reward = + 1 - + ( + ReCo.gyration_tensor_eigvals_ratio(particles, half_box_len) - + rl_params.goal_shape_ratio + )^2 + + # Update states + n_states = rl_params.env_params.n_states + env_direction_state = rl_params.env_params.direction_state_space[1] for i in 1:n_particles - env = rl_params.envs[i] - agent = rl_params.agents[i] - - min_distance = sqrt(rl_params.min_distance²[i]) + env, agent, hook = get_env_agent_hook(rl_params, i) env_distance_state::Union{DistanceState,Nothing} = nothing - for distance_state in rl_params.env_params.distance_state_space - if min_distance in distance_state.interval - env_distance_state = distance_state - break + min_sq_distance = rl_params.min_sq_distances[i] + min_distance = sqrt(min_sq_distance) + + if !isinf(min_sq_distance) + for distance_state in rl_params.env_params.distance_state_space + if min_distance in distance_state.interval + env_distance_state = distance_state + break + end end end if isnothing(env_distance_state) # (nothing, nothing) - env.state_ind = length(env.params.state_space) + env.state_ind = n_states else - r⃗₁₂ = rl_params.r⃗₁₂_to_min_distance_particle[i] + r⃗₁₂ = rl_params.vecs_r⃗₁₂_to_min_distance_particle[i] si, co = sincos(particles[i].φ) #= @@ -290,28 +345,25 @@ function post_integration_hook( angle = acos(dot(r⃗₁₂, e) / (norm(r⃗₁₂) * norm(e))) norm(r⃗₁₂) == min_distance norm(e) == 1 + + min_distance is not infinite, because otherwise + env_direction_state would be nothing and this else block will not be called =# direction = acos((r⃗₁₂[1] * co + r⃗₁₂[2] * si) / min_distance) for direction_state in rl_params.env_params.direction_state_space - if direction in direction_state + if direction in direction_state.interval env_direction_state = direction_state end end state = (env_distance_state, env_direction_state) - env.state_ind = get_state_ind(state, env.params.state_space) + env.state_ind = get_state_ind(state, env.params) end - env.params.reward = - 1 - - ( - ReCo.gyration_tensor_eigvals_ratio(particles, half_box_len) - - rl_params.goal_shape_ratio - )^2 - + # Post act agent(POST_ACT_STAGE, env) - rl_params.hooks[i](POST_ACT_STAGE, agent, env) + hook(POST_ACT_STAGE, agent, env) end return nothing @@ -320,9 +372,10 @@ end function run(; goal_shape_ratio::Float64, n_episodes::Int64=100, - episode_duration::Float64=100.0, - update_actions_at::Float64=0.1, + episode_duration::Float64=50.0, + update_actions_at::Float64=0.2, n_particles::Int64=100, + seed::Int64=42, ) @assert 0.0 <= goal_shape_ratio <= 1.0 @assert n_episodes > 0 @@ -330,9 +383,10 @@ function run(; @assert update_actions_at in 0.01:0.01:episode_duration @assert n_particles > 0 - Random.seed!(42) + # Setup + Random.seed!(seed) - sim_consts = ReCo.gen_sim_consts(n_particles, 0.0; skin_to_interaction_r_ratio=4.0) + sim_consts = ReCo.gen_sim_consts(n_particles, 0.0; skin_to_interaction_r_ratio=3.0) n_particles = sim_consts.n_particles env_params = EnvParams(sim_consts.particle_radius, sim_consts.skin_r) @@ -343,50 +397,51 @@ function run(; n_particles, env_params, n_steps_before_actions_update, goal_shape_ratio ) - for i in 1:n_particles - env = rl_params.envs[i] - agent = rl_params.agents[i] + # Pre experiment + @simd for i in 1:n_particles + env, agent, hook = get_env_agent_hook(rl_params, i) - rl_params.hooks[i](PRE_EXPERIMENT_STAGE, agent, env) + hook(PRE_EXPERIMENT_STAGE, agent, env) agent(PRE_EXPERIMENT_STAGE, env) end @showprogress 0.6 for episode in 1:n_episodes dir, particles = ReCo.init_sim_with_sim_consts(sim_consts; parent_dir="RL") - for i in 1:n_particles - env = rl_params.envs[i] - agent = rl_params.agents[i] + # Reset + @simd for i in 1:n_particles + reset!(rl_params.envs[i], particles[i]) + end - rl_params.hooks[i](PRE_EPISODE_STAGE, agent, env) + reset!(rl_params.env_params) + + # Pre espisode + @simd for i in 1:n_particles + env, agent, hook = get_env_agent_hook(rl_params, i) + + hook(PRE_EPISODE_STAGE, agent, env) agent(PRE_EPISODE_STAGE, env) end - for i in 1:n_particles - rl_params.envs[i].particle = particles[i] - rl_params.envs[i].state_ind = length(rl_params.env_params.state_space) - end - - rl_params.env_params.reward = 0.0 - + # Episode run_sim( dir; duration=episode_duration, seed=rand(1:typemax(Int64)), rl_params=rl_params ) - for i in 1:n_particles - env = rl_params.envs[i] - agent = rl_params.agents[i] + # Post episode + @simd for i in 1:n_particles + env, agent, hook = get_env_agent_hook(rl_params, i) - rl_params.hooks[i](POST_EPISODE_STAGE, agent, env) + hook(POST_EPISODE_STAGE, agent, env) agent(POST_EPISODE_STAGE, env) end end - for i in 1:n_particles - env = rl_params.envs[i] - agent = rl_params.agents[i] + # Post experiment + @simd for i in 1:n_particles + env, agent, hook = get_env_agent_hook(rl_params, i) - rl_params.hooks[i](POST_EXPERIMENT_STAGE, agent, env) + hook(POST_EXPERIMENT_STAGE, agent, env) end return rl_params diff --git a/src/run.jl b/src/run.jl index 7f9fe6e..893cf67 100644 --- a/src/run.jl +++ b/src/run.jl @@ -109,12 +109,10 @@ function run_sim( if !isnothing(rl_params) pre_integration_hook! = RL.pre_integration_hook! - integration_hook = RL.integration_hook + integration_hook! = RL.integration_hook! post_integration_hook = RL.post_integration_hook else - pre_integration_hook! = empty_hook - integration_hook = empty_hook - post_integration_hook = empty_hook + pre_integration_hook! = integration_hook! = post_integration_hook = empty_hook end simulate( @@ -128,7 +126,7 @@ function run_sim( save_data, rl_params, pre_integration_hook!, - integration_hook, + integration_hook!, post_integration_hook, ) diff --git a/src/setup.jl b/src/setup.jl index 05a7c41..0b76004 100644 --- a/src/setup.jl +++ b/src/setup.jl @@ -65,11 +65,11 @@ function gen_sim_consts( ϵ = 100.0 interaction_r = 2^(1 / 6) * σ - if v₀ != 0.0 - buffer = 1.8 - max_approach_after_one_integration_step = buffer * (2 * v₀ * δt) / interaction_r - @assert skin_to_interaction_r_ratio >= 1 + max_approach_after_one_integration_step + buffer = 1.8 + max_approach_after_one_integration_step = buffer * (2 * v₀ * δt) / interaction_r + @assert skin_to_interaction_r_ratio >= 1 + max_approach_after_one_integration_step + if v₀ != 0.0 n_steps_before_verlet_list_update = round( Int64, (skin_to_interaction_r_ratio - 1) / max_approach_after_one_integration_step, diff --git a/src/simulation.jl b/src/simulation.jl index 68dc554..5b855a1 100644 --- a/src/simulation.jl +++ b/src/simulation.jl @@ -43,7 +43,7 @@ end function euler!( args, state_hook::Function, - integration_hook::Function, + integration_hook!::Function, rl_params::Union{RL.Params,Nothing}, ) for id1 in 1:(args.n_particles - 1) @@ -79,7 +79,7 @@ function euler!( restrict_coordinates!(p, args.half_box_len) - integration_hook(p, rl_params, args.δt, si, co) + integration_hook!(p, rl_params, args.δt, si, co) p.φ += args.c₄ * rand_normal01() @@ -94,8 +94,8 @@ wait(::Nothing) = nothing gen_run_hooks(::Nothing, args...) = false function gen_run_hooks(rl_params::RL.Params, integration_step::Int64) - return (integration_step == 1) || - (integration_step % rl_params.n_steps_before_actions_update == 0) + return (integration_step % rl_params.n_steps_before_actions_update == 0) || + (integration_step == 1) end function simulate( @@ -109,7 +109,7 @@ function simulate( save_data::Bool, rl_params::Union{RL.Params,Nothing}, pre_integration_hook!::Function, - integration_hook::Function, + integration_hook!::Function, post_integration_hook::Function, ) bundle_snapshot_counter = 0 @@ -153,7 +153,7 @@ function simulate( state_hook = RL.state_hook end - euler!(args, state_hook, integration_hook, rl_params) + euler!(args, state_hook, integration_hook!, rl_params) if run_hooks post_integration_hook(