mirror of
https://gitlab.rlp.net/mobitar/ReCo.jl.git
synced 2024-12-21 00:51:21 +00:00
Added state_id_tensor
This commit is contained in:
parent
28fd6bab95
commit
614311b080
5 changed files with 47 additions and 53 deletions
|
@ -1,26 +1,26 @@
|
||||||
abstract type Env <: AbstractEnv end
|
abstract type Env <: AbstractEnv end
|
||||||
|
|
||||||
mutable struct EnvSharedProps{state_dims}
|
mutable struct EnvSharedProps{n_state_dims}
|
||||||
n_actions::Int64
|
n_actions::Int64
|
||||||
action_space::Vector{SVector{2,Float64}}
|
action_space::Vector{SVector{2,Float64}}
|
||||||
action_ind_space::OneTo{Int64}
|
action_id_space::OneTo{Int64}
|
||||||
|
|
||||||
n_states::Int64
|
n_states::Int64
|
||||||
state_space::Vector{SVector{state_dims,Interval}}
|
state_id_tensor::Array{Int64,n_state_dims}
|
||||||
state_ind_space::OneTo{Int64}
|
state_id_space::OneTo{Int64}
|
||||||
state_ind::Int64
|
state_id::Int64
|
||||||
|
|
||||||
reward::Float64
|
reward::Float64
|
||||||
terminated::Bool
|
terminated::Bool
|
||||||
|
|
||||||
function EnvSharedProps(
|
function EnvSharedProps(
|
||||||
n_states::Int64,
|
n_states::Int64, # Can be different from the sum of state_id_tensor_dims
|
||||||
state_space::Vector{SVector{state_dims,Interval}};
|
state_id_tensor_dims::NTuple{n_state_dims,Int64};
|
||||||
n_v_actions::Int64=2,
|
n_v_actions::Int64=2,
|
||||||
n_ω_actions::Int64=3,
|
n_ω_actions::Int64=3,
|
||||||
max_v::Float64=40.0,
|
max_v::Float64=40.0,
|
||||||
max_ω::Float64=π / 2,
|
max_ω::Float64=π / 2,
|
||||||
) where {state_dims}
|
) where {n_state_dims}
|
||||||
@assert n_v_actions > 1
|
@assert n_v_actions > 1
|
||||||
@assert n_ω_actions > 1
|
@assert n_ω_actions > 1
|
||||||
@assert max_v > 0
|
@assert max_v > 0
|
||||||
|
@ -41,17 +41,25 @@ mutable struct EnvSharedProps{state_dims}
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
action_ind_space = OneTo(n_actions)
|
action_id_space = OneTo(n_actions)
|
||||||
|
|
||||||
state_ind_space = OneTo(n_states)
|
state_id_tensor = Array{Int64,n_state_dims}(undef, state_id_tensor_dims)
|
||||||
|
|
||||||
return new{state_dims}(
|
id = 1
|
||||||
|
for ind in eachindex(state_id_tensor)
|
||||||
|
state_id_tensor[ind] = id
|
||||||
|
id += 1
|
||||||
|
end
|
||||||
|
|
||||||
|
state_id_space = OneTo(n_states)
|
||||||
|
|
||||||
|
return new{n_state_dims}(
|
||||||
n_actions,
|
n_actions,
|
||||||
action_space,
|
action_space,
|
||||||
action_ind_space,
|
action_id_space,
|
||||||
n_states,
|
n_states,
|
||||||
state_space,
|
state_id_tensor,
|
||||||
state_ind_space,
|
state_id_space,
|
||||||
INITIAL_STATE_IND,
|
INITIAL_STATE_IND,
|
||||||
INITIAL_REWARD,
|
INITIAL_REWARD,
|
||||||
false,
|
false,
|
||||||
|
@ -65,11 +73,11 @@ function reset!(env::Env)
|
||||||
return nothing
|
return nothing
|
||||||
end
|
end
|
||||||
|
|
||||||
RLBase.state_space(env::Env) = env.shared.state_ind_space
|
RLBase.state_space(env::Env) = env.shared.state_id_space
|
||||||
|
|
||||||
RLBase.state(env::Env) = env.shared.state_ind
|
RLBase.state(env::Env) = env.shared.state_id
|
||||||
|
|
||||||
RLBase.action_space(env::Env) = env.shared.action_ind_space
|
RLBase.action_space(env::Env) = env.shared.action_id_space
|
||||||
|
|
||||||
RLBase.reward(env::Env) = env.shared.reward
|
RLBase.reward(env::Env) = env.shared.reward
|
||||||
|
|
||||||
|
|
|
@ -11,11 +11,11 @@ struct EnvHelperSharedProps{H<:AbstractHook}
|
||||||
|
|
||||||
n_particles::Int64
|
n_particles::Int64
|
||||||
|
|
||||||
old_states_ind::Vector{Int64}
|
old_states_id::Vector{Int64}
|
||||||
states_ind::Vector{Int64}
|
states_id::Vector{Int64}
|
||||||
|
|
||||||
actions::Vector{SVector{2,Float64}}
|
actions::Vector{SVector{2,Float64}}
|
||||||
actions_ind::Vector{Int64}
|
actions_id::Vector{Int64}
|
||||||
|
|
||||||
function EnvHelperSharedProps(
|
function EnvHelperSharedProps(
|
||||||
env::Env,
|
env::Env,
|
||||||
|
|
|
@ -27,16 +27,16 @@ function update_table_and_actions_hook(
|
||||||
|
|
||||||
if !first_integration_step
|
if !first_integration_step
|
||||||
# Old state
|
# Old state
|
||||||
env.shared.state_ind = env_helper.shared.old_states_ind[id]
|
env.shared.state_id = env_helper.shared.old_states_id[id]
|
||||||
|
|
||||||
action_ind = env_helper.shared.actions_ind[id]
|
action_id = env_helper.shared.actions_id[id]
|
||||||
|
|
||||||
# Pre act
|
# Pre act
|
||||||
agent(PRE_ACT_STAGE, env, action_ind)
|
agent(PRE_ACT_STAGE, env, action_id)
|
||||||
hook(PRE_ACT_STAGE, agent, env, action_ind)
|
hook(PRE_ACT_STAGE, agent, env, action_id)
|
||||||
|
|
||||||
# Update to current state
|
# Update to current state
|
||||||
env.shared.state_ind = env_helper.shared.states_ind[id]
|
env.shared.state_id = env_helper.shared.states_id[id]
|
||||||
|
|
||||||
# Update reward
|
# Update reward
|
||||||
update_reward!(env, env_helper, particle)
|
update_reward!(env, env_helper, particle)
|
||||||
|
@ -47,11 +47,11 @@ function update_table_and_actions_hook(
|
||||||
end
|
end
|
||||||
|
|
||||||
# Update action
|
# Update action
|
||||||
action_ind = agent(env)
|
action_id = agent(env)
|
||||||
action = env.shared.action_space[action_ind]
|
action = env.shared.action_space[action_id]
|
||||||
|
|
||||||
env_helper.shared.actions[id] = action
|
env_helper.shared.actions[id] = action
|
||||||
env_helper.shared.actions_ind[id] = action_ind
|
env_helper.shared.actions_id[id] = action_id
|
||||||
|
|
||||||
return nothing
|
return nothing
|
||||||
end
|
end
|
||||||
|
|
|
@ -11,6 +11,7 @@ struct LocalCOMEnv <: Env
|
||||||
function LocalCOMEnv(;
|
function LocalCOMEnv(;
|
||||||
n_distance_states::Int64=3, n_direction_angle_states::Int64=3, args
|
n_distance_states::Int64=3, n_direction_angle_states::Int64=3, args
|
||||||
)
|
)
|
||||||
|
@assert n_distance_states > 1
|
||||||
@assert n_direction_angle_states > 1
|
@assert n_direction_angle_states > 1
|
||||||
|
|
||||||
direction_angle_state_space = gen_angle_state_space(n_direction_angle_states)
|
direction_angle_state_space = gen_angle_state_space(n_direction_angle_states)
|
||||||
|
@ -23,19 +24,9 @@ struct LocalCOMEnv <: Env
|
||||||
)
|
)
|
||||||
|
|
||||||
n_states = n_distance_states * n_direction_angle_states + 1
|
n_states = n_distance_states * n_direction_angle_states + 1
|
||||||
|
|
||||||
state_space = Vector{SVector{2,Interval}}(undef, n_states - 1)
|
|
||||||
|
|
||||||
ind = 1
|
|
||||||
for distance_state in distance_state_space
|
|
||||||
for direction_angle_state in direction_angle_state_space
|
|
||||||
state_space[ind] = SVector(distance_state, direction_angle_state)
|
|
||||||
ind += 1
|
|
||||||
end
|
|
||||||
end
|
|
||||||
# Last state is when no particle is in the skin radius
|
# Last state is when no particle is in the skin radius
|
||||||
|
|
||||||
shared = EnvSharedProps(n_states, state_space)
|
shared = EnvSharedProps(n_states, (n_distance_states, n_direction_angle_states))
|
||||||
|
|
||||||
return new(shared, distance_state_space, direction_angle_state_space)
|
return new(shared, distance_state_space, direction_angle_state_space)
|
||||||
end
|
end
|
||||||
|
@ -110,7 +101,7 @@ function state_update_hook(env_helper::LocalCOMEnvHelper, particles::Vector{Part
|
||||||
n_particles = env_helper.shared.n_particles
|
n_particles = env_helper.shared.n_particles
|
||||||
|
|
||||||
@turbo for id in 1:n_particles
|
@turbo for id in 1:n_particles
|
||||||
env_helper.shared.old_states_ind[id] = env_helper.shared.states_ind[id]
|
env_helper.shared.old_states_id[id] = env_helper.shared.states_id[id]
|
||||||
end
|
end
|
||||||
|
|
||||||
env = env_helper.shared.env
|
env = env_helper.shared.env
|
||||||
|
@ -121,7 +112,7 @@ function state_update_hook(env_helper::LocalCOMEnvHelper, particles::Vector{Part
|
||||||
n_neighbours = env_helper.n_neighbours[id]
|
n_neighbours = env_helper.n_neighbours[id]
|
||||||
|
|
||||||
if n_neighbours == 0
|
if n_neighbours == 0
|
||||||
state_ind = env.shared.n_states
|
state_id = env.shared.n_states
|
||||||
|
|
||||||
distance_to_local_center_of_mass_sum +=
|
distance_to_local_center_of_mass_sum +=
|
||||||
env_helper.max_distance_to_local_center_of_mass
|
env_helper.max_distance_to_local_center_of_mass
|
||||||
|
@ -135,21 +126,20 @@ function state_update_hook(env_helper::LocalCOMEnvHelper, particles::Vector{Part
|
||||||
|
|
||||||
distance_to_local_center_of_mass_sum += distance
|
distance_to_local_center_of_mass_sum += distance
|
||||||
|
|
||||||
distance_state = find_state_interval(distance, env.distance_state_space)
|
distance_state_ind = find_state_ind(distance, env.distance_state_space)
|
||||||
|
|
||||||
si, co = sincos(particles[id].φ)
|
si, co = sincos(particles[id].φ)
|
||||||
|
|
||||||
direction_angle = ReCo.angle2(SVector(co, si), vec_to_local_center_of_mass)
|
direction_angle = ReCo.angle2(SVector(co, si), vec_to_local_center_of_mass)
|
||||||
|
|
||||||
direction_angle_state = find_state_interval(
|
direction_state_ind = find_state_ind(
|
||||||
direction_angle, env.direction_angle_state_space
|
direction_angle, env.direction_angle_state_space
|
||||||
)
|
)
|
||||||
|
|
||||||
state = SVector{2,Interval}(distance_state, direction_angle_state)
|
state_id = env.shared.state_id_tensor[distance_state_ind, direction_state_ind]
|
||||||
state_ind = find_state_ind(state, env.shared.state_space)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
env_helper.shared.states_ind[id] = state_ind
|
env_helper.shared.states_id[id] = state_id
|
||||||
end
|
end
|
||||||
|
|
||||||
mean_distance_to_local_center_of_mass =
|
mean_distance_to_local_center_of_mass =
|
||||||
|
|
|
@ -46,14 +46,10 @@ function gen_distance_state_space(
|
||||||
return distance_state_space
|
return distance_state_space
|
||||||
end
|
end
|
||||||
|
|
||||||
function find_state_ind(state::S, state_space::Vector{S}) where {S<:SVector}
|
function find_state_ind(value::Float64, state_space::Vector{Interval})::Int64
|
||||||
return findfirst(x -> x == state, state_space)
|
for (ind, state) in enumerate(state_space)
|
||||||
end
|
|
||||||
|
|
||||||
function find_state_interval(value::Float64, state_space::Vector{Interval})::Interval
|
|
||||||
for state in state_space
|
|
||||||
if value in state
|
if value in state
|
||||||
return state
|
return ind
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
Loading…
Reference in a new issue