import functools
from typing import Callable, Literal, Sequence

import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import optax
from scipy.special import factorial2

from globe.nn.parameters import INVERSE_TRANSFORMS, ParamSpec, ParamTree, SpecTree
from globe.systems.scf import Scf
from globe.typing import HfOrbitalFunction, OrbitalMatchingFunction, SystemConfigs
from globe.utils import iterate_segments
from globe.utils.jax_utils import pmap, pmean_if_pmap


def eval_orbitals(
    scf_approx: Sequence[Scf], electrons: jax.Array, spins: tuple[tuple[int, int], ...]
) -> list[tuple[np.ndarray, np.ndarray, np.ndarray]]:
    """Returns the molecular orbitals of Hartree Fock calculations.

    Args:
        scf_approx (List[Scf]): Hartree Fock calculations, length H
        electrons ([type]): (B, N*H, 3)
        spins ([type]): list of length H with (spin_up, spin_down)

    Returns:
        List of length H where each element is a tuple of ((B, up, up), (B, down, down), (B, N, M, O))
    """
    assert len(scf_approx) == len(spins)
    assert np.sum(spins) == electrons.shape[-2]
    result = []
    for scf, elec, (na, nb) in zip(
        scf_approx, iterate_segments(electrons, np.sum(spins, axis=-1), axis=-2), spins
    ):
        mos, aos = scf.eval_molecular_orbitals(elec.reshape(-1, 3))
        aos = aos.reshape(*elec.shape[:-1], *aos.shape[-2:])
        mos = mos.reshape(2, *elec.shape[:-1], mos.shape[-1])
        mo_alpha = mos[0, ..., :na, :na]
        mo_beta = mos[1, ..., na:, :nb]
        # mo_deriv = mo_deriv.reshape(2, 3, *elec.shape[:-1], mo_deriv.shape[-1])
        # mod_alpha = np.moveaxis(mo_deriv[0, ..., :na, :na], 0, -1)
        # mod_beta = np.moveaxis(mo_deriv[1, ..., na:, :nb], 0, -1)
        result.append(
            (
                mo_alpha.astype(np.float32),
                mo_beta.astype(np.float32),
                aos.astype(np.float32),
            )
        )
    return result


def make_mol_param_loss(
    spec_tree: SpecTree, scale: float, max_moment: int = 4, eps=1e-6
) -> Callable[[ParamTree], jax.Array]:
    """
    Computes the loss for the molecular parameters.
    The loss is the difference between the target moments and the observed moments.

    Args:
    - spec_tree: The parameter specification tree
    - scale: The scale of the loss
    - max_moment: The maximum moment to compute
    - eps: A small number to avoid division by zero
    Returns:
    - A function that computes the loss
    """
    p = np.arange(1, max_moment + 1)
    # all odd moments are 0
    # https://en.wikipedia.org/wiki/Normal_distribution#Moments:~:text=standard%20normal%20distribution.-,Moments,-See%20also%3A
    target_moments = 1**p * factorial2(p - 1) * (1 - p % 2)

    def distr_loss(param: jax.Array, spec: ParamSpec):
        if not spec.keep_distr:
            return 0
        if isinstance(spec.keep_distr, float):
            scale = spec.keep_distr
        else:
            scale = 1
        # We must reverse the transformation applied to the parameters if possible
        if spec.transform in INVERSE_TRANSFORMS:
            param = INVERSE_TRANSFORMS[spec.transform](param)

        p_norm = (param - spec.mean) / (eps + spec.std)  # type: ignore
        x = p_norm[..., None] ** p
        # average over all but last dim
        observed_moments = x.mean(axis=tuple(range(x.ndim - 1)))
        return scale * ((target_moments - observed_moments) ** 2).sum()

    def loss(mol_params):
        result = jtu.tree_reduce(
            jnp.add, jtu.tree_map(distr_loss, mol_params, spec_tree)
        )
        return scale * jnp.sum(result)

    return loss


