using Distributions: Normal
using ProgressMeter: ProgressMeter
using StaticArrays: SVector
using CellListMap: CellListMap

function rand_normal01()
    return rand(Normal(0, 1))
end

function push_to_verlet_list!(verlet_lists, i, j)
    if i < j
        push!(verlet_lists[i], Int64(j))
    else
        push!(verlet_lists[j], Int64(i))
    end

    return nothing
end

function update_verlet_lists!(args, cl::CellListMap.CellList)
    @simd for pre_vec in args.verlet_lists
        reset!(pre_vec)
    end

    @simd for i in 1:(args.n_particles)
        args.particles_c[i] = args.particles[i].c
    end

    cl = CellListMap.UpdateCellList!(args.particles_c, args.box, cl; parallel=false)

    CellListMap.map_pairwise!(
        (x, y, i, j, d2, output) -> push_to_verlet_list!(args.verlet_lists, i, j),
        nothing,
        args.box,
        cl;
        parallel=false,
    )

    return cl
end

function euler!(
    args,
    first_integration_step::Bool,
    env_helper::Union{RL.EnvHelper,Nothing},
    state_update_helper_hook!::Function,
    state_update_hook!::Function,
    update_table_and_actions_hook!::Function,
)
    for id1 in 1:(args.n_particles - 1)
        p1 = args.particles[id1]
        p1_c = p1.c
        verlet_list = args.verlet_lists[id1]

        for id2 in view(verlet_list.v, 1:(verlet_list.last_ind))
            p2 = args.particles[id2]

            overlapping, r⃗₁₂, distance² = are_overlapping(
                p1_c, p2.c, args.interaction_radius², args.half_box_len
            )

            state_update_helper_hook!(env_helper, id1, id2, r⃗₁₂, distance²)

            if overlapping
                factor = args.ϵσ⁶δtμₜ24 / (distance²^4) * (1.0 - args.σ⁶2 / (distance²^3))
                μₜF⃗₁₂ = factor * r⃗₁₂ # Force acting on 1 from 2 multiplied with μₜ

                p1.tmp_c += μₜF⃗₁₂
                p2.tmp_c -= μₜF⃗₁₂
            end
        end
    end

    RL.copy_states_to_old_states_hook!(env_helper)
    state_update_hook!(env_helper, args.particles)

    @simd for p in args.particles
        si, co = sincos(p.φ)
        p.tmp_c += SVector(
            args.v₀δt * co + args.sqrt_Dₜδt2 * rand_normal01(),
            args.v₀δt * si + args.sqrt_Dₜδt2 * rand_normal01(),
        )

        restrict_coordinates!(p, args.half_box_len)

        update_table_and_actions_hook!(env_helper, p, first_integration_step)

        RL.act_hook!(p, env_helper, args.δt, si, co)

        p.φ += args.sqrt_Dᵣδt2 * rand_normal01()

        p.c = p.tmp_c
    end

    return nothing
end

function Base.wait(::Nothing)
    return nothing
end

function gen_run_additional_hooks(::Nothing, args...)
    return false
end

function gen_run_additional_hooks(env_helper::RL.EnvHelper, integration_step::Int64)
    return (integration_step % env_helper.shared.n_steps_before_actions_update == 0) ||
           (integration_step == 1)
end

function gen_cell_list(particles_c::Vector{SVector{2,Float64}}, box::CellListMap.Box)
    return CellListMap.CellList(particles_c, box; parallel=false)
end

function simulate!(
    args,
    T0::Float64,
    T::Float64,
    n_steps_before_verlet_list_update::Int64,
    n_steps_before_snapshot::Int64,
    n_bundles::Int64,
    sim_dir::String,
    save_data::Bool,
    env_helper::Union{RL.EnvHelper,Nothing},
)
    bundle_snapshot_counter = 0

    task::Union{Task,Nothing} = nothing

    cl = gen_cell_list(args.particles_c, args.box)
    cl = update_verlet_lists!(args, cl)

    first_integration_step = true

    state_update_helper_hook! =
        state_update_hook! = update_table_and_actions_hook! = empty_hook

    time_range = T0:(args.δt):T

    progress = ProgressMeter.Progress(
        length(time_range); dt=2, enabled=args.show_progress, desc="Simulation: "
    )

    for (integration_step, t) in enumerate(time_range)
        if (integration_step % n_steps_before_snapshot == 0) && save_data
            wait(task)

            bundle_snapshot_counter += 1
            save_snapshot!(args.bundle, bundle_snapshot_counter, t, args.particles)

            if bundle_snapshot_counter == args.n_bundle_snapshots
                task = @async begin
                    bundle_snapshot_counter = 0
                    n_bundles += 1

                    save_bundle(sim_dir, args.bundle, n_bundles, t)
                end
            end
        end

        if integration_step % n_steps_before_verlet_list_update == 0
            cl = update_verlet_lists!(args, cl)
        end

        run_additional_hooks = gen_run_additional_hooks(env_helper, integration_step)

        if run_additional_hooks
            RL.pre_integration_hook!(env_helper)

            state_update_helper_hook! = RL.state_update_helper_hook!
            state_update_hook! = RL.state_update_hook!
            update_table_and_actions_hook! = RL.update_table_and_actions_hook!
        end

        euler!(
            args,
            first_integration_step,
            env_helper,
            state_update_helper_hook!,
            state_update_hook!,
            update_table_and_actions_hook!,
        )

        if run_additional_hooks
            state_update_helper_hook! =
                state_update_hook! = update_table_and_actions_hook! = empty_hook
        end

        first_integration_step = false

        ProgressMeter.next!(progress)
    end

    wait(task)

    if bundle_snapshot_counter > 0
        bundle = first_n_snapshots(args.bundle, bundle_snapshot_counter)

        n_bundles += 1

        save_bundle(sim_dir, bundle, n_bundles, T)
    end

    return nothing
end