"""
This environment corresponds to the local center of mass strategy with an additional shape reward term.
The minimization variable of the additional reward term is the individual elliptical distance of a particle.
"""

using ..ReCo: ReCo

const DEFAULT_TRIGGER = 0.35

struct LocalCOMWithAdditionalShapeRewardEnv <: Env
    shared::EnvSharedProps

    distance_state_space::Vector{Interval}
    direction_angle_state_space::Vector{Interval}

    function LocalCOMWithAdditionalShapeRewardEnv(
        args; n_distance_states::Int64=3, n_direction_angle_states::Int64=3
    )
        @assert n_distance_states > 1
        @assert n_direction_angle_states > 1

        direction_angle_state_space = gen_angle_state_space(n_direction_angle_states)

        min_distance = 0.0
        max_distance = args.skin_radius

        distance_state_space = gen_distance_state_space(
            min_distance, max_distance, n_distance_states
        )

        n_states = n_distance_states * n_direction_angle_states + 1
        # Last state is when no particle is in the skin radius

        state_spaces_labels = gen_state_spaces_labels(
            ("d", "\\theta"), (distance_state_space, direction_angle_state_space)
        )

        shared = EnvSharedProps(
            n_states, (n_distance_states, n_direction_angle_states), state_spaces_labels
        )

        return new(shared, distance_state_space, direction_angle_state_space)
    end
end

mutable struct LocalCOMWithAdditionalShapeRewardEnvHelper <: EnvHelper
    shared::EnvHelperSharedProps

    vec_to_neighbor_sums::Vector{SVector{2,Float64}}
    n_neighbors::Vector{Int64}

    distances_to_local_center_of_mass::Vector{Float64}
    max_distance_to_local_center_of_mass::Float64

    add_shape_reward_term::Bool

    center_of_mass::SVector{2,Float64}
    gyration_tensor_eigvec_to_smaller_eigval::SVector{2,Float64}
    gyration_tensor_eigvec_to_bigger_eigval::SVector{2,Float64}

    half_box_len::Float64
    max_elliptical_distance::Float64

    trigger::Float64

    function LocalCOMWithAdditionalShapeRewardEnvHelper(
        shared::EnvHelperSharedProps;
        half_box_len::Float64,
        skin_radius::Float64,
        trigger::Float64=DEFAULT_TRIGGER,
    )
        max_elliptical_distance = sqrt(
            half_box_len^2 + (half_box_len / shared.elliptical_b_a_ratio)^2
        )

        max_distance_to_local_center_of_mass = skin_radius

        return new(
            shared,
            fill(SVector(0.0, 0.0), shared.n_particles),
            fill(0, shared.n_particles),
            zeros(Float64, shared.n_particles),
            max_distance_to_local_center_of_mass,
            false,
            SVector(0.0, 0.0),
            SVector(0.0, 0.0),
            SVector(0.0, 0.0),
            half_box_len,
            max_elliptical_distance,
            trigger,
        )
    end
end

function gen_env_helper(
    ::LocalCOMWithAdditionalShapeRewardEnv,
    env_helper_shared::EnvHelperSharedProps;
    kwargs...,
)
    return LocalCOMWithAdditionalShapeRewardEnvHelper(env_helper_shared; kwargs...)
end

function pre_integration_hook!(env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper)
    @simd for id in 1:(env_helper.shared.n_particles)
        env_helper.vec_to_neighbor_sums[id] = SVector(0.0, 0.0)
        env_helper.n_neighbors[id] = 0
    end

    return nothing
end

function state_update_helper_hook!(
    env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper,
    id1::Int64,
    id2::Int64,
    r⃗₁₂::SVector{2,Float64},
    distance²::Float64,
)
    env_helper.vec_to_neighbor_sums[id1] += r⃗₁₂
    env_helper.vec_to_neighbor_sums[id2] -= r⃗₁₂

    env_helper.n_neighbors[id1] += 1
    env_helper.n_neighbors[id2] += 1

    return nothing
end

function state_update_hook!(
    env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper, particles::Vector{ReCo.Particle}
)
    n_particles = env_helper.shared.n_particles

    env = env_helper.shared.env

    distance_to_local_center_of_mass_sum = 0.0

    for particle_id in 1:n_particles
        n_neighbors = env_helper.n_neighbors[particle_id]

        if n_neighbors == 0
            state_id = env.shared.n_states

            distance_to_local_center_of_mass_sum +=
                env_helper.max_distance_to_local_center_of_mass
        else
            vec_to_local_center_of_mass =
                env_helper.vec_to_neighbor_sums[particle_id] / n_neighbors
            distance = ReCo.norm2d(vec_to_local_center_of_mass)
            env_helper.distances_to_local_center_of_mass[particle_id] = distance
            distance_to_local_center_of_mass_sum += distance
            distance_state_ind = find_state_ind(distance, env.distance_state_space)

            si, co = sincos(particles[particle_id].φ)
            direction_angle = ReCo.angle2(SVector(co, si), vec_to_local_center_of_mass)
            direction_state_ind = find_state_ind(
                direction_angle, env.direction_angle_state_space
            )

            state_id = env.shared.state_id_tensor[distance_state_ind, direction_state_ind]
        end

        env_helper.shared.states_id[particle_id] = state_id
    end

    mean_distance_to_local_center_of_mass =
        distance_to_local_center_of_mass_sum / n_particles
    env_helper.add_shape_reward_term =
        mean_distance_to_local_center_of_mass /
        env_helper.max_distance_to_local_center_of_mass < env_helper.trigger
    if env_helper.add_shape_reward_term
        print("*")
    end

    env_helper.center_of_mass = ReCo.center_of_mass(particles, env_helper.half_box_len)

    v1, v2 = ReCo.gyration_tensor_eigvecs(
        particles, env_helper.half_box_len, env_helper.center_of_mass
    )

    env_helper.gyration_tensor_eigvec_to_smaller_eigval = v1
    env_helper.gyration_tensor_eigvec_to_bigger_eigval = v2

    return nothing
end

function update_reward!(
    env::LocalCOMWithAdditionalShapeRewardEnv,
    env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper,
    particle::ReCo.Particle,
)
    n_neighbors = env_helper.n_neighbors[particle.id]

    if n_neighbors == 0
        env.shared.reward = 0.0
    else
        reward = minimizing_reward(
            env_helper.distances_to_local_center_of_mass[particle.id],
            env_helper.max_distance_to_local_center_of_mass,
        )

        if env_helper.add_shape_reward_term
            elliptical_distance = ReCo.elliptical_distance(
                particle.c,
                env_helper.center_of_mass,
                env_helper.gyration_tensor_eigvec_to_smaller_eigval,
                env_helper.gyration_tensor_eigvec_to_bigger_eigval,
                env_helper.shared.elliptical_b_a_ratio,
                env_helper.half_box_len,
            )

            reward += minimizing_reward(
                elliptical_distance, env_helper.max_elliptical_distance
            )
        end

        reward /= 2

        set_normalized_reward!(env, reward, env_helper)
    end

    return nothing
end