diff --git a/src/RL/Envs/COMCompass.jl b/src/RL/Envs/COMCompass.jl new file mode 100644 index 0000000..2a926e7 --- /dev/null +++ b/src/RL/Envs/COMCompass.jl @@ -0,0 +1,157 @@ +using ..ReCo: ReCo + +struct COMCompassEnv <: Env + shared::EnvSharedProps + + distance_state_space::Vector{Interval} + direction_angle_state_space::Vector{Interval} + position_angle_state_space::Vector{Interval} + + function COMCompassEnv(; + n_distance_states::Int64=3, + n_direction_angle_states::Int64=3, + n_position_angle_states::Int64=8, + args, + ) + @assert n_distance_states > 1 + @assert n_direction_angle_states > 1 + @assert n_position_angle_states > 1 + + direction_angle_state_space = gen_angle_state_space(n_direction_angle_states) + position_angle_state_space = gen_angle_state_space(n_position_angle_states) + + min_distance = 0.0 + max_distance = sqrt(2) * args.half_box_len + + distance_state_space = gen_distance_state_space( + min_distance, max_distance, n_distance_states + ) + + n_states = n_distance_states * n_direction_angle_states * n_position_angle_states + + state_spaces_labels = gen_state_spaces_labels( + ("d", "\\theta", "\\alpha"), + (distance_state_space, direction_angle_state_space, position_angle_state_space), + ) + + shared = EnvSharedProps( + n_states, + (n_distance_states, n_direction_angle_states, n_position_angle_states), + state_spaces_labels, + ) + + return new( + shared, + distance_state_space, + direction_angle_state_space, + position_angle_state_space, + ) + end +end + +mutable struct COMCompassEnvHelper <: EnvHelper + shared::EnvHelperSharedProps + + center_of_mass::SVector{2,Float64} + gyration_tensor_eigvec_to_smaller_eigval::SVector{2,Float64} + gyration_tensor_eigvec_to_bigger_eigval::SVector{2,Float64} + + half_box_len::Float64 + max_elliptical_distance::Float64 + + function COMCompassEnvHelper(shared::EnvHelperSharedProps, half_box_len::Float64) + max_elliptical_distance = sqrt( + half_box_len^2 + (half_box_len / shared.elliptical_a_b_ratio)^2 + ) + + return new( + shared, + SVector(0.0, 0.0), + SVector(0.0, 0.0), + SVector(0.0, 0.0), + half_box_len, + max_elliptical_distance, + ) + end +end + +function gen_env_helper(::COMCompassEnv, env_helper_shared::EnvHelperSharedProps; args) + return COMCompassEnvHelper(env_helper_shared, args.half_box_len) +end + +function pre_integration_hook!(::COMCompassEnvHelper) + return nothing +end + +function state_update_helper_hook!( + ::COMCompassEnvHelper, + id1::Int64, + id2::Int64, + r⃗₁₂::SVector{2,Float64}, + distance²::Float64, +) + return nothing +end + +function state_update_hook!( + env_helper::COMCompassEnvHelper, particles::Vector{ReCo.Particle} +) + n_particles = env_helper.shared.n_particles + + env = env_helper.shared.env + + env_helper.center_of_mass = ReCo.center_of_mass(particles, env_helper.half_box_len) + + for particle_id in 1:n_particles + vec_to_center_of_mass = ReCo.restrict_coordinates( + env_helper.center_of_mass - particles[particle_id].c, env_helper.half_box_len + ) + distance_to_center_of_mass = ReCo.norm2d(vec_to_center_of_mass) + distance_state_ind = find_state_ind( + distance_to_center_of_mass, env.distance_state_space + ) + + si, co = sincos(particles[particle_id].φ) + direction_angle = ReCo.angle2(SVector(co, si), vec_to_center_of_mass) + direction_state_ind = find_state_ind( + direction_angle, env.direction_angle_state_space + ) + + position_angle = atan(si, co) + position_state_ind = find_state_ind(position_angle, env.position_angle_state_space) + + state_id = env.shared.state_id_tensor[ + distance_state_ind, direction_state_ind, position_state_ind + ] + + env_helper.shared.states_id[particle_id] = state_id + end + + v1, v2 = ReCo.gyration_tensor_eigvecs( + particles, env_helper.half_box_len, env_helper.center_of_mass + ) + + env_helper.gyration_tensor_eigvec_to_smaller_eigval = v1 + env_helper.gyration_tensor_eigvec_to_bigger_eigval = v2 + + return nothing +end + +function update_reward!( + env::COMCompassEnv, env_helper::COMCompassEnvHelper, particle::ReCo.Particle +) + elliptical_distance = ReCo.elliptical_distance( + particle.c, + env_helper.center_of_mass, + env_helper.gyration_tensor_eigvec_to_smaller_eigval, + env_helper.gyration_tensor_eigvec_to_bigger_eigval, + env_helper.shared.elliptical_a_b_ratio, + env_helper.half_box_len, + ) + + reward = minimizing_reward(elliptical_distance, env_helper.max_elliptical_distance) + + set_normalized_reward!(env, reward, env_helper) + + return nothing +end \ No newline at end of file