1
0
Fork 0
mirror of https://gitlab.rlp.net/mobitar/ReCo.jl.git synced 2024-11-08 22:21:08 +00:00

Local centers of mass

This commit is contained in:
Mo8it 2021-12-28 23:39:24 +01:00
parent 6f774ea6d0
commit 9353621edb
3 changed files with 30 additions and 54 deletions

View file

@ -211,8 +211,6 @@ function animate(
) )
println("Generating animation...") println("Generating animation...")
set_window_config!(; framerate=framerate)
sim_consts = JSON3.read(read("$dir/sim_consts.json", String)) sim_consts = JSON3.read(read("$dir/sim_consts.json", String))
animate_with_sim_consts(dir, sim_consts, framerate, show_center_of_mass, debug) animate_with_sim_consts(dir, sim_consts, framerate, show_center_of_mass, debug)

View file

@ -33,16 +33,16 @@ mutable struct Env <: AbstractEnv
center_of_mass::SVector{2,Float64} center_of_mass::SVector{2,Float64}
function Env( function Env(
min_distance::Float64,
max_distance::Float64; max_distance::Float64;
min_distance::Float64=0.0,
n_v_actions::Int64=3, n_v_actions::Int64=3,
n_ω_actions::Int64=5, n_ω_actions::Int64=3,
max_v::Float64=60.0, max_v::Float64=40.0,
max_ω::Float64=π / 2, max_ω::Float64=π / 2,
n_distance_states::Int64=3, n_distance_states::Int64=3,
n_angle_states::Int64=4, n_angle_states::Int64=4,
) )
@assert min_distance > 0.0 @assert min_distance >= 0.0
@assert max_distance > min_distance @assert max_distance > min_distance
@assert n_v_actions > 1 @assert n_v_actions > 1
@assert n_ω_actions > 1 @assert n_ω_actions > 1
@ -168,12 +168,12 @@ struct Params{H<:AbstractHook}
goal_shape_ratio::Float64 goal_shape_ratio::Float64
n_particles::Int64 n_particles::Int64
min_sq_distances::Vector{Float64}
vecs_r⃗₁₂_to_min_distance_particle::Vector{SVector{2,Float64}}
half_box_len::Float64 half_box_len::Float64
max_elliptic_distance::Float64 max_elliptic_distance::Float64
local_centers_of_mass::Vector{SVector{2,Float64}}
updated_local_center_of_mass::Vector{Bool}
function Params( function Params(
env::Env, env::Env,
agent::Agent, agent::Agent,
@ -198,36 +198,31 @@ struct Params{H<:AbstractHook}
n_steps_before_actions_update, n_steps_before_actions_update,
goal_shape_ratio, goal_shape_ratio,
n_particles, n_particles,
fill(Inf64, n_particles),
fill(SVector(0.0, 0.0), n_particles),
half_box_len, half_box_len,
max_elliptic_distance, max_elliptic_distance,
fill(SVector(0.0, 0.0), n_particles),
falses(n_particles),
) )
end end
end end
function pre_integration_hook(rl_params::Params) function pre_integration_hook(rl_params::Params)
@turbo for i in 1:(rl_params.n_particles) @simd for id in 1:(rl_params.n_particles)
rl_params.min_sq_distances[i] = Inf64 rl_params.local_centers_of_mass[id] = SVector(0.0, 0.0)
rl_params.updated_local_center_of_mass[id] = false
end end
return nothing return nothing
end end
function state_update_helper_hook( function state_update_helper_hook(
rl_params::Params, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}, distance²::Float64 rl_params::Params, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}
) )
if rl_params.min_sq_distances[id1] > distance² rl_params.local_centers_of_mass[id1] += r⃗₁₂
rl_params.min_sq_distances[id1] = distance² rl_params.local_centers_of_mass[id2] -= r⃗₁₂
rl_params.vecs_r⃗₁₂_to_min_distance_particle[id1] = r⃗₁₂ rl_params.updated_local_center_of_mass[id1] = true
end rl_params.updated_local_center_of_mass[id2] = true
if rl_params.min_sq_distances[id2] > distance²
rl_params.min_sq_distances[id2] = distance²
rl_params.vecs_r⃗₁₂_to_min_distance_particle[id2] = -r⃗₁₂
end
return nothing return nothing
end end
@ -243,54 +238,38 @@ function state_update_hook(rl_params::Params, particles::Vector{Particle})
env = rl_params.env env = rl_params.env
n_states = env.n_states env_distance_state = env.distance_state_space[1]
env_angle_state = env.angle_state_space[1] env_angle_state = env.angle_state_space[1]
state_ind = 0
state_space = env.state_space
for id in 1:(rl_params.n_particles) for id in 1:(rl_params.n_particles)
env_distance_state::Union{Interval,Nothing} = nothing if !rl_params.updated_local_center_of_mass[id]
state_ind = env.n_states
else
local_center_of_mass = rl_params.local_centers_of_mass[id]
min_sq_distance = rl_params.min_sq_distances[id] distance = sqrt(local_center_of_mass[1]^2 + local_center_of_mass[2]^2)
min_distance = sqrt(min_sq_distance)
if !isinf(min_sq_distance)
for distance_state in env.distance_state_space for distance_state in env.distance_state_space
if min_distance in distance_state if distance in distance_state
env_distance_state = distance_state env_distance_state = distance_state
break break
end end
end end
end
# (nothing, nothing)
state_ind = n_states
if !isnothing(env_distance_state)
r⃗₁₂ = rl_params.vecs_r⃗₁₂_to_min_distance_particle[id]
si, co = sincos(particles[id].φ) si, co = sincos(particles[id].φ)
#= angle = angle2(SVector(co, si), local_center_of_mass)
Angle between two vectors
e = (co, si)
angle = acos(dot(r⃗₁₂, e) / (norm(r⃗₁₂) * norm(e)))
norm(r⃗₁₂) == min_distance
norm(e) == 1
min_distance is not infinite, because otherwise
env_angle_state would be nothing and this else block will not be called
=#
angle = angle2(SVector(co, si), r⃗₁₂)
for angle_state in env.angle_state_space for angle_state in env.angle_state_space
if angle in angle_state if angle in angle_state
env_angle_state = angle_state env_angle_state = angle_state
break
end end
end end
state = SVector{2,Interval}(env_distance_state, env_angle_state) state = SVector{2,Interval}(env_distance_state, env_angle_state)
state_ind = get_state_ind(state, state_space) state_ind = get_state_ind(state, env.state_space)
end end
rl_params.states_ind[id] = state_ind rl_params.states_ind[id] = state_ind
@ -402,7 +381,7 @@ function run_rl(;
) )
n_particles = sim_consts.n_particles n_particles = sim_consts.n_particles
env = Env(sim_consts.particle_radius, sim_consts.skin_r) env = Env(sim_consts.skin_r)
agent = gen_agent(env.n_states, env.n_actions, ϵ) agent = gen_agent(env.n_states, env.n_actions, ϵ)

View file

@ -52,7 +52,7 @@ function euler!(
p1_c, p2.c, args.interaction_r², args.half_box_len p1_c, p2.c, args.interaction_r², args.half_box_len
) )
state_update_helper_hook(rl_params, id1, id2, r⃗₁₂, distance²) state_update_helper_hook(rl_params, id1, id2, r⃗₁₂)
if overlapping if overlapping
factor = args.c₁ / (distance²^4) * (args.c₂ / (distance²^3) - 1.0) factor = args.c₁ / (distance²^4) * (args.c₂ / (distance²^3) - 1.0)
@ -116,7 +116,6 @@ function simulate(
first_integration_step = true first_integration_step = true
run_hooks = false
state_update_helper_hook = state_update_helper_hook =
state_update_hook = update_table_and_actions_hook = empty_hook state_update_hook = update_table_and_actions_hook = empty_hook