struct TwoBodyProblem{T} <: AbstractDynamicalSystem{T} end

# Ground truth equations of motion
# In-place
function (system::TwoBodyProblem)(du, u, p, t)
    q₁, q₂, p₁, p₂ = u

    du[1] = p₁
    du[2] = p₂
    du[3] = -q₁ / (q₁^2 + q₂^2)^(3 / 2)
    du[4] = -q₂ / (q₁^2 + q₂^2)^(3 / 2)

    return nothing
end

function hamiltonian(u::AbstractVector{T}) where {T}
    q₁, q₂, p₁, p₂ = u
    return T(0.5) * (p₁^2 + p₂^2) - 1 / sqrt(q₁^2 + q₂^2)
end

function initial_conditions(system::TwoBodyProblem{T}) where {T}
    e = 0.5 + 0.2 * rand()  # eccentricity drawn from [0.5, 0.7]
    return T[1-e, 0, 0, sqrt((1 + e) / (1 - e))]
end

function get_trajectories(
    system::TwoBodyProblem{T},
    experiment_version,
    seconds,
    dt,
    transient_seconds,
    solver,
    reltol,
    abstol,
    N,
    steps,
    stabilization_param,
    θ,
    restructure,
) where {T}
    f = NeuralVectorField(system, experiment_version, restructure)
    F = ConstraintsPseudoinverse(system, experiment_version)
    γ = T(stabilization_param)

    systemBF = TwoBodyProblem{BigFloat}()
    trajectories = []
    for _ = 1:N
        u0 = initial_conditions(systemBF)
        H₀ = hamiltonian(u0)
        period = 2π * (2 * abs(H₀))^(-3 / 2)  # We want one full period for each trajectory
        time_series = generate_data(
            systemBF;
            seconds = period,
            dt,
            transient_seconds,
            solver,
            reltol,
            abstol,
            u0,
            NF = T,
        )
        u0 = time_series.trajectory[:, 1]
        t0 = time_series.times[1]

        # Set up the SNDE
        if γ == 0
            rhs = f
        else
            g = ConstraintsFunction(system, experiment_version, u0, t0)
            rhs = StabilizedNDE(f, g, F, γ)
        end
        prob = ODEProblem{false,SciMLBase.FullSpecialize}(rhs, zeros(T), (zero(T), one(T)), θ)
        data_ms = multiple_shooting(prob, time_series; steps)
        push!(trajectories, data_ms)
    end
    return vcat(trajectories...)
end

# EXPERIMENT 1: Constrain energy and angular momentum
function get_mlp(
    hidden_layers,
    hidden_width,
    activation,
    ::TwoBodyProblem{T},
    ::Val{1},
) where {T}
    return get_mlp(4 => 2, hidden_layers, hidden_width, activation, T)
end

function constraints(u, t, ::TwoBodyProblem{T}, ::Val{1}) where {T}
    q₁, q₂, p₁, p₂ = u
    return [T(0.5) * (p₁^2 + p₂^2) - 1 / sqrt(q₁^2 + q₂^2), q₁ * p₂ - q₂ * p₁]
end

function constraints_jacobian(u, t, ::TwoBodyProblem, ::Val{1})
    q₁, q₂, p₁, p₂ = u
    #! format: off
    return [
        q₁*(q₁^2+q₂^2)^(-3/2)       q₂*(q₁^2+q₂^2)^(-3/2)       p₁      p₂
        p₂                          -p₁                         -q₂     q₁
    ]
    #! format: on
end

function rhs_neural(du, u, θ, t, re::Optimisers.Restructure, ::TwoBodyProblem, ::Val{1})
    du[1:2] .= u[3:4]
    du[3:4] .= re(θ)(u)
    return nothing
end

function rhs_neural(u, θ, t, re::Optimisers.Restructure, ::TwoBodyProblem, ::Val{1})
    q₁, q₂, p₁, p₂ = u
    dp₁, dp₂ = re(θ)(u)
    return [p₁, p₂, dp₁, dp₂]
end

# EXPERIMENT 2: Constrain angular momentum only
function get_mlp(
    hidden_layers,
    hidden_width,
    activation,
    ::TwoBodyProblem{T},
    ::Val{2},
) where {T}
    return get_mlp(4 => 2, hidden_layers, hidden_width, activation, T)
end

function constraints(u, t, ::TwoBodyProblem, ::Val{2})
    q₁, q₂, p₁, p₂ = u
    return [q₁ * p₂ - q₂ * p₁]
end

function constraints_jacobian(u, t, ::TwoBodyProblem, ::Val{2})
    q₁, q₂, p₁, p₂ = u
    return [p₂ -p₁ -q₂ q₁]
end

function rhs_neural(du, u, θ, t, re::Optimisers.Restructure, ::TwoBodyProblem, ::Val{2})
    du[1:2] .= u[3:4]
    du[3:4] .= re(θ)(u)
    return nothing
end

function rhs_neural(u, θ, t, re::Optimisers.Restructure, ::TwoBodyProblem, ::Val{2})
    q₁, q₂, p₁, p₂ = u
    dp₁, dp₂ = re(θ)(u)
    return [p₁, p₂, dp₁, dp₂]
end
