module Shape

using StaticArrays: SVector

export center_of_mass,
    gyration_tensor_eigvals_ratio, gyration_tensor_eigvecs, elliptical_distance

using StaticArrays: SVector, SMatrix
using LinearAlgebra: LinearAlgebra as LA

using ..ReCo: ReCo, Particle

function project_to_unit_circle(x::Real, half_box_len::Real)
    θ = (x + half_box_len) * π / half_box_len
    si, co = sincos(θ)

    return SVector(co, si)
end

function project_back_from_unit_circle(θ::Real, half_box_len::Real)
    x = θ * half_box_len / π - half_box_len

    return ReCo.restrict_coordinate(x, half_box_len)
end

function center_of_mass_from_proj_sums(
    x_proj_sum::SVector{2,<:Real}, y_proj_sum::SVector{2,<:Real}, half_box_len::Real
)
    # Prevent for example atan(1e-16, 1e-15) != 0 with rounding
    digits = 5

    # No need for 1/n_particles with atan
    # If proj is (0, 0) then COM is 0 or L or -L. Here, 0 is chosen with θ = π
    if round(x_proj_sum[1]; digits=digits) == round(x_proj_sum[2]; digits=digits) == 0
        x_θ = π
    else
        x_θ = atan(x_proj_sum[2], x_proj_sum[1])
    end

    if round(y_proj_sum[1]; digits=digits) == round(y_proj_sum[2]; digits=digits) == 0
        y_θ = π
    else
        y_θ = atan(y_proj_sum[2], y_proj_sum[1])
    end

    COM_x = project_back_from_unit_circle(x_θ, half_box_len)
    COM_y = project_back_from_unit_circle(y_θ, half_box_len)

    return SVector(COM_x, COM_y)
end

function center_of_mass(centers::AbstractVector{<:SVector{2,<:Real}}, half_box_len::Real)
    x_proj_sum = SVector(0.0, 0.0)
    y_proj_sum = SVector(0.0, 0.0)

    for c in centers
        x_proj_sum += project_to_unit_circle(c[1], half_box_len)
        y_proj_sum += project_to_unit_circle(c[2], half_box_len)
    end

    return center_of_mass_from_proj_sums(x_proj_sum, y_proj_sum, half_box_len)
end

function center_of_mass(particles::AbstractVector{Particle}, half_box_len::Real)
    x_proj_sum = SVector(0.0, 0.0)
    y_proj_sum = SVector(0.0, 0.0)

    for p in particles
        x_proj_sum += project_to_unit_circle(p.c[1], half_box_len)
        y_proj_sum += project_to_unit_circle(p.c[2], half_box_len)
    end

    return center_of_mass_from_proj_sums(x_proj_sum, y_proj_sum, half_box_len)
end

function gyration_tensor(
    particles::AbstractVector{Particle}, half_box_len::Real, COM::SVector{2,<:Real}
)
    S11 = 0.0
    S12 = 0.0
    S22 = 0.0

    for p in particles
        shifted_c = ReCo.restrict_coordinates(p.c - COM, half_box_len)

        S11 += shifted_c[1]^2
        S12 += shifted_c[1] * shifted_c[2]
        S22 += shifted_c[2]^2
    end

    return LA.Hermitian(SMatrix{2,2}(S11, S12, S12, S22))
end

function gyration_tensor(
    centers::AbstractVector{<:SVector{2,<:Real}}, half_box_len::Real, COM::SVector{2,<:Real}
)
    S11 = 0.0
    S12 = 0.0
    S22 = 0.0

    for c in centers
        shifted_c = ReCo.restrict_coordinates(c - COM, half_box_len)

        S11 += shifted_c[1]^2
        S12 += shifted_c[1] * shifted_c[2]
        S22 += shifted_c[2]^2
    end

    return LA.Hermitian(SMatrix{2,2}(S11, S12, S12, S22))
end

function gyration_tensor(
    particles_or_centers::Union{
        AbstractVector{Particle},AbstractVector{<:SVector{2,<:Real}}
    },
    half_box_len::Real,
)
    COM = center_of_mass(particles_or_centers, half_box_len)

    return gyration_tensor(particles_or_centers, half_box_len, COM)
end

function eigvals_ratio(matrix)
    ev = LA.eigvals(matrix) # Eigenvalues are sorted
    return abs(ev[1] / ev[2])
end

function gyration_tensor_eigvals_ratio(
    particles_or_centers::Union{
        AbstractVector{Particle},AbstractVector{<:SVector{2,<:Real}}
    },
    half_box_len::Real,
)
    g_tensor = gyration_tensor(particles_or_centers, half_box_len)
    return eigvals_ratio(g_tensor)
end

function gyration_tensor_eigvals_ratio(
    particles_or_centers::Union{
        AbstractVector{Particle},AbstractVector{<:SVector{2,<:Real}}
    },
    half_box_len::Real,
    COM::SVector{2,<:Real},
)
    g_tensor = gyration_tensor(particles_or_centers, half_box_len, COM)
    return eigvals_ratio(g_tensor)
end

function gyration_tensor_eigvecs(
    particles::AbstractVector{Particle}, half_box_len::Real, COM::SVector{2,<:Real}
)
    g_tensor = gyration_tensor(particles, half_box_len, COM)
    eig_vecs = LA.eigvecs(g_tensor)

    v1 = eig_vecs[:, 1]
    v2 = eig_vecs[:, 2]

    return (v1, v2)
end

function elliptical_distance(
    v::SVector{2,<:Real},
    COM::SVector{2,<:Real},
    gyration_tensor_eigvec_to_smaller_eigval::SVector{2,<:Real},
    gyration_tensor_eigvec_to_bigger_eigval::SVector{2,<:Real},
    elliptical_b_a_ratio::Real,
    half_box_len::Real,
)
    v′ = ReCo.restrict_coordinates(v - COM, half_box_len)

    x = LA.dot(v′, gyration_tensor_eigvec_to_bigger_eigval)
    y = LA.dot(v′, gyration_tensor_eigvec_to_smaller_eigval)

    return sqrt(x^2 + (y / elliptical_b_a_ratio)^2)
end

end # module