export LocalCOMEnv

struct LocalCOMEnv <: Env
    params::EnvParams

    distance_state_space::Vector{Interval}
    direction_angle_state_space::Vector{Interval}

    max_distance::Float64

    function LocalCOMEnv(
        sim_consts; n_distance_states::Int64=3, n_direction_angle_states::Int64=3
    )
        @assert n_direction_angle_states > 1

        direction_angle_state_space = gen_angle_state_space(n_direction_angle_states)

        min_distance = 0.0
        max_distance = sim_consts.skin_r

        distance_state_space = gen_distance_state_space(
            min_distance, max_distance, n_distance_states
        )

        n_states = n_distance_states * n_direction_angle_states + 1

        state_space = Vector{SVector{2,Interval}}(undef, n_states - 1)

        ind = 1
        for distance_state in distance_state_space
            for direction_angle_state in direction_angle_state_space
                state_space[ind] = SVector(distance_state, direction_angle_state)
                ind += 1
            end
        end
        # Last state is when no particle is in the skin radius

        params = EnvParams(n_states, state_space)

        return new(params, distance_state_space, direction_angle_state_space, max_distance)
    end
end

struct LocalCOMEnvHelper <: EnvHelper
    params::EnvHelperParams

    vec_to_neighbour_sums::Vector{SVector{2,Float64}}
    n_neighbours::Vector{Int64}

    function LocalCOMEnvHelper(params::EnvHelperParams)
        return new(
            params, fill(SVector(0.0, 0.0), params.n_particles), fill(0, params.n_particles)
        )
    end
end

function gen_env_helper(::LocalCOMEnv, env_helper_params::EnvHelperParams)
    return LocalCOMEnvHelper(env_helper_params)
end

function pre_integration_hook(env_helper::LocalCOMEnvHelper)
    @simd for id in 1:(env_helper.params.n_particles)
        env_helper.vec_to_neighbour_sums[id] = SVector(0.0, 0.0)
        env_helper.n_neighbours[id] = 0
    end

    return nothing
end

function state_update_helper_hook(
    env_helper::LocalCOMEnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}
)
    env_helper.vec_to_neighbour_sums[id1] += r⃗₁₂
    env_helper.vec_to_neighbour_sums[id2] -= r⃗₁₂

    env_helper.n_neighbours[id1] += 1
    env_helper.n_neighbours[id2] += 1

    return nothing
end

function state_update_hook(env_helper::LocalCOMEnvHelper, particles::Vector{Particle})
    n_particles = env_helper.params.n_particles

    @turbo for id in 1:(n_particles)
        env_helper.params.old_states_ind[id] = env_helper.params.states_ind[id]
    end

    env = env_helper.params.env

    for id in 1:n_particles
        n_neighbours = env_helper.n_neighbours[id]

        if n_neighbours == 0
            state_ind = env.params.n_states
        else
            vec_to_local_center_of_mass =
                env_helper.vec_to_neighbour_sums[id] / n_neighbours

            distance = sqrt(
                vec_to_local_center_of_mass[1]^2 + vec_to_local_center_of_mass[2]^2
            )

            distance_state = find_state_interval(distance, env.distance_state_space)

            si, co = sincos(particles[id].φ)

            direction_angle = angle2(SVector(co, si), vec_to_local_center_of_mass)

            direction_angle_state = find_state_interval(
                direction_angle, env.direction_angle_state_space
            )

            state = SVector{2,Interval}(distance_state, direction_angle_state)
            state_ind = find_state_ind(state, env.params.state_space)
        end

        env_helper.params.states_ind[id] = state_ind
    end

    return nothing
end

function update_reward!(env::LocalCOMEnv, env_helper::LocalCOMEnvHelper, particle::Particle)
    id = particle.id

    normalization = (env.max_distance * env_helper.params.n_particles)

    n_neighbours = env_helper.n_neighbours[id]
    if n_neighbours == 0
        env.params.reward = -(env.max_distance^2) / normalization
    else
        vec_to_local_center_of_mass = env_helper.vec_to_neighbour_sums[id] / n_neighbours # TODO: Reuse vec_to_local_center_of_mass from state_update_hook
        env.params.reward =
            -(vec_to_local_center_of_mass[1]^2 + vec_to_local_center_of_mass[2]^2) /
            normalization
    end

    return nothing
end