diff --git a/src/RL/Env.jl b/src/RL/Env.jl index 44229b4..f80fa22 100644 --- a/src/RL/Env.jl +++ b/src/RL/Env.jl @@ -1,26 +1,26 @@ abstract type Env <: AbstractEnv end -mutable struct EnvSharedProps{state_dims} +mutable struct EnvSharedProps{n_state_dims} n_actions::Int64 action_space::Vector{SVector{2,Float64}} - action_ind_space::OneTo{Int64} + action_id_space::OneTo{Int64} n_states::Int64 - state_space::Vector{SVector{state_dims,Interval}} - state_ind_space::OneTo{Int64} - state_ind::Int64 + state_id_tensor::Array{Int64,n_state_dims} + state_id_space::OneTo{Int64} + state_id::Int64 reward::Float64 terminated::Bool function EnvSharedProps( - n_states::Int64, - state_space::Vector{SVector{state_dims,Interval}}; + n_states::Int64, # Can be different from the sum of state_id_tensor_dims + state_id_tensor_dims::NTuple{n_state_dims,Int64}; n_v_actions::Int64=2, n_ω_actions::Int64=3, max_v::Float64=40.0, max_ω::Float64=π / 2, - ) where {state_dims} + ) where {n_state_dims} @assert n_v_actions > 1 @assert n_ω_actions > 1 @assert max_v > 0 @@ -41,17 +41,25 @@ mutable struct EnvSharedProps{state_dims} end end - action_ind_space = OneTo(n_actions) + action_id_space = OneTo(n_actions) - state_ind_space = OneTo(n_states) + state_id_tensor = Array{Int64,n_state_dims}(undef, state_id_tensor_dims) - return new{state_dims}( + id = 1 + for ind in eachindex(state_id_tensor) + state_id_tensor[ind] = id + id += 1 + end + + state_id_space = OneTo(n_states) + + return new{n_state_dims}( n_actions, action_space, - action_ind_space, + action_id_space, n_states, - state_space, - state_ind_space, + state_id_tensor, + state_id_space, INITIAL_STATE_IND, INITIAL_REWARD, false, @@ -65,11 +73,11 @@ function reset!(env::Env) return nothing end -RLBase.state_space(env::Env) = env.shared.state_ind_space +RLBase.state_space(env::Env) = env.shared.state_id_space -RLBase.state(env::Env) = env.shared.state_ind +RLBase.state(env::Env) = env.shared.state_id -RLBase.action_space(env::Env) = env.shared.action_ind_space +RLBase.action_space(env::Env) = env.shared.action_id_space RLBase.reward(env::Env) = env.shared.reward diff --git a/src/RL/EnvHelper.jl b/src/RL/EnvHelper.jl index bcb36f5..f0c47bb 100644 --- a/src/RL/EnvHelper.jl +++ b/src/RL/EnvHelper.jl @@ -11,11 +11,11 @@ struct EnvHelperSharedProps{H<:AbstractHook} n_particles::Int64 - old_states_ind::Vector{Int64} - states_ind::Vector{Int64} + old_states_id::Vector{Int64} + states_id::Vector{Int64} actions::Vector{SVector{2,Float64}} - actions_ind::Vector{Int64} + actions_id::Vector{Int64} function EnvHelperSharedProps( env::Env, diff --git a/src/RL/Hooks.jl b/src/RL/Hooks.jl index edd9ae5..82666fb 100644 --- a/src/RL/Hooks.jl +++ b/src/RL/Hooks.jl @@ -27,16 +27,16 @@ function update_table_and_actions_hook( if !first_integration_step # Old state - env.shared.state_ind = env_helper.shared.old_states_ind[id] + env.shared.state_id = env_helper.shared.old_states_id[id] - action_ind = env_helper.shared.actions_ind[id] + action_id = env_helper.shared.actions_id[id] # Pre act - agent(PRE_ACT_STAGE, env, action_ind) - hook(PRE_ACT_STAGE, agent, env, action_ind) + agent(PRE_ACT_STAGE, env, action_id) + hook(PRE_ACT_STAGE, agent, env, action_id) # Update to current state - env.shared.state_ind = env_helper.shared.states_ind[id] + env.shared.state_id = env_helper.shared.states_id[id] # Update reward update_reward!(env, env_helper, particle) @@ -47,11 +47,11 @@ function update_table_and_actions_hook( end # Update action - action_ind = agent(env) - action = env.shared.action_space[action_ind] + action_id = agent(env) + action = env.shared.action_space[action_id] env_helper.shared.actions[id] = action - env_helper.shared.actions_ind[id] = action_ind + env_helper.shared.actions_id[id] = action_id return nothing end diff --git a/src/RL/LocalCOMEnv.jl b/src/RL/LocalCOMEnv.jl index 66eb076..6b73030 100644 --- a/src/RL/LocalCOMEnv.jl +++ b/src/RL/LocalCOMEnv.jl @@ -11,6 +11,7 @@ struct LocalCOMEnv <: Env function LocalCOMEnv(; n_distance_states::Int64=3, n_direction_angle_states::Int64=3, args ) + @assert n_distance_states > 1 @assert n_direction_angle_states > 1 direction_angle_state_space = gen_angle_state_space(n_direction_angle_states) @@ -23,19 +24,9 @@ struct LocalCOMEnv <: Env ) n_states = n_distance_states * n_direction_angle_states + 1 - - state_space = Vector{SVector{2,Interval}}(undef, n_states - 1) - - ind = 1 - for distance_state in distance_state_space - for direction_angle_state in direction_angle_state_space - state_space[ind] = SVector(distance_state, direction_angle_state) - ind += 1 - end - end # Last state is when no particle is in the skin radius - shared = EnvSharedProps(n_states, state_space) + shared = EnvSharedProps(n_states, (n_distance_states, n_direction_angle_states)) return new(shared, distance_state_space, direction_angle_state_space) end @@ -110,7 +101,7 @@ function state_update_hook(env_helper::LocalCOMEnvHelper, particles::Vector{Part n_particles = env_helper.shared.n_particles @turbo for id in 1:n_particles - env_helper.shared.old_states_ind[id] = env_helper.shared.states_ind[id] + env_helper.shared.old_states_id[id] = env_helper.shared.states_id[id] end env = env_helper.shared.env @@ -121,7 +112,7 @@ function state_update_hook(env_helper::LocalCOMEnvHelper, particles::Vector{Part n_neighbours = env_helper.n_neighbours[id] if n_neighbours == 0 - state_ind = env.shared.n_states + state_id = env.shared.n_states distance_to_local_center_of_mass_sum += env_helper.max_distance_to_local_center_of_mass @@ -135,21 +126,20 @@ function state_update_hook(env_helper::LocalCOMEnvHelper, particles::Vector{Part distance_to_local_center_of_mass_sum += distance - distance_state = find_state_interval(distance, env.distance_state_space) + distance_state_ind = find_state_ind(distance, env.distance_state_space) si, co = sincos(particles[id].φ) direction_angle = ReCo.angle2(SVector(co, si), vec_to_local_center_of_mass) - direction_angle_state = find_state_interval( + direction_state_ind = find_state_ind( direction_angle, env.direction_angle_state_space ) - state = SVector{2,Interval}(distance_state, direction_angle_state) - state_ind = find_state_ind(state, env.shared.state_space) + state_id = env.shared.state_id_tensor[distance_state_ind, direction_state_ind] end - env_helper.shared.states_ind[id] = state_ind + env_helper.shared.states_id[id] = state_id end mean_distance_to_local_center_of_mass = diff --git a/src/RL/States.jl b/src/RL/States.jl index 3bdf00c..5424c4a 100644 --- a/src/RL/States.jl +++ b/src/RL/States.jl @@ -46,14 +46,10 @@ function gen_distance_state_space( return distance_state_space end -function find_state_ind(state::S, state_space::Vector{S}) where {S<:SVector} - return findfirst(x -> x == state, state_space) -end - -function find_state_interval(value::Float64, state_space::Vector{Interval})::Interval - for state in state_space +function find_state_ind(value::Float64, state_space::Vector{Interval})::Int64 + for (ind, state) in enumerate(state_space) if value in state - return state + return ind end end end \ No newline at end of file