mirror of
https://gitlab.rlp.net/mobitar/ReCo.jl.git
synced 2024-12-21 00:51:21 +00:00
Local centers of mass
This commit is contained in:
parent
6f774ea6d0
commit
9353621edb
3 changed files with 30 additions and 54 deletions
|
@ -211,8 +211,6 @@ function animate(
|
|||
)
|
||||
println("Generating animation...")
|
||||
|
||||
set_window_config!(; framerate=framerate)
|
||||
|
||||
sim_consts = JSON3.read(read("$dir/sim_consts.json", String))
|
||||
|
||||
animate_with_sim_consts(dir, sim_consts, framerate, show_center_of_mass, debug)
|
||||
|
|
79
src/RL.jl
79
src/RL.jl
|
@ -33,16 +33,16 @@ mutable struct Env <: AbstractEnv
|
|||
center_of_mass::SVector{2,Float64}
|
||||
|
||||
function Env(
|
||||
min_distance::Float64,
|
||||
max_distance::Float64;
|
||||
min_distance::Float64=0.0,
|
||||
n_v_actions::Int64=3,
|
||||
n_ω_actions::Int64=5,
|
||||
max_v::Float64=60.0,
|
||||
n_ω_actions::Int64=3,
|
||||
max_v::Float64=40.0,
|
||||
max_ω::Float64=π / 2,
|
||||
n_distance_states::Int64=3,
|
||||
n_angle_states::Int64=4,
|
||||
)
|
||||
@assert min_distance > 0.0
|
||||
@assert min_distance >= 0.0
|
||||
@assert max_distance > min_distance
|
||||
@assert n_v_actions > 1
|
||||
@assert n_ω_actions > 1
|
||||
|
@ -168,12 +168,12 @@ struct Params{H<:AbstractHook}
|
|||
goal_shape_ratio::Float64
|
||||
|
||||
n_particles::Int64
|
||||
min_sq_distances::Vector{Float64}
|
||||
vecs_r⃗₁₂_to_min_distance_particle::Vector{SVector{2,Float64}}
|
||||
|
||||
half_box_len::Float64
|
||||
max_elliptic_distance::Float64
|
||||
|
||||
local_centers_of_mass::Vector{SVector{2,Float64}}
|
||||
updated_local_center_of_mass::Vector{Bool}
|
||||
|
||||
function Params(
|
||||
env::Env,
|
||||
agent::Agent,
|
||||
|
@ -198,36 +198,31 @@ struct Params{H<:AbstractHook}
|
|||
n_steps_before_actions_update,
|
||||
goal_shape_ratio,
|
||||
n_particles,
|
||||
fill(Inf64, n_particles),
|
||||
fill(SVector(0.0, 0.0), n_particles),
|
||||
half_box_len,
|
||||
max_elliptic_distance,
|
||||
fill(SVector(0.0, 0.0), n_particles),
|
||||
falses(n_particles),
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
function pre_integration_hook(rl_params::Params)
|
||||
@turbo for i in 1:(rl_params.n_particles)
|
||||
rl_params.min_sq_distances[i] = Inf64
|
||||
@simd for id in 1:(rl_params.n_particles)
|
||||
rl_params.local_centers_of_mass[id] = SVector(0.0, 0.0)
|
||||
rl_params.updated_local_center_of_mass[id] = false
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function state_update_helper_hook(
|
||||
rl_params::Params, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}, distance²::Float64
|
||||
rl_params::Params, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}
|
||||
)
|
||||
if rl_params.min_sq_distances[id1] > distance²
|
||||
rl_params.min_sq_distances[id1] = distance²
|
||||
rl_params.local_centers_of_mass[id1] += r⃗₁₂
|
||||
rl_params.local_centers_of_mass[id2] -= r⃗₁₂
|
||||
|
||||
rl_params.vecs_r⃗₁₂_to_min_distance_particle[id1] = r⃗₁₂
|
||||
end
|
||||
|
||||
if rl_params.min_sq_distances[id2] > distance²
|
||||
rl_params.min_sq_distances[id2] = distance²
|
||||
|
||||
rl_params.vecs_r⃗₁₂_to_min_distance_particle[id2] = -r⃗₁₂
|
||||
end
|
||||
rl_params.updated_local_center_of_mass[id1] = true
|
||||
rl_params.updated_local_center_of_mass[id2] = true
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
@ -243,54 +238,38 @@ function state_update_hook(rl_params::Params, particles::Vector{Particle})
|
|||
|
||||
env = rl_params.env
|
||||
|
||||
n_states = env.n_states
|
||||
|
||||
env_distance_state = env.distance_state_space[1]
|
||||
env_angle_state = env.angle_state_space[1]
|
||||
|
||||
state_space = env.state_space
|
||||
state_ind = 0
|
||||
|
||||
for id in 1:(rl_params.n_particles)
|
||||
env_distance_state::Union{Interval,Nothing} = nothing
|
||||
if !rl_params.updated_local_center_of_mass[id]
|
||||
state_ind = env.n_states
|
||||
else
|
||||
local_center_of_mass = rl_params.local_centers_of_mass[id]
|
||||
|
||||
min_sq_distance = rl_params.min_sq_distances[id]
|
||||
min_distance = sqrt(min_sq_distance)
|
||||
distance = sqrt(local_center_of_mass[1]^2 + local_center_of_mass[2]^2)
|
||||
|
||||
if !isinf(min_sq_distance)
|
||||
for distance_state in env.distance_state_space
|
||||
if min_distance in distance_state
|
||||
if distance in distance_state
|
||||
env_distance_state = distance_state
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# (nothing, nothing)
|
||||
state_ind = n_states
|
||||
|
||||
if !isnothing(env_distance_state)
|
||||
r⃗₁₂ = rl_params.vecs_r⃗₁₂_to_min_distance_particle[id]
|
||||
si, co = sincos(particles[id].φ)
|
||||
|
||||
#=
|
||||
Angle between two vectors
|
||||
e = (co, si)
|
||||
angle = acos(dot(r⃗₁₂, e) / (norm(r⃗₁₂) * norm(e)))
|
||||
norm(r⃗₁₂) == min_distance
|
||||
norm(e) == 1
|
||||
|
||||
min_distance is not infinite, because otherwise
|
||||
env_angle_state would be nothing and this else block will not be called
|
||||
=#
|
||||
angle = angle2(SVector(co, si), r⃗₁₂)
|
||||
angle = angle2(SVector(co, si), local_center_of_mass)
|
||||
|
||||
for angle_state in env.angle_state_space
|
||||
if angle in angle_state
|
||||
env_angle_state = angle_state
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
state = SVector{2,Interval}(env_distance_state, env_angle_state)
|
||||
state_ind = get_state_ind(state, state_space)
|
||||
state_ind = get_state_ind(state, env.state_space)
|
||||
end
|
||||
|
||||
rl_params.states_ind[id] = state_ind
|
||||
|
@ -402,7 +381,7 @@ function run_rl(;
|
|||
)
|
||||
n_particles = sim_consts.n_particles
|
||||
|
||||
env = Env(sim_consts.particle_radius, sim_consts.skin_r)
|
||||
env = Env(sim_consts.skin_r)
|
||||
|
||||
agent = gen_agent(env.n_states, env.n_actions, ϵ)
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ function euler!(
|
|||
p1_c, p2.c, args.interaction_r², args.half_box_len
|
||||
)
|
||||
|
||||
state_update_helper_hook(rl_params, id1, id2, r⃗₁₂, distance²)
|
||||
state_update_helper_hook(rl_params, id1, id2, r⃗₁₂)
|
||||
|
||||
if overlapping
|
||||
factor = args.c₁ / (distance²^4) * (args.c₂ / (distance²^3) - 1.0)
|
||||
|
@ -116,7 +116,6 @@ function simulate(
|
|||
|
||||
first_integration_step = true
|
||||
|
||||
run_hooks = false
|
||||
state_update_helper_hook =
|
||||
state_update_hook = update_table_and_actions_hook = empty_hook
|
||||
|
||||
|
|
Loading…
Reference in a new issue