diff --git a/graphics/radial_distribution.jl b/graphics/radial_distribution.jl index 7622e58..9eda94d 100644 --- a/graphics/radial_distribution.jl +++ b/graphics/radial_distribution.jl @@ -7,7 +7,6 @@ using ReCo: ReCo function gen_rdf_graphics() Random.seed!(1) - box_length = 100 box_length = 100 graphics_export_dir = "exports/graphics" diff --git a/src/RL/Envs/LocalCOMEnv.jl b/src/RL/Envs/LocalCOMEnv.jl index 5d6e8af..f45adfa 100644 --- a/src/RL/Envs/LocalCOMEnv.jl +++ b/src/RL/Envs/LocalCOMEnv.jl @@ -1,5 +1,3 @@ -export LocalCOMEnv - using ..ReCo: ReCo struct LocalCOMEnv <: Env diff --git a/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv.jl b/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv.jl index 539e5a0..c3bae35 100644 --- a/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv.jl +++ b/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv.jl @@ -1,5 +1,3 @@ -export LocalCOMWithAdditionalShapeRewardEnv - using ..ReCo: ReCo struct LocalCOMWithAdditionalShapeRewardEnv <: Env diff --git a/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv2.jl b/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv2.jl index 51be11f..4372bce 100644 --- a/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv2.jl +++ b/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv2.jl @@ -1,5 +1,3 @@ -export LocalCOMWithAdditionalShapeRewardEnv2 - using ..ReCo: ReCo struct LocalCOMWithAdditionalShapeRewardEnv2 <: Env diff --git a/src/RL/Envs/NearestNeighbourEnv.jl b/src/RL/Envs/NearestNeighbourEnv.jl index 5460686..3a931af 100644 --- a/src/RL/Envs/NearestNeighbourEnv.jl +++ b/src/RL/Envs/NearestNeighbourEnv.jl @@ -1,5 +1,3 @@ -export NearestNeighbourEnv - using ..ReCo: ReCo struct NearestNeighbourEnv <: Env diff --git a/src/RL/Envs/OriginCompass.jl b/src/RL/Envs/OriginCompass.jl new file mode 100644 index 0000000..f2b4f94 --- /dev/null +++ b/src/RL/Envs/OriginCompass.jl @@ -0,0 +1,155 @@ +using ..ReCo: ReCo + +struct OriginCompassEnv <: Env + shared::EnvSharedProps + + distance_state_space::Vector{Interval} + direction_angle_state_space::Vector{Interval} + position_angle_state_space::Vector{Interval} + + function OriginCompassEnv(; + n_distance_states::Int64=3, + n_direction_angle_states::Int64=3, + n_position_angle_states::Int64=8, + args, + ) + @assert n_distance_states > 1 + @assert n_direction_angle_states > 1 + @assert n_position_angle_states > 1 + + direction_angle_state_space = gen_angle_state_space(n_direction_angle_states) + position_angle_state_space = gen_angle_state_space(n_position_angle_states) + + min_distance = 0.0 + max_distance = sqrt(2) * args.half_box_len + + distance_state_space = gen_distance_state_space( + min_distance, max_distance, n_distance_states + ) + + n_states = n_distance_states * n_direction_angle_states * n_position_angle_states + + state_spaces_labels = gen_state_spaces_labels( + ("d", "\\theta", "\\alpha"), + (distance_state_space, direction_angle_state_space, position_angle_state_space), + ) + + shared = EnvSharedProps( + n_states, + (n_distance_states, n_direction_angle_states, n_position_angle_states), + state_spaces_labels, + ) + + return new( + shared, + distance_state_space, + direction_angle_state_space, + position_angle_state_space, + ) + end +end + +mutable struct OriginCompassEnvHelper <: EnvHelper + shared::EnvHelperSharedProps + + center_of_mass::SVector{2,Float64} + gyration_tensor_eigvec_to_smaller_eigval::SVector{2,Float64} + gyration_tensor_eigvec_to_bigger_eigval::SVector{2,Float64} + + half_box_len::Float64 + max_elliptical_distance::Float64 + + function OriginCompassEnvHelper(shared::EnvHelperSharedProps, half_box_len::Float64) + max_elliptical_distance = sqrt( + half_box_len^2 + (half_box_len / shared.elliptical_a_b_ratio)^2 + ) + + return new( + shared, + SVector(0.0, 0.0), + SVector(0.0, 0.0), + SVector(0.0, 0.0), + half_box_len, + max_elliptical_distance, + ) + end +end + +function gen_env_helper(::OriginCompassEnv, env_helper_shared::EnvHelperSharedProps; args) + return OriginCompassEnvHelper(env_helper_shared, args.half_box_len) +end + +function pre_integration_hook!(::OriginCompassEnvHelper) + return nothing +end + +function state_update_helper_hook!( + ::OriginCompassEnvHelper, + id1::Int64, + id2::Int64, + r⃗₁₂::SVector{2,Float64}, + distance²::Float64, +) + return nothing +end + +function state_update_hook!( + env_helper::OriginCompassEnvHelper, particles::Vector{ReCo.Particle} +) + n_particles = env_helper.shared.n_particles + + env = env_helper.shared.env + + env_helper.center_of_mass = ReCo.center_of_mass(particles, env_helper.half_box_len) + + for particle_id in 1:n_particles + vec_to_origin = -particles[particle_id].c + distance_to_center_of_mass = ReCo.norm2d(vec_to_origin) + distance_state_ind = find_state_ind( + distance_to_center_of_mass, env.distance_state_space + ) + + si, co = sincos(particles[particle_id].φ) + direction_angle = ReCo.angle2(SVector(co, si), vec_to_origin) + direction_state_ind = find_state_ind( + direction_angle, env.direction_angle_state_space + ) + + position_angle = atan(si, co) + position_state_ind = find_state_ind(position_angle, env.position_angle_state_space) + + state_id = env.shared.state_id_tensor[ + distance_state_ind, direction_state_ind, position_state_ind + ] + + env_helper.shared.states_id[particle_id] = state_id + end + + v1, v2 = ReCo.gyration_tensor_eigvecs( + particles, env_helper.half_box_len, env_helper.center_of_mass + ) + + env_helper.gyration_tensor_eigvec_to_smaller_eigval = v1 + env_helper.gyration_tensor_eigvec_to_bigger_eigval = v2 + + return nothing +end + +function update_reward!( + env::OriginCompassEnv, env_helper::OriginCompassEnvHelper, particle::ReCo.Particle +) + elliptical_distance = ReCo.elliptical_distance( + particle.c, + env_helper.center_of_mass, + env_helper.gyration_tensor_eigvec_to_smaller_eigval, + env_helper.gyration_tensor_eigvec_to_bigger_eigval, + env_helper.shared.elliptical_a_b_ratio, + env_helper.half_box_len, + ) + + reward = minimizing_reward(elliptical_distance, env_helper.max_elliptical_distance) + + set_normalized_reward!(env, reward, env_helper) + + return nothing +end \ No newline at end of file diff --git a/src/RL/Envs/OriginEnv.jl b/src/RL/Envs/OriginEnv.jl index 118007d..429f8e2 100644 --- a/src/RL/Envs/OriginEnv.jl +++ b/src/RL/Envs/OriginEnv.jl @@ -1,5 +1,3 @@ -export OriginEnv - using ..ReCo: ReCo struct OriginEnv <: Env diff --git a/src/RL/RL.jl b/src/RL/RL.jl index ee37711..ad59da5 100644 --- a/src/RL/RL.jl +++ b/src/RL/RL.jl @@ -5,7 +5,9 @@ export run_rl, LocalCOMWithAdditionalShapeRewardEnv2, OriginEnv, NearestNeighbourEnv, - LocalCOMEnv + LocalCOMEnv, + OriginCompassEnv, + COMCompassEnv using Base: OneTo @@ -198,5 +200,7 @@ include("Envs/LocalCOMWithAdditionalShapeRewardEnv2.jl") include("Envs/OriginEnv.jl") include("Envs/NearestNeighbourEnv.jl") include("Envs/LocalCOMEnv.jl") +include("Envs/OriginCompass.jl") +include("Envs/COMCompass.jl") end # module \ No newline at end of file diff --git a/src/ReCo.jl b/src/ReCo.jl index ff02a9e..64a2f8c 100644 --- a/src/ReCo.jl +++ b/src/ReCo.jl @@ -10,7 +10,9 @@ export init_sim, LocalCOMWithAdditionalShapeRewardEnv2, OriginEnv, NearestNeighbourEnv, - LocalCOMEnv + LocalCOMEnv, + OriginCompassEnv, + COMCompassEnv using StaticArrays: SVector using JLD2: JLD2