1
0
Fork 0
mirror of https://gitlab.rlp.net/mobitar/ReCo.jl.git synced 2024-12-21 00:51:21 +00:00

Local centers of mass

This commit is contained in:
Mo8it 2021-12-28 23:39:24 +01:00
parent 6f774ea6d0
commit 9353621edb
3 changed files with 30 additions and 54 deletions

View file

@ -211,8 +211,6 @@ function animate(
)
println("Generating animation...")
set_window_config!(; framerate=framerate)
sim_consts = JSON3.read(read("$dir/sim_consts.json", String))
animate_with_sim_consts(dir, sim_consts, framerate, show_center_of_mass, debug)

View file

@ -33,16 +33,16 @@ mutable struct Env <: AbstractEnv
center_of_mass::SVector{2,Float64}
function Env(
min_distance::Float64,
max_distance::Float64;
min_distance::Float64=0.0,
n_v_actions::Int64=3,
n_ω_actions::Int64=5,
max_v::Float64=60.0,
n_ω_actions::Int64=3,
max_v::Float64=40.0,
max_ω::Float64=π / 2,
n_distance_states::Int64=3,
n_angle_states::Int64=4,
)
@assert min_distance > 0.0
@assert min_distance >= 0.0
@assert max_distance > min_distance
@assert n_v_actions > 1
@assert n_ω_actions > 1
@ -168,12 +168,12 @@ struct Params{H<:AbstractHook}
goal_shape_ratio::Float64
n_particles::Int64
min_sq_distances::Vector{Float64}
vecs_r⃗₁₂_to_min_distance_particle::Vector{SVector{2,Float64}}
half_box_len::Float64
max_elliptic_distance::Float64
local_centers_of_mass::Vector{SVector{2,Float64}}
updated_local_center_of_mass::Vector{Bool}
function Params(
env::Env,
agent::Agent,
@ -198,36 +198,31 @@ struct Params{H<:AbstractHook}
n_steps_before_actions_update,
goal_shape_ratio,
n_particles,
fill(Inf64, n_particles),
fill(SVector(0.0, 0.0), n_particles),
half_box_len,
max_elliptic_distance,
fill(SVector(0.0, 0.0), n_particles),
falses(n_particles),
)
end
end
function pre_integration_hook(rl_params::Params)
@turbo for i in 1:(rl_params.n_particles)
rl_params.min_sq_distances[i] = Inf64
@simd for id in 1:(rl_params.n_particles)
rl_params.local_centers_of_mass[id] = SVector(0.0, 0.0)
rl_params.updated_local_center_of_mass[id] = false
end
return nothing
end
function state_update_helper_hook(
rl_params::Params, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}, distance²::Float64
rl_params::Params, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}
)
if rl_params.min_sq_distances[id1] > distance²
rl_params.min_sq_distances[id1] = distance²
rl_params.local_centers_of_mass[id1] += r⃗₁₂
rl_params.local_centers_of_mass[id2] -= r⃗₁₂
rl_params.vecs_r⃗₁₂_to_min_distance_particle[id1] = r⃗₁₂
end
if rl_params.min_sq_distances[id2] > distance²
rl_params.min_sq_distances[id2] = distance²
rl_params.vecs_r⃗₁₂_to_min_distance_particle[id2] = -r⃗₁₂
end
rl_params.updated_local_center_of_mass[id1] = true
rl_params.updated_local_center_of_mass[id2] = true
return nothing
end
@ -243,54 +238,38 @@ function state_update_hook(rl_params::Params, particles::Vector{Particle})
env = rl_params.env
n_states = env.n_states
env_distance_state = env.distance_state_space[1]
env_angle_state = env.angle_state_space[1]
state_space = env.state_space
state_ind = 0
for id in 1:(rl_params.n_particles)
env_distance_state::Union{Interval,Nothing} = nothing
if !rl_params.updated_local_center_of_mass[id]
state_ind = env.n_states
else
local_center_of_mass = rl_params.local_centers_of_mass[id]
min_sq_distance = rl_params.min_sq_distances[id]
min_distance = sqrt(min_sq_distance)
distance = sqrt(local_center_of_mass[1]^2 + local_center_of_mass[2]^2)
if !isinf(min_sq_distance)
for distance_state in env.distance_state_space
if min_distance in distance_state
if distance in distance_state
env_distance_state = distance_state
break
end
end
end
# (nothing, nothing)
state_ind = n_states
if !isnothing(env_distance_state)
r⃗₁₂ = rl_params.vecs_r⃗₁₂_to_min_distance_particle[id]
si, co = sincos(particles[id].φ)
#=
Angle between two vectors
e = (co, si)
angle = acos(dot(r⃗₁₂, e) / (norm(r⃗₁₂) * norm(e)))
norm(r⃗₁₂) == min_distance
norm(e) == 1
min_distance is not infinite, because otherwise
env_angle_state would be nothing and this else block will not be called
=#
angle = angle2(SVector(co, si), r⃗₁₂)
angle = angle2(SVector(co, si), local_center_of_mass)
for angle_state in env.angle_state_space
if angle in angle_state
env_angle_state = angle_state
break
end
end
state = SVector{2,Interval}(env_distance_state, env_angle_state)
state_ind = get_state_ind(state, state_space)
state_ind = get_state_ind(state, env.state_space)
end
rl_params.states_ind[id] = state_ind
@ -402,7 +381,7 @@ function run_rl(;
)
n_particles = sim_consts.n_particles
env = Env(sim_consts.particle_radius, sim_consts.skin_r)
env = Env(sim_consts.skin_r)
agent = gen_agent(env.n_states, env.n_actions, ϵ)

View file

@ -52,7 +52,7 @@ function euler!(
p1_c, p2.c, args.interaction_r², args.half_box_len
)
state_update_helper_hook(rl_params, id1, id2, r⃗₁₂, distance²)
state_update_helper_hook(rl_params, id1, id2, r⃗₁₂)
if overlapping
factor = args.c₁ / (distance²^4) * (args.c₂ / (distance²^3) - 1.0)
@ -116,7 +116,6 @@ function simulate(
first_integration_step = true
run_hooks = false
state_update_helper_hook =
state_update_hook = update_table_and_actions_hook = empty_hook