1
0
Fork 0
mirror of https://gitlab.rlp.net/mobitar/ReCo.jl.git synced 2024-09-19 19:01:17 +00:00

Fixed angle state

This commit is contained in:
MoBit 2021-12-16 14:54:52 +01:00
parent 14302fbe4e
commit 5fc3df66cd
3 changed files with 44 additions and 45 deletions

View file

@ -4,6 +4,11 @@ export angle2
using StaticArrays: SVector using StaticArrays: SVector
"""
angle2(a::SVector{2,Float64}, b::SVector{2,Float64})
Returns the angle φ from vector a to b while φ [-π, π].
"""
function angle2(a::SVector{2,Float64}, b::SVector{2,Float64}) function angle2(a::SVector{2,Float64}, b::SVector{2,Float64})
θ_a = atan(a[2], a[1]) θ_a = atan(a[2], a[1])
θ_b = atan(b[2], b[1]) θ_b = atan(b[2], b[1])

View file

@ -14,29 +14,13 @@ using ..ReCo: ReCo, Particle, angle2
const INITIAL_REWARD = 0.0 const INITIAL_REWARD = 0.0
struct DistanceState{L<:Bound}
interval::Interval{Float64,L,Closed}
function DistanceState{L}(lower::Float64, upper::Float64) where {L<:Bound}
return new(Interval{Float64,L,Closed}(lower, upper))
end
end
struct DirectionState
interval::Interval{Float64,Closed,Open}
function DirectionState(lower::Float64, upper::Float64)
return new(Interval{Float64,Closed,Open}(lower, upper))
end
end
mutable struct EnvParams mutable struct EnvParams
action_space::Vector{Tuple{Float64,Float64}} action_space::Vector{Tuple{Float64,Float64}}
action_ind_space::Vector{Int64} action_ind_space::Vector{Int64}
distance_state_space::Vector{DistanceState} distance_state_space::Vector{Interval}
direction_state_space::Vector{DirectionState} angle_state_space::Vector{Interval}
state_space::Vector{Union{Tuple{DistanceState,DirectionState},Tuple{Nothing,Nothing}}} state_space::Vector{Union{Tuple{Interval,Interval},Tuple{Nothing,Nothing}}}
state_ind_space::Vector{Int64} state_ind_space::Vector{Int64}
n_states::Int64 n_states::Int64
@ -50,7 +34,7 @@ mutable struct EnvParams
max_v::Float64=80.0, max_v::Float64=80.0,
max_ω::Float64=π / 2, max_ω::Float64=π / 2,
n_distance_states::Int64=2, n_distance_states::Int64=2,
n_direction_states::Int64=2, n_angle_states::Int64=2,
) )
@assert min_distance > 0.0 @assert min_distance > 0.0
@assert max_distance > min_distance @assert max_distance > min_distance
@ -79,7 +63,7 @@ mutable struct EnvParams
distance_range = distance_range =
min_distance:((max_distance - min_distance) / n_distance_states):max_distance min_distance:((max_distance - min_distance) / n_distance_states):max_distance
distance_state_space = Vector{DistanceState}(undef, n_distance_states) distance_state_space = Vector{Interval}(undef, n_distance_states)
@simd for i in 1:n_distance_states @simd for i in 1:n_distance_states
if i == 1 if i == 1
@ -88,33 +72,37 @@ mutable struct EnvParams
bound = Open bound = Open
end end
distance_state_space[i] = DistanceState{bound}( distance_state_space[i] = Interval{Float64,bound,Closed}(
distance_range[i], distance_range[i + 1] distance_range[i], distance_range[i + 1]
) )
end end
direction_range = (-π):(2 * π / n_direction_states):π angle_range = (-π):(2 * π / n_angle_states):π
direction_state_space = Vector{DirectionState}(undef, n_direction_states) angle_state_space = Vector{Interval}(undef, n_angle_states)
@simd for i in 1:n_direction_states @simd for i in 1:n_angle_states
direction_state_space[i] = DirectionState( if i == 1
direction_range[i], direction_range[i + 1] bound = Closed
else
bound = Open
end
angle_state_space[i] = Interval{Float64,bound,Closed}(
angle_range[i], angle_range[i + 1]
) )
end end
n_states = n_distance_states * n_direction_states + 1 n_states = n_distance_states * n_angle_states + 1
state_space = Vector{ state_space = Vector{Union{Tuple{Interval,Interval},Tuple{Nothing,Nothing}}}(
Union{Tuple{DistanceState,DirectionState},Tuple{Nothing,Nothing}}
}(
undef, n_states undef, n_states
) )
ind = 1 ind = 1
for distance_state in distance_state_space for distance_state in distance_state_space
for direction_state in direction_state_space for angle_state in angle_state_space
state_space[ind] = (distance_state, direction_state) state_space[ind] = (distance_state, angle_state)
ind += 1 ind += 1
end end
end end
@ -126,7 +114,7 @@ mutable struct EnvParams
action_space, action_space,
action_ind_space, action_ind_space,
distance_state_space, distance_state_space,
direction_state_space, angle_state_space,
state_space, state_space,
state_ind_space, state_ind_space,
n_states, n_states,
@ -286,7 +274,7 @@ function integration_hook!(
return nothing return nothing
end end
function get_state_ind(state::Tuple{DistanceState,DirectionState}, env_params::EnvParams) function get_state_ind(state::Tuple{Interval,Interval}, env_params::EnvParams)
return findfirst(x -> x == state, env_params.state_space) return findfirst(x -> x == state, env_params.state_space)
end end
@ -311,19 +299,19 @@ function post_integration_hook(
# Update states # Update states
n_states = rl_params.env_params.n_states n_states = rl_params.env_params.n_states
env_direction_state = rl_params.env_params.direction_state_space[1] env_angle_state = rl_params.env_params.angle_state_space[1]
for i in 1:n_particles for i in 1:n_particles
env, agent, hook = get_env_agent_hook(rl_params, i) env, agent, hook = get_env_agent_hook(rl_params, i)
env_distance_state::Union{DistanceState,Nothing} = nothing env_distance_state::Union{Interval,Nothing} = nothing
min_sq_distance = rl_params.min_sq_distances[i] min_sq_distance = rl_params.min_sq_distances[i]
min_distance = sqrt(min_sq_distance) min_distance = sqrt(min_sq_distance)
if !isinf(min_sq_distance) if !isinf(min_sq_distance)
for distance_state in rl_params.env_params.distance_state_space for distance_state in rl_params.env_params.distance_state_space
if min_distance in distance_state.interval if min_distance in distance_state
env_distance_state = distance_state env_distance_state = distance_state
break break
end end
@ -345,17 +333,17 @@ function post_integration_hook(
norm(e) == 1 norm(e) == 1
min_distance is not infinite, because otherwise min_distance is not infinite, because otherwise
env_direction_state would be nothing and this else block will not be called env_angle_state would be nothing and this else block will not be called
=# =#
direction = angle2(SVector(co, si), r⃗₁₂) angle = angle2(SVector(co, si), r⃗₁₂)
for direction_state in rl_params.env_params.direction_state_space for angle_state in rl_params.env_params.angle_state_space
if direction in direction_state.interval if angle in angle_state
env_direction_state = direction_state env_angle_state = angle_state
end end
end end
state = (env_distance_state, env_direction_state) state = (env_distance_state, env_angle_state)
env.state_ind = get_state_ind(state, env.params) env.state_ind = get_state_ind(state, env.params)
end end
@ -384,7 +372,7 @@ function run_rl(;
# Setup # Setup
Random.seed!(seed) Random.seed!(seed)
sim_consts = ReCo.gen_sim_consts(n_particles, 0.0; skin_to_interaction_r_ratio=3.0) sim_consts = ReCo.gen_sim_consts(n_particles, 0.0; skin_to_interaction_r_ratio=1.6)
n_particles = sim_consts.n_particles n_particles = sim_consts.n_particles
env_params = EnvParams(sim_consts.particle_radius, sim_consts.skin_r) env_params = EnvParams(sim_consts.particle_radius, sim_consts.skin_r)

View file

@ -3,6 +3,12 @@ using ReCo
using StaticArrays: SVector using StaticArrays: SVector
@testset "Geometry.jl" begin
@test ReCo.Geometry.angle2(SVector(1.0, 0.0), SVector(0.0, 1.0)) / π 0.5
@test ReCo.Geometry.angle2(SVector(0.0, 1.0), SVector(1.0, 0.0)) / π -0.5
@test ReCo.Geometry.angle2(SVector(1.0, 0.0), SVector(1.0, 0.0)) / π 0.0
end
@testset "Particle.jl" begin @testset "Particle.jl" begin
half_box_len = 1.0 half_box_len = 1.0