1
0
Fork 0
mirror of https://gitlab.rlp.net/mobitar/ReCo.jl.git synced 2024-09-19 19:01:17 +00:00

Added state_id_tensor

This commit is contained in:
Mo8it 2022-01-15 21:27:15 +01:00
parent 28fd6bab95
commit 614311b080
5 changed files with 47 additions and 53 deletions

View file

@ -1,26 +1,26 @@
abstract type Env <: AbstractEnv end abstract type Env <: AbstractEnv end
mutable struct EnvSharedProps{state_dims} mutable struct EnvSharedProps{n_state_dims}
n_actions::Int64 n_actions::Int64
action_space::Vector{SVector{2,Float64}} action_space::Vector{SVector{2,Float64}}
action_ind_space::OneTo{Int64} action_id_space::OneTo{Int64}
n_states::Int64 n_states::Int64
state_space::Vector{SVector{state_dims,Interval}} state_id_tensor::Array{Int64,n_state_dims}
state_ind_space::OneTo{Int64} state_id_space::OneTo{Int64}
state_ind::Int64 state_id::Int64
reward::Float64 reward::Float64
terminated::Bool terminated::Bool
function EnvSharedProps( function EnvSharedProps(
n_states::Int64, n_states::Int64, # Can be different from the sum of state_id_tensor_dims
state_space::Vector{SVector{state_dims,Interval}}; state_id_tensor_dims::NTuple{n_state_dims,Int64};
n_v_actions::Int64=2, n_v_actions::Int64=2,
n_ω_actions::Int64=3, n_ω_actions::Int64=3,
max_v::Float64=40.0, max_v::Float64=40.0,
max_ω::Float64=π / 2, max_ω::Float64=π / 2,
) where {state_dims} ) where {n_state_dims}
@assert n_v_actions > 1 @assert n_v_actions > 1
@assert n_ω_actions > 1 @assert n_ω_actions > 1
@assert max_v > 0 @assert max_v > 0
@ -41,17 +41,25 @@ mutable struct EnvSharedProps{state_dims}
end end
end end
action_ind_space = OneTo(n_actions) action_id_space = OneTo(n_actions)
state_ind_space = OneTo(n_states) state_id_tensor = Array{Int64,n_state_dims}(undef, state_id_tensor_dims)
return new{state_dims}( id = 1
for ind in eachindex(state_id_tensor)
state_id_tensor[ind] = id
id += 1
end
state_id_space = OneTo(n_states)
return new{n_state_dims}(
n_actions, n_actions,
action_space, action_space,
action_ind_space, action_id_space,
n_states, n_states,
state_space, state_id_tensor,
state_ind_space, state_id_space,
INITIAL_STATE_IND, INITIAL_STATE_IND,
INITIAL_REWARD, INITIAL_REWARD,
false, false,
@ -65,11 +73,11 @@ function reset!(env::Env)
return nothing return nothing
end end
RLBase.state_space(env::Env) = env.shared.state_ind_space RLBase.state_space(env::Env) = env.shared.state_id_space
RLBase.state(env::Env) = env.shared.state_ind RLBase.state(env::Env) = env.shared.state_id
RLBase.action_space(env::Env) = env.shared.action_ind_space RLBase.action_space(env::Env) = env.shared.action_id_space
RLBase.reward(env::Env) = env.shared.reward RLBase.reward(env::Env) = env.shared.reward

View file

@ -11,11 +11,11 @@ struct EnvHelperSharedProps{H<:AbstractHook}
n_particles::Int64 n_particles::Int64
old_states_ind::Vector{Int64} old_states_id::Vector{Int64}
states_ind::Vector{Int64} states_id::Vector{Int64}
actions::Vector{SVector{2,Float64}} actions::Vector{SVector{2,Float64}}
actions_ind::Vector{Int64} actions_id::Vector{Int64}
function EnvHelperSharedProps( function EnvHelperSharedProps(
env::Env, env::Env,

View file

@ -27,16 +27,16 @@ function update_table_and_actions_hook(
if !first_integration_step if !first_integration_step
# Old state # Old state
env.shared.state_ind = env_helper.shared.old_states_ind[id] env.shared.state_id = env_helper.shared.old_states_id[id]
action_ind = env_helper.shared.actions_ind[id] action_id = env_helper.shared.actions_id[id]
# Pre act # Pre act
agent(PRE_ACT_STAGE, env, action_ind) agent(PRE_ACT_STAGE, env, action_id)
hook(PRE_ACT_STAGE, agent, env, action_ind) hook(PRE_ACT_STAGE, agent, env, action_id)
# Update to current state # Update to current state
env.shared.state_ind = env_helper.shared.states_ind[id] env.shared.state_id = env_helper.shared.states_id[id]
# Update reward # Update reward
update_reward!(env, env_helper, particle) update_reward!(env, env_helper, particle)
@ -47,11 +47,11 @@ function update_table_and_actions_hook(
end end
# Update action # Update action
action_ind = agent(env) action_id = agent(env)
action = env.shared.action_space[action_ind] action = env.shared.action_space[action_id]
env_helper.shared.actions[id] = action env_helper.shared.actions[id] = action
env_helper.shared.actions_ind[id] = action_ind env_helper.shared.actions_id[id] = action_id
return nothing return nothing
end end

View file

@ -11,6 +11,7 @@ struct LocalCOMEnv <: Env
function LocalCOMEnv(; function LocalCOMEnv(;
n_distance_states::Int64=3, n_direction_angle_states::Int64=3, args n_distance_states::Int64=3, n_direction_angle_states::Int64=3, args
) )
@assert n_distance_states > 1
@assert n_direction_angle_states > 1 @assert n_direction_angle_states > 1
direction_angle_state_space = gen_angle_state_space(n_direction_angle_states) direction_angle_state_space = gen_angle_state_space(n_direction_angle_states)
@ -23,19 +24,9 @@ struct LocalCOMEnv <: Env
) )
n_states = n_distance_states * n_direction_angle_states + 1 n_states = n_distance_states * n_direction_angle_states + 1
state_space = Vector{SVector{2,Interval}}(undef, n_states - 1)
ind = 1
for distance_state in distance_state_space
for direction_angle_state in direction_angle_state_space
state_space[ind] = SVector(distance_state, direction_angle_state)
ind += 1
end
end
# Last state is when no particle is in the skin radius # Last state is when no particle is in the skin radius
shared = EnvSharedProps(n_states, state_space) shared = EnvSharedProps(n_states, (n_distance_states, n_direction_angle_states))
return new(shared, distance_state_space, direction_angle_state_space) return new(shared, distance_state_space, direction_angle_state_space)
end end
@ -110,7 +101,7 @@ function state_update_hook(env_helper::LocalCOMEnvHelper, particles::Vector{Part
n_particles = env_helper.shared.n_particles n_particles = env_helper.shared.n_particles
@turbo for id in 1:n_particles @turbo for id in 1:n_particles
env_helper.shared.old_states_ind[id] = env_helper.shared.states_ind[id] env_helper.shared.old_states_id[id] = env_helper.shared.states_id[id]
end end
env = env_helper.shared.env env = env_helper.shared.env
@ -121,7 +112,7 @@ function state_update_hook(env_helper::LocalCOMEnvHelper, particles::Vector{Part
n_neighbours = env_helper.n_neighbours[id] n_neighbours = env_helper.n_neighbours[id]
if n_neighbours == 0 if n_neighbours == 0
state_ind = env.shared.n_states state_id = env.shared.n_states
distance_to_local_center_of_mass_sum += distance_to_local_center_of_mass_sum +=
env_helper.max_distance_to_local_center_of_mass env_helper.max_distance_to_local_center_of_mass
@ -135,21 +126,20 @@ function state_update_hook(env_helper::LocalCOMEnvHelper, particles::Vector{Part
distance_to_local_center_of_mass_sum += distance distance_to_local_center_of_mass_sum += distance
distance_state = find_state_interval(distance, env.distance_state_space) distance_state_ind = find_state_ind(distance, env.distance_state_space)
si, co = sincos(particles[id].φ) si, co = sincos(particles[id].φ)
direction_angle = ReCo.angle2(SVector(co, si), vec_to_local_center_of_mass) direction_angle = ReCo.angle2(SVector(co, si), vec_to_local_center_of_mass)
direction_angle_state = find_state_interval( direction_state_ind = find_state_ind(
direction_angle, env.direction_angle_state_space direction_angle, env.direction_angle_state_space
) )
state = SVector{2,Interval}(distance_state, direction_angle_state) state_id = env.shared.state_id_tensor[distance_state_ind, direction_state_ind]
state_ind = find_state_ind(state, env.shared.state_space)
end end
env_helper.shared.states_ind[id] = state_ind env_helper.shared.states_id[id] = state_id
end end
mean_distance_to_local_center_of_mass = mean_distance_to_local_center_of_mass =

View file

@ -46,14 +46,10 @@ function gen_distance_state_space(
return distance_state_space return distance_state_space
end end
function find_state_ind(state::S, state_space::Vector{S}) where {S<:SVector} function find_state_ind(value::Float64, state_space::Vector{Interval})::Int64
return findfirst(x -> x == state, state_space) for (ind, state) in enumerate(state_space)
end
function find_state_interval(value::Float64, state_space::Vector{Interval})::Interval
for state in state_space
if value in state if value in state
return state return ind
end end
end end
end end