diff --git a/src/RL.jl b/src/RL.jl index d9b6209..5aa17da 100644 --- a/src/RL.jl +++ b/src/RL.jl @@ -12,7 +12,7 @@ using LoopVectorization: @turbo using Random: Random using ProgressMeter: @showprogress -using ..ReCo: ReCo, Particle, angle2, center_of_mass +using ..ReCo: ReCo, Particle, angle2, Shape const INITIAL_REWARD = 0.0 const INITIAL_STATE_IND = 1 @@ -55,6 +55,8 @@ mutable struct Env <: AbstractEnv terminated::Bool center_of_mass::SVector{2,Float64} # TODO: Use or remove + gyration_tensor_eigvec_to_smaller_eigval::SVector{2,Float64} + gyration_tensor_eigvec_to_bigger_eigval::SVector{2,Float64} function Env(; max_distance::Float64, @@ -65,7 +67,7 @@ mutable struct Env <: AbstractEnv max_ω::Float64=π / 2, n_distance_states::Int64=4, n_direction_angle_states::Int64=3, - n_position_angle_states::Int64=4, + n_position_angle_states::Int64=8, ) @assert min_distance >= 0.0 @assert max_distance > min_distance @@ -182,7 +184,7 @@ struct Params{H<:AbstractHook} n_steps_before_actions_update::Int64 - goal_shape_ratio::Float64 + goal_gyration_tensor_eigvals_ratio::Float64 n_particles::Int64 half_box_len::Float64 @@ -193,7 +195,7 @@ struct Params{H<:AbstractHook} agent::Agent, hook::H, n_steps_before_actions_update::Int64, - goal_shape_ratio::Float64, + goal_gyration_tensor_eigvals_ratio::Float64, n_particles::Int64, half_box_len::Float64, ) where {H<:AbstractHook} @@ -210,7 +212,7 @@ struct Params{H<:AbstractHook} fill(SVector(0.0, 0.0), n_particles), fill(0, n_particles), n_steps_before_actions_update, - goal_shape_ratio, + goal_gyration_tensor_eigvals_ratio, n_particles, half_box_len, max_elliptic_distance, @@ -247,17 +249,23 @@ function state_update_hook(rl_params::Params, particles::Vector{Particle}) env = rl_params.env + env.center_of_mass = Shape.center_of_mass(particles, rl_params.half_box_len) + for id in 1:(rl_params.n_particles) particle = particles[id] - distance = sqrt(particle.c[1]^2 + particle.c[2]^2) + vec_to_center_of_mass = ReCo.minimum_image( + env.center_of_mass - particle.c, rl_params.half_box_len + ) + + distance = sqrt(vec_to_center_of_mass[1]^2 + vec_to_center_of_mass[2]^2) distance_state = find_state_interval(distance, env.distance_state_space) si, co = sincos(particles[id].φ) - direction_angle = angle2(SVector(co, si), -particle.c) - position_angle = atan(particle.c[2], particle.c[1]) + direction_angle = angle2(SVector(co, si), vec_to_center_of_mass) + position_angle = atan(-vec_to_center_of_mass[2], -vec_to_center_of_mass[1]) direction_angle_state = find_state_interval( direction_angle, env.direction_angle_state_space @@ -274,7 +282,10 @@ function state_update_hook(rl_params::Params, particles::Vector{Particle}) rl_params.states_ind[id] = state_ind end - env.center_of_mass = center_of_mass(particles, rl_params.half_box_len) + v1, v2 = Shape.gyration_tensor_eigvecs(particles, rl_params.half_box_len) # TODO: Reuse center_of_mass + + env.gyration_tensor_eigvec_to_smaller_eigval = v1 + env.gyration_tensor_eigvec_to_bigger_eigval = v2 return nothing end @@ -285,8 +296,12 @@ end function update_reward!(env::Env, rl_params::Params, particle::Particle) env.reward = - -(particle.c[1]^2 + particle.c[2]^2) / - (rl_params.max_elliptic_distance^2 * rl_params.n_particles) + -Shape.elliptical_distance( + particle, + env.gyration_tensor_eigvec_to_smaller_eigval, + env.gyration_tensor_eigvec_to_bigger_eigval, + rl_params.goal_gyration_tensor_eigvals_ratio, + ) / (rl_params.max_elliptic_distance^2 * rl_params.n_particles) return nothing end @@ -368,7 +383,7 @@ function gen_agent(n_states::Int64, n_actions::Int64, ϵ_stable::Float64) end function run_rl(; - goal_shape_ratio::Float64, + goal_gyration_tensor_eigvals_ratio::Float64, n_episodes::Int64=200, episode_duration::Float64=50.0, update_actions_at::Float64=0.1, @@ -377,7 +392,7 @@ function run_rl(; ϵ_stable::Float64=0.0001, parent_dir::String="", ) - @assert 0.0 <= goal_shape_ratio <= 1.0 + @assert 0.0 <= goal_gyration_tensor_eigvals_ratio <= 1.0 @assert n_episodes > 0 @assert episode_duration > 0 @assert update_actions_at in 0.001:0.001:episode_duration @@ -405,7 +420,7 @@ function run_rl(; agent, hook, n_steps_before_actions_update, - goal_shape_ratio, + goal_gyration_tensor_eigvals_ratio, n_particles, sim_consts.half_box_len, )