from typing import Callable

import jax
import jax.numpy as jnp
import numpy as np
from folx import (
    ForwardLaplacianOperator,
    LoopLaplacianOperator,
    batched_vmap,
)

from globe.nn.ferminet import netjit
from globe.utils import adj_idx, tree_generator_zip, triu_idx
from globe.utils.config import (
    SystemConfigs,
    group_by_config,
    inverse_group_idx,
)


def make_kinetic_energy_function(f, group_parameters_fn, operator: str = 'loop'):
    operator = operator.lower()
    match operator.lower():
        case 'loop':
            op = LoopLaplacianOperator()
        case 'forward':
            op = ForwardLaplacianOperator(0.25)
        case _:
            raise ValueError(f'Unknown operator {operator}')

    @netjit
    def laplacian_of_f(params, electrons, atoms, config: SystemConfigs, mol_params):
        def _laplacian(elec, atoms, config, mol_params):
            def f_closure(x):
                return f(params, x, atoms, config, mol_params).squeeze()

            laplacian, quantum_force = op(f_closure)(elec)
            return -0.5 * (jnp.sum(laplacian) + jnp.sum(quantum_force**2))

        match operator.lower():
            case 'loop':
                _laplacian = jax.vmap(_laplacian, in_axes=(0, 0, None, 0))
            case 'forward':
                _laplacian = batched_vmap(
                    _laplacian, max_batch_size=1, in_axes=(0, 0, None, 0)
                )
            case _:
                raise ValueError(f'Unknown operator {operator}')

        result = []
        for (elec, (spins, charges)), at, pm in tree_generator_zip(
            group_by_config(
                config, electrons, lambda s, c: np.sum(s), return_config=True
            ),
            group_by_config(config, atoms, lambda s, c: len(c)),
            group_parameters_fn(mol_params, config),
        ):
            conf = SystemConfigs((spins,), (charges,))
            result.append(_laplacian(elec, at, conf, pm))
        idx = inverse_group_idx(config)
        return jnp.concatenate(result)[idx]

    return laplacian_of_f


def potential_energy(
    electrons: jax.Array, atoms: jax.Array, config: SystemConfigs
) -> jax.Array:
    """
    Computes the potential energy of the system.

    Args:
    - electrons: (n_electrons, 3) array of electron positions
    - atoms: (n_atoms, 3) array of atom positions
    - config: SystemConfigs object containing the spin and charge of each atom
    Returns:
    - (n_graphs,) array of potential energies
    """
    charges = config.flat_charges()
    spins, n_nuclei = config.spins, config.n_nuc
    electrons = electrons.reshape(-1, 3)
    n_graphs = len(spins)

    i, j, m = triu_idx(np.sum(spins, -1), 1)
    r_ee = jnp.linalg.norm(electrons[i] - electrons[j], axis=-1)
    v_ee = jax.ops.segment_sum(
        1.0 / r_ee,
        m,
        n_graphs,
        True,
    )

    i, j, m = adj_idx(np.sum(spins, -1), n_nuclei)
    r_ae = jnp.linalg.norm(electrons[i] - atoms[j], axis=-1)
    v_ae = -jax.ops.segment_sum(
        charges[j] / r_ae,
        m,
        n_graphs,
        True,
    )

    i, j, m = triu_idx(n_nuclei, 1)
    r_aa = jnp.linalg.norm(atoms[i] - atoms[j], axis=-1)
    v_aa = jax.ops.segment_sum(
        charges[i] * charges[j] / r_aa,
        m,
        n_graphs,
        True,
    )
    return v_ee + v_ae + v_aa


def make_local_energy_function(f, group_parameters_fn, operator: str) -> Callable:
    """
    Returns a function that computes the local energy of the system.

    Args:
    - f: function that computes the wavefunction
    - group_parameters_fn: function that groups the parameters by graph
    - linearize: whether to linearize the function f
    Returns:
    - function that computes the local energy
    """
    kinetic_energy_fn = make_kinetic_energy_function(f, group_parameters_fn, operator)

    @netjit
    def local_energy(
        params, electrons, atoms, config: SystemConfigs, mol_params
    ) -> jax.Array:
        potential = potential_energy(electrons, atoms, config)
        kinetic = kinetic_energy_fn(params, electrons, atoms, config, mol_params)
        return potential + kinetic

    return local_energy
