diff --git a/src/Error.jl b/src/Error.jl new file mode 100644 index 0000000..d91e1cb --- /dev/null +++ b/src/Error.jl @@ -0,0 +1,7 @@ +module Error + +export method_not_implemented + +method_not_implemented() = error("Method not implemented!") + +end # module \ No newline at end of file diff --git a/src/RL/Env.jl b/src/RL/Env.jl new file mode 100644 index 0000000..44229b4 --- /dev/null +++ b/src/RL/Env.jl @@ -0,0 +1,76 @@ +abstract type Env <: AbstractEnv end + +mutable struct EnvSharedProps{state_dims} + n_actions::Int64 + action_space::Vector{SVector{2,Float64}} + action_ind_space::OneTo{Int64} + + n_states::Int64 + state_space::Vector{SVector{state_dims,Interval}} + state_ind_space::OneTo{Int64} + state_ind::Int64 + + reward::Float64 + terminated::Bool + + function EnvSharedProps( + n_states::Int64, + state_space::Vector{SVector{state_dims,Interval}}; + n_v_actions::Int64=2, + n_ω_actions::Int64=3, + max_v::Float64=40.0, + max_ω::Float64=π / 2, + ) where {state_dims} + @assert n_v_actions > 1 + @assert n_ω_actions > 1 + @assert max_v > 0 + @assert max_ω > 0 + + v_action_space = range(; start=0.0, stop=max_v, length=n_v_actions) + ω_action_space = range(; start=-max_ω, stop=max_ω, length=n_ω_actions) + + n_actions = n_v_actions * n_ω_actions + + action_space = Vector{SVector{2,Float64}}(undef, n_actions) + + ind = 1 + for v in v_action_space + for ω in ω_action_space + action_space[ind] = SVector(v, ω) + ind += 1 + end + end + + action_ind_space = OneTo(n_actions) + + state_ind_space = OneTo(n_states) + + return new{state_dims}( + n_actions, + action_space, + action_ind_space, + n_states, + state_space, + state_ind_space, + INITIAL_STATE_IND, + INITIAL_REWARD, + false, + ) + end +end + +function reset!(env::Env) + env.shared.terminated = false + + return nothing +end + +RLBase.state_space(env::Env) = env.shared.state_ind_space + +RLBase.state(env::Env) = env.shared.state_ind + +RLBase.action_space(env::Env) = env.shared.action_ind_space + +RLBase.reward(env::Env) = env.shared.reward + +RLBase.is_terminated(env::Env) = env.shared.terminated \ No newline at end of file diff --git a/src/RL/EnvHelper.jl b/src/RL/EnvHelper.jl new file mode 100644 index 0000000..a840987 --- /dev/null +++ b/src/RL/EnvHelper.jl @@ -0,0 +1,49 @@ +abstract type EnvHelper end + +struct EnvHelperSharedProps{H<:AbstractHook} + env::Env + agent::Agent + hook::H + + n_steps_before_actions_update::Int64 + + goal_gyration_tensor_eigvals_ratio::Float64 + + n_particles::Int64 + + old_states_ind::Vector{Int64} + states_ind::Vector{Int64} + + actions::Vector{SVector{2,Float64}} + actions_ind::Vector{Int64} + + function EnvHelperSharedProps( + env::Env, + agent::Agent, + hook::H, + n_steps_before_actions_update::Int64, + goal_gyration_tensor_eigvals_ratio::Float64, + n_particles::Int64, + ) where {H<:AbstractHook} + return new{H}( + env, + agent, + hook, + n_steps_before_actions_update, + goal_gyration_tensor_eigvals_ratio, + n_particles, + fill(0, n_particles), + fill(0, n_particles), + fill(SVector(0.0, 0.0), n_particles), + fill(0, n_particles), + ) + end +end + +function gen_env_helper(::Env, env_helper_params::EnvHelperSharedProps) + return method_not_implemented() +end + +function get_env_agent_hook(env_helper::EnvHelper) + return (env_helper.shared.env, env_helper.shared.agent, env_helper.shared.hook) +end \ No newline at end of file diff --git a/src/RL/Hooks.jl b/src/RL/Hooks.jl new file mode 100644 index 0000000..a5311d7 --- /dev/null +++ b/src/RL/Hooks.jl @@ -0,0 +1,70 @@ +function pre_integration_hook(::EnvHelper) + return method_not_implemented() +end + +function state_update_helper_hook( + ::EnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64} +) + return method_not_implemented() +end + +function state_update_hook(::EnvHelper, particles::Vector{Particle}) + return method_not_implemented() +end + +function update_reward!(::Env, ::EnvHelper, particle::Particle) + return method_not_implemented() +end + +function update_table_and_actions_hook( + env_helper::EnvHelper, particle::Particle, first_integration_step::Bool +) + env, agent, hook = get_env_agent_hook(env_helper) + + id = particle.id + + if !first_integration_step + # Old state + env.shared.state_ind = env_helper.shared.old_states_ind[id] + + action_ind = env_helper.shared.actions_ind[id] + + # Pre act + agent(PRE_ACT_STAGE, env, action_ind) + hook(PRE_ACT_STAGE, agent, env, action_ind) + + # Update to current state + env.shared.state_ind = env_helper.shared.states_ind[id] + + # Update reward + update_reward!(env, env_helper, particle) + + # Post act + agent(POST_ACT_STAGE, env) + hook(POST_ACT_STAGE, agent, env) + end + + # Update action + action_ind = agent(env) + action = env.shared.action_space[action_ind] + + env_helper.shared.actions[id] = action + env_helper.shared.actions_ind[id] = action_ind + + return nothing +end + +act_hook(::Nothing, args...) = nothing + +function act_hook( + env_helper::EnvHelper, particle::Particle, δt::Float64, si::Float64, co::Float64 +) + # Apply action + action = env_helper.shared.actions[particle.id] + + vδt = action[1] * δt + particle.tmp_c += SVector(vδt * co, vδt * si) + particle.φ += action[2] * δt + + return nothing +end \ No newline at end of file diff --git a/src/RL/RL.jl b/src/RL/RL.jl index 2ad332a..783c4c9 100644 --- a/src/RL/RL.jl +++ b/src/RL/RL.jl @@ -13,270 +13,23 @@ using Random: Random using ProgressMeter: @showprogress using ..ReCo: - ReCo, Particle, angle2, norm2d, sq_norm2d, Shape, DEFAULT_SKIN_TO_INTERACTION_R_RATIO + ReCo, + Particle, + angle2, + norm2d, + sq_norm2d, + Shape, + DEFAULT_SKIN_TO_INTERACTION_R_RATIO, + method_not_implemented const INITIAL_STATE_IND = 1 const INITIAL_REWARD = 0.0 -method_not_implemented() = error("Method not implemented!") +include("Env.jl") +include("EnvHelper.jl") -function gen_angle_state_space(n_angle_states::Int64) - angle_range = range(; start=-π, stop=π, length=n_angle_states + 1) - - angle_state_space = Vector{Interval}(undef, n_angle_states) - - @simd for i in 1:n_angle_states - if i == 1 - bound = Closed - else - bound = Open - end - - angle_state_space[i] = Interval{Float64,bound,Closed}( - angle_range[i], angle_range[i + 1] - ) - end - - return angle_state_space -end - -function gen_distance_state_space( - min_distance::Float64, max_distance::Float64, n_distance_states::Int64 -) - @assert min_distance >= 0.0 - @assert max_distance > min_distance - @assert n_distance_states > 1 - - distance_range = range(; - start=min_distance, stop=max_distance, length=n_distance_states + 1 - ) - - distance_state_space = Vector{Interval}(undef, n_distance_states) - - @simd for i in 1:n_distance_states - if i == 1 - bound = Closed - else - bound = Open - end - - distance_state_space[i] = Interval{Float64,bound,Closed}( - distance_range[i], distance_range[i + 1] - ) - end - - return distance_state_space -end - -abstract type Env <: AbstractEnv end - -mutable struct EnvSharedProps{state_dims} - n_actions::Int64 - action_space::Vector{SVector{2,Float64}} - action_ind_space::OneTo{Int64} - - n_states::Int64 - state_space::Vector{SVector{state_dims,Interval}} - state_ind_space::OneTo{Int64} - state_ind::Int64 - - reward::Float64 - terminated::Bool - - function EnvSharedProps( - n_states::Int64, - state_space::Vector{SVector{state_dims,Interval}}; - n_v_actions::Int64=2, - n_ω_actions::Int64=3, - max_v::Float64=40.0, - max_ω::Float64=π / 2, - ) where {state_dims} - @assert n_v_actions > 1 - @assert n_ω_actions > 1 - @assert max_v > 0 - @assert max_ω > 0 - - v_action_space = range(; start=0.0, stop=max_v, length=n_v_actions) - ω_action_space = range(; start=-max_ω, stop=max_ω, length=n_ω_actions) - - n_actions = n_v_actions * n_ω_actions - - action_space = Vector{SVector{2,Float64}}(undef, n_actions) - - ind = 1 - for v in v_action_space - for ω in ω_action_space - action_space[ind] = SVector(v, ω) - ind += 1 - end - end - - action_ind_space = OneTo(n_actions) - - state_ind_space = OneTo(n_states) - - return new{state_dims}( - n_actions, - action_space, - action_ind_space, - n_states, - state_space, - state_ind_space, - INITIAL_STATE_IND, - INITIAL_REWARD, - false, - ) - end -end - -function reset!(env::Env) - env.shared.terminated = false - - return nothing -end - -RLBase.state_space(env::Env) = env.shared.state_ind_space - -RLBase.state(env::Env) = env.shared.state_ind - -RLBase.action_space(env::Env) = env.shared.action_ind_space - -RLBase.reward(env::Env) = env.shared.reward - -RLBase.is_terminated(env::Env) = env.shared.terminated - -struct EnvHelperSharedProps{H<:AbstractHook} - env::Env - agent::Agent - hook::H - - n_steps_before_actions_update::Int64 - - goal_gyration_tensor_eigvals_ratio::Float64 - - n_particles::Int64 - - old_states_ind::Vector{Int64} - states_ind::Vector{Int64} - - actions::Vector{SVector{2,Float64}} - actions_ind::Vector{Int64} - - function EnvHelperSharedProps( - env::Env, - agent::Agent, - hook::H, - n_steps_before_actions_update::Int64, - goal_gyration_tensor_eigvals_ratio::Float64, - n_particles::Int64, - ) where {H<:AbstractHook} - return new{H}( - env, - agent, - hook, - n_steps_before_actions_update, - goal_gyration_tensor_eigvals_ratio, - n_particles, - fill(0, n_particles), - fill(0, n_particles), - fill(SVector(0.0, 0.0), n_particles), - fill(0, n_particles), - ) - end -end - -abstract type EnvHelper end - -function gen_env_helper(::Env, env_helper_params::EnvHelperSharedProps) - return method_not_implemented() -end - -function pre_integration_hook(::EnvHelper) - return method_not_implemented() -end - -function state_update_helper_hook( - ::EnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64} -) - return method_not_implemented() -end - -function find_state_ind(state::S, state_space::Vector{S}) where {S<:SVector} - return findfirst(x -> x == state, state_space) -end - -function find_state_interval(value::Float64, state_space::Vector{Interval})::Interval - for state in state_space - if value in state - return state - end - end -end - -function state_update_hook(::EnvHelper, particles::Vector{Particle}) - return method_not_implemented() -end - -function get_env_agent_hook(env_helper::EnvHelper) - return (env_helper.shared.env, env_helper.shared.agent, env_helper.shared.hook) -end - -function update_reward!(::Env, ::EnvHelper, particle::Particle) - return method_not_implemented() -end - -function update_table_and_actions_hook( - env_helper::EnvHelper, particle::Particle, first_integration_step::Bool -) - env, agent, hook = get_env_agent_hook(env_helper) - - id = particle.id - - if !first_integration_step - # Old state - env.shared.state_ind = env_helper.shared.old_states_ind[id] - - action_ind = env_helper.shared.actions_ind[id] - - # Pre act - agent(PRE_ACT_STAGE, env, action_ind) - hook(PRE_ACT_STAGE, agent, env, action_ind) - - # Update to current state - env.shared.state_ind = env_helper.shared.states_ind[id] - - # Update reward - update_reward!(env, env_helper, particle) - - # Post act - agent(POST_ACT_STAGE, env) - hook(POST_ACT_STAGE, agent, env) - end - - # Update action - action_ind = agent(env) - action = env.shared.action_space[action_ind] - - env_helper.shared.actions[id] = action - env_helper.shared.actions_ind[id] = action_ind - - return nothing -end - -act_hook(::Nothing, args...) = nothing - -function act_hook( - env_helper::EnvHelper, particle::Particle, δt::Float64, si::Float64, co::Float64 -) - # Apply action - action = env_helper.shared.actions[particle.id] - - vδt = action[1] * δt - particle.tmp_c += SVector(vδt * co, vδt * si) - particle.φ += action[2] * δt - - return nothing -end +include("States.jl") +include("Hooks.jl") function gen_agent(n_states::Int64, n_actions::Int64, ϵ_stable::Float64) # TODO: Optimize warmup and decay @@ -334,7 +87,7 @@ function run_rl(; skin_to_interaction_r_ratio=skin_to_interaction_r_ratio, packing_ratio=packing_ratio, ) - n_particles = sim_consts.n_particles # This not always equal to the input! + n_particles = sim_consts.n_particles # Not always equal to the input! env = EnvType(sim_consts) diff --git a/src/RL/States.jl b/src/RL/States.jl new file mode 100644 index 0000000..3bdf00c --- /dev/null +++ b/src/RL/States.jl @@ -0,0 +1,59 @@ +function gen_angle_state_space(n_angle_states::Int64) + angle_range = range(; start=-π, stop=π, length=n_angle_states + 1) + + angle_state_space = Vector{Interval}(undef, n_angle_states) + + @simd for i in 1:n_angle_states + if i == 1 + bound = Closed + else + bound = Open + end + + angle_state_space[i] = Interval{Float64,bound,Closed}( + angle_range[i], angle_range[i + 1] + ) + end + + return angle_state_space +end + +function gen_distance_state_space( + min_distance::Float64, max_distance::Float64, n_distance_states::Int64 +) + @assert min_distance >= 0.0 + @assert max_distance > min_distance + @assert n_distance_states > 1 + + distance_range = range(; + start=min_distance, stop=max_distance, length=n_distance_states + 1 + ) + + distance_state_space = Vector{Interval}(undef, n_distance_states) + + @simd for i in 1:n_distance_states + if i == 1 + bound = Closed + else + bound = Open + end + + distance_state_space[i] = Interval{Float64,bound,Closed}( + distance_range[i], distance_range[i + 1] + ) + end + + return distance_state_space +end + +function find_state_ind(state::S, state_space::Vector{S}) where {S<:SVector} + return findfirst(x -> x == state, state_space) +end + +function find_state_interval(value::Float64, state_space::Vector{Interval})::Interval + for state in state_space + if value in state + return state + end + end +end \ No newline at end of file diff --git a/src/ReCo.jl b/src/ReCo.jl index a0cae4c..920367b 100644 --- a/src/ReCo.jl +++ b/src/ReCo.jl @@ -13,6 +13,9 @@ using CellListMap: Box, CellList, map_pairwise!, UpdateCellList! using Random: Random using Dates: Dates, now +include("Error.jl") +using .Error + include("PreVectors.jl") using .PreVectors