using ..ReCo: ReCo

struct OriginCompassEnv <: Env
    shared::EnvSharedProps

    distance_state_space::Vector{Interval}
    direction_angle_state_space::Vector{Interval}
    position_angle_state_space::Vector{Interval}

    function OriginCompassEnv(;
        n_distance_states::Int64=3,
        n_direction_angle_states::Int64=3,
        n_position_angle_states::Int64=8,
        args,
    )
        @assert n_distance_states > 1
        @assert n_direction_angle_states > 1
        @assert n_position_angle_states > 1

        direction_angle_state_space = gen_angle_state_space(n_direction_angle_states)
        position_angle_state_space = gen_angle_state_space(n_position_angle_states)

        min_distance = 0.0
        max_distance = sqrt(2) * args.half_box_len

        distance_state_space = gen_distance_state_space(
            min_distance, max_distance, n_distance_states
        )

        n_states = n_distance_states * n_direction_angle_states * n_position_angle_states

        state_spaces_labels = gen_state_spaces_labels(
            ("d", "\\theta", "\\alpha"),
            (distance_state_space, direction_angle_state_space, position_angle_state_space),
        )

        shared = EnvSharedProps(
            n_states,
            (n_distance_states, n_direction_angle_states, n_position_angle_states),
            state_spaces_labels,
        )

        return new(
            shared,
            distance_state_space,
            direction_angle_state_space,
            position_angle_state_space,
        )
    end
end

mutable struct OriginCompassEnvHelper <: EnvHelper
    shared::EnvHelperSharedProps

    center_of_mass::SVector{2,Float64}
    gyration_tensor_eigvec_to_smaller_eigval::SVector{2,Float64}
    gyration_tensor_eigvec_to_bigger_eigval::SVector{2,Float64}

    half_box_len::Float64
    max_elliptical_distance::Float64

    function OriginCompassEnvHelper(shared::EnvHelperSharedProps, half_box_len::Float64)
        max_elliptical_distance = sqrt(
            half_box_len^2 + (half_box_len / shared.elliptical_a_b_ratio)^2
        )

        return new(
            shared,
            SVector(0.0, 0.0),
            SVector(0.0, 0.0),
            SVector(0.0, 0.0),
            half_box_len,
            max_elliptical_distance,
        )
    end
end

function gen_env_helper(::OriginCompassEnv, env_helper_shared::EnvHelperSharedProps; args)
    return OriginCompassEnvHelper(env_helper_shared, args.half_box_len)
end

function pre_integration_hook!(::OriginCompassEnvHelper)
    return nothing
end

function state_update_helper_hook!(
    ::OriginCompassEnvHelper,
    id1::Int64,
    id2::Int64,
    r⃗₁₂::SVector{2,Float64},
    distance²::Float64,
)
    return nothing
end

function state_update_hook!(
    env_helper::OriginCompassEnvHelper, particles::Vector{ReCo.Particle}
)
    n_particles = env_helper.shared.n_particles

    env = env_helper.shared.env

    env_helper.center_of_mass = ReCo.center_of_mass(particles, env_helper.half_box_len)

    for particle_id in 1:n_particles
        vec_to_origin = -particles[particle_id].c
        distance_to_center_of_mass = ReCo.norm2d(vec_to_origin)
        distance_state_ind = find_state_ind(
            distance_to_center_of_mass, env.distance_state_space
        )

        si, co = sincos(particles[particle_id].φ)
        direction_angle = ReCo.angle2(SVector(co, si), vec_to_origin)
        direction_state_ind = find_state_ind(
            direction_angle, env.direction_angle_state_space
        )

        position_angle = atan(si, co)
        position_state_ind = find_state_ind(position_angle, env.position_angle_state_space)

        state_id = env.shared.state_id_tensor[
            distance_state_ind, direction_state_ind, position_state_ind
        ]

        env_helper.shared.states_id[particle_id] = state_id
    end

    v1, v2 = ReCo.gyration_tensor_eigvecs(
        particles, env_helper.half_box_len, env_helper.center_of_mass
    )

    env_helper.gyration_tensor_eigvec_to_smaller_eigval = v1
    env_helper.gyration_tensor_eigvec_to_bigger_eigval = v2

    return nothing
end

function update_reward!(
    env::OriginCompassEnv, env_helper::OriginCompassEnvHelper, particle::ReCo.Particle
)
    elliptical_distance = ReCo.elliptical_distance(
        particle.c,
        env_helper.center_of_mass,
        env_helper.gyration_tensor_eigvec_to_smaller_eigval,
        env_helper.gyration_tensor_eigvec_to_bigger_eigval,
        env_helper.shared.elliptical_a_b_ratio,
        env_helper.half_box_len,
    )

    reward = minimizing_reward(elliptical_distance, env_helper.max_elliptical_distance)

    set_normalized_reward!(env, reward, env_helper)

    return nothing
end