def make_pretrain_step(
    mcmc_step: Callable,
    mol_param_fn: Callable,
    wave_function: Callable,
    orbital_fn: Callable,
    opt_update: Callable,
    orbital_matching_fn: OrbitalMatchingFunction,
    mol_param_aux_loss: Callable | None = None,
    natgrad_precond: Callable | None = None,
    match_gradients: Literal[False] | float = False,
    regularize_correlation: bool = False,
):
    """
    Creates a pretrain step function for the molecular parameters and the orbitals.

    Args:
    - mcmc_step: A function that performs a single MCMC step
    - mol_param_fn: A function that computes the molecular parameters
    - orbital_fn: A function that computes the orbitals
    - opt_update: A function that updates the optimizer state
    - full_det: Whether to train with full determinants or not
    - mol_param_aux_loss: An auxiliary loss for the molecular parameters
    - natgrad_precond: A function that computes the preconditioner for the natural gradient
    Returns:
    - A function that performs a single pretrain step
    """
    orbital_fn = jax.vmap(orbital_fn, in_axes=(None, 0, None, None, None))
    orbital_fn = jax.jit(orbital_fn, static_argnums=3)

    @functools.partial(
        pmap,
        in_axes=(0, 0, None, None, None, 0, 0, None, 0, 0),
        static_broadcasted_argnums=(3, 4),
    )
    def pretrain_step(
        params,
        electrons,
        atoms,
        config: SystemConfigs,
        hf_mo_fns: tuple[HfOrbitalFunction, ...],
        opt_state,
        key,
        properties,
        natgrad_state,
        cache,
    ):
        def loss_fn(p):
            # Orbitals are a list of tuples
            mol_params = mol_param_fn({**params, 'params': p}, atoms, config)
            orbitals = orbital_fn(params, electrons, atoms, config, mol_params)

            hf_orbitals = []
            hf_gradients = []
            for hf_fn, elec in zip(
                hf_mo_fns,
                iterate_segments(electrons, np.sum(config.spins, axis=-1), axis=-2),
            ):
                if match_gradients:

                    def log_wf(x):
                        # We do this in float64 as some stuff might be a bit fragile in float32
                        hf_up, hf_down = hf_fn(x.astype(jnp.float64))
                        hf_up, hf_down = hf_up.astype(x.dtype), hf_down.astype(x.dtype)
                        log_wf = (
                            jnp.linalg.slogdet(hf_up)[1]
                            + jnp.linalg.slogdet(hf_down)[1]
                        )
                        return log_wf.sum(), (hf_up, hf_down)

                    gradient, (hf_up, hf_down) = jax.vmap(
                        jax.grad(log_wf, has_aux=True)
                    )(elec)
                    hf_orbitals.append((hf_up, hf_down))
                    hf_gradients.append(gradient)
                else:
                    hf_orbitals.append(jax.vmap(hf_fn)(elec))
            matched = orbital_matching_fn(orbitals, hf_orbitals, config, cache=cache)

            if len(matched) == 2:
                orbitals, final_targets = matched
                out_cache = None
            else:
                orbitals, final_targets, out_cache = matched

            # over different molecules
            def mse(x, y):
                if x.ndim < y.ndim:
                    x = x[..., None, :, :]
                return ((x - y) ** 2).mean()

            orbital_loss = jtu.tree_reduce(
                jnp.add, jtu.tree_map(mse, final_targets, orbitals)
            ) / len(orbitals)

            grad_loss = 0
            if match_gradients:
                for elec, at, conf, hf_grad in zip(
                    iterate_segments(electrons, config.n_elec, axis=-2),
                    iterate_segments(atoms, config.n_nuc, axis=0),
                    config.sub_configs,
                    hf_gradients,
                ):

                    def wf_close(x):
                        return wave_function(params, x, at, conf).sum()

                    grad = jax.vmap(jax.grad(wf_close))(elec)
                    grad_loss += match_gradients * jnp.abs(grad - hf_grad).mean()
                grad_loss /= config.n_mols

            if mol_param_aux_loss is not None:
                aux_loss = mol_param_aux_loss(mol_params)
            else:
                aux_loss = 0

            svd_loss = 0
            if regularize_correlation:
                for A in iterate_segments(
                    mol_params['orbitals']['correlation'],
                    np.array(config.n_nuc)[:-1] ** 2,
                ):
                    from globe.nn import block

                    n_orb = A.shape[-1]
                    n_det = A.shape[-3]
                    n = int(A.shape[0] ** 0.5)
                    A = A.reshape(n, n, n_det, n_orb, n_orb)
                    A = jnp.transpose(A, (2, 0, 3, 1, 4)).reshape(
                        n_det, n * n_orb, n * n_orb
                    )
                    A, S = (A - A.mT) / 2, (A + A.mT) / 2
                    A = block(A, S, -S.mT, A)
                    s = jnp.linalg.svd(A, full_matrices=False, compute_uv=False)
                    svd_loss += ((s - 1) ** 2).mean()
            svd_loss /= config.n_mols
            total_loss = orbital_loss + grad_loss + aux_loss + svd_loss
            return total_loss, (
                {
                    'total_loss': total_loss,
                    'orbital_loss': orbital_loss,
                    'grad_loss': grad_loss,
                    'aux_loss': aux_loss,
                    'svd_loss': svd_loss,
                    # 'log_det': jtu.tree_reduce(
                    #     jnp.add, jtu.tree_map(jnp.linalg.slogdet, orbitals)
                    # )
                    # / len(orbitals),
                    # 'hf_log_det': jtu.tree_reduce(
                    #     jnp.add, jtu.tree_map(jnp.linalg.slogdet, final_targets)
                    # )
                    # / len(final_targets),
                },
                out_cache,
            )

        (total_loss, (losses, cache)), grad = jax.value_and_grad(loss_fn, has_aux=True)(
            params['params']
        )
        losses, grad = pmean_if_pmap((losses, grad))
        if natgrad_precond is not None:
            grad, natgrad_state = natgrad_precond(
                params, electrons, atoms, config, grad, natgrad_state
            )

        updates, opt_state = opt_update(grad, opt_state, params['params'])
        params['params'] = optax.apply_updates(params['params'], updates)

        key, subkey = jax.random.split(key)
        electrons, pmove = mcmc_step(
            params, electrons, atoms, config, subkey, properties['mcmc_width']
        )
        return params, electrons, opt_state, natgrad_state, losses, pmove, cache

    return pretrain_step
