import enum
import functools
import logging
import os
from collections import OrderedDict, defaultdict
from typing import Callable, Tuple

import h5py
import jax
import jax.tree_util as jtu
import numpy as np
import pyscf
from pyscf import lo
from pyscf.scf.hf import SCF
from scipy.optimize import linear_sum_assignment, minimize

from globe.nn.coords import find_axes
from globe.nn.orbitals import get_orbitals
from globe.systems.molecule import Molecule
from globe.utils.config import SystemConfigs

get_orbitals = jax.jit(get_orbitals, static_argnums=1, static_argnames='config')


class OrbitalType(enum.Enum):
    HF = 'hf'
    GLOBE = 'globe'
    BOYS = 'boys'


class Scf:
    """
    A Hartree-Fock mean-field solver for molecules.

    Attributes:
    - molecule (Molecule): The molecule to solve.
    - _mol (pyscf.gto.Mole): The PySCF molecule object.
    - _mean_field (pyscf.scf.hf.SCF): The PySCF mean-field object.
    - restricted (bool): Whether the calculation is restricted or unrestricted.
    - _coeff (np.ndarray): The canonicalized molecular orbital coefficients.

    Methods:
    - __init__(self, molecule: Molecule, restricted: bool = True, basis='STO-6G', verbose=3) -> None:
        Initializes the Scf object.
    - run(self, chkfile: str = None):
        Runs the Hartree-Fock calculation.
    - eval_molecular_orbitals(self, electrons: jax.Array, deriv: bool = False) -> Tuple[jax.Array, jax.Array]:
        Evaluates the molecular orbitals for a given set of electrons.
    - energy(self):
        Returns the Hartree-Fock energy.
    - mo_coeff(self):
        Returns the molecular orbital coefficients.
    """

    molecule: Molecule
    _mol: pyscf.gto.Mole
    _mean_field: SCF
    restricted: bool
    _coeff: np.ndarray | None = None

    def __init__(
        self,
        molecule: Molecule,
        restricted: bool = True,
        basis='STO-6G',
        verbose=1,
        orbitals: OrbitalType | str = OrbitalType.GLOBE,
    ) -> None:
        """
        Args:
        - molecule (Molecule): The molecule to solve.
        - restricted (bool): Whether the calculation is restricted or unrestricted.
        - basis (str): The basis set to use.
        - verbose (int): The verbosity level.
        """
        self.molecule = molecule
        self.basis = basis
        self.restricted = restricted
        self._mol = self.molecule.to_pyscf(basis, verbose)
        self.orbitals = OrbitalType(orbitals)
        self._correlation = None
        if restricted:
            self._mean_field = pyscf.scf.RHF(self._mol)
        else:
            self._mean_field = pyscf.scf.UHF(self._mol)

    def run(self, chkfile: str | None = None):
        """
        Runs the Hartree-Fock calculation.

        Args:
        - chkfile (str): The checkpoint file to use.
        """
        self._mean_field.max_cycle = 10_000
        self._mean_field.chkfile = chkfile
        if chkfile is not None and os.path.exists(chkfile):
            with h5py.File(chkfile, 'r') as inp:
                self._mean_field.mo_coeff = inp['scf']['mo_coeff'][()]  # type: ignore
                self._mean_field.e_tot = inp['scf']['e_tot'][()]  # type: ignore
        else:
            self._mean_field.kernel()
        if self.restricted:
            self.transform_orbitals()
        logging.info(f'HF energy: {self.energy}')
        return self

    def transform_orbitals(self):
        try:
            if self.orbitals is OrbitalType.GLOBE:
                self._coeff = globe_canonicalize_weights(self)
            elif self.orbitals is OrbitalType.BOYS:
                self._coeff = boys_localization(self)
                self._coeff = sort_coeff_by_atom(self)
            else:
                self._coeff = self.mo_coeff
        except Exception:  # just use the default coefficients if something goes wrong
            logging.info('Failed to canonicalize orbitals.')
            self._coeff = self.mo_coeff
        if self.orbitals in (OrbitalType.BOYS, OrbitalType.HF):
            self._coeff *= np.sign(np.sum(self._coeff, axis=-2, keepdims=True))

    @property
    def energy(self):
        return self._mean_field.e_tot

    @property
    def mo_coeff(self):
        if self._coeff is not None:
            return self._coeff
        if self.restricted:
            coeffs = (self._mean_field.mo_coeff,)
        else:
            coeffs = self._mean_field.mo_coeff
        return np.array(coeffs)

    def eval_molecular_orbitals(
        self, electrons: jax.Array | np.ndarray
    ) -> tuple[np.ndarray, np.ndarray]:
        """
        Evaluates the molecular orbitals for a given set of electrons.

        Args:
        - electrons (jax.Array): A 2D array of shape (n_electrons, 3) containing the electron coordinates.

        Returns:
        - Tuple[np.ndarray, np.ndarray]: A tuple containing the molecular orbitals and atomic orbitals.
        """
        if self._mol.cart:
            raise NotImplementedError(
                'Evaluation of molecular orbitals using cartesian GTOs.'
            )

        # gto_op = 'GTOval_sph_deriv1'
        gto_op = 'GTOval_sph'
        electrons = np.array(electrons)
        ao_values = self._mol.eval_gto(gto_op, electrons)
        mo_values = tuple(np.matmul(ao_values, coeff) for coeff in self.mo_coeff)
        if self.restricted:
            mo_values *= 2

        # ao_deriv = np.array(ao_deriv)
        # mo_deriv = tuple(np.matmul(ao_deriv, coeff) for coeff in self.mo_coeff)
        # if self.restricted:
        #     mo_values *= 2
        #     mo_deriv *= 2

        ao_per_atom = get_number_ao(self)
        ao_idx, max_ao = np.cumsum(ao_per_atom[:-1]), max(ao_per_atom)
        ao_orbitals = np.stack(
            [
                np.concatenate(
                    [x, np.zeros((x.shape[0], max_ao - x.shape[1]))], axis=-1
                )
                for x in np.split(ao_values, ao_idx, axis=-1)
            ],
            axis=1,
        )
        return np.array(mo_values), ao_orbitals

    @property
    def ao_nuc_idx(self):
        return [int(label.split(' ')[0]) for label in self._mol.ao_labels()]

    @property
    def correlation(self):
        if self._correlation is not None:
            return self._correlation
        coeff = self.mo_coeff[0][..., self._mean_field.mo_occ > 0]
        coords = self._mol.atom_coords()[np.array(self.ao_nuc_idx)]
        mo_pos = np.einsum('...ij,...id->...jd', coeff, coords)
        mo_pos = np.concatenate([mo_pos] * 2, axis=-2)  # for full det
        diffs = np.linalg.norm(mo_pos[..., None, :] - mo_pos[..., None, :, :], axis=-1)
        adjacency = np.exp(-diffs / 10)
        correlation = np.triu(adjacency)
        correlation = correlation - np.swapaxes(correlation, -1, -2)
        self._correlation = correlation
        return correlation


def sort_coeff_by_atom(scf: Scf):
    coeff = scf._coeff
    ao_to_atomic_idx = np.array(scf.ao_nuc_idx)
    new_coeff = []
    occ = scf._mean_field.mo_occ  # type: ignore
    for i, all_coeff in enumerate(scf._coeff):
        coeff = all_coeff[..., occ > i]
        affinities = np.asarray(jax.ops.segment_sum(np.abs(coeff), ao_to_atomic_idx))
        main_nuc = np.argmin(affinities, axis=-2)
        order = np.argsort(main_nuc, kind='stable')
        coeff = np.concatenate([coeff[..., order], all_coeff[..., occ <= i]], axis=-1)
        new_coeff.append(coeff)
    return np.array(new_coeff)


def boys_localization(scf: Scf):
    mo_occ: np.ndarray = scf._mean_field.mo_occ  # type: ignore
    mo_coeff: np.ndarray = scf._mean_field.mo_coeff  # type: ignore
    occ = np.sum(mo_occ > 0)
    mo_coeff_occ = mo_coeff[..., :occ]
    coeff: np.ndarray = lo.Boys(scf._mol, mo_coeff=mo_coeff_occ).kernel()  # type: ignore
    transform = np.linalg.lstsq(mo_coeff_occ, coeff, rcond=None)[0]
    transform = np.abs(transform)
    transform /= transform.sum(-2, keepdims=True)
    mo_energy: np.ndarray = scf._mean_field.mo_energy  # type: ignore
    new_energies = mo_energy[..., :occ] @ transform
    return np.concatenate(
        [coeff[..., np.argsort(new_energies)], mo_coeff[..., occ:]], axis=-1
    )[None]


@functools.cache
def get_number_ao(scf: Scf):
    """
    Returns the number of atomic orbitals for each atom.

    Args:
    - scf: An instance of the Scf class.

    Returns:
    - Tuple[int]: A tuple containing the number of atomic orbitals for each atom.
    """
    counts = defaultdict(int)
    for label in scf._mol.ao_labels():
        counts[int(label.split(' ')[0])] += 1
    return tuple(counts[i] for i in range(max(counts) + 1))


def make_mask(
    nuc_orbitals: tuple[int], nuc_assign: list[Tuple[int, int]]
) -> np.ndarray:
    """
    Creates a boolean mask for the given nuclear orbitals and assignments.

    Args:
    - nuc_orbitals (Tuple[int]): A tuple containing the number of atomic orbitals for each atom.
    - nuc_assign (List[Tuple[int, int]]): A list of tuples containing the pairs of atoms in each orbital.

    Returns:
    - np.ndarray: A boolean mask for the given nuclear orbitals and assignments.
    """
    segments = np.cumsum((0,) + nuc_orbitals + (0,))
    mask = np.zeros((segments[-1], len(nuc_assign)), dtype=bool)
    for k, (i, j) in enumerate(nuc_assign):
        mask[segments[i] : segments[i + 1], k] = True
        mask[segments[j] : segments[j + 1], k] = True
    return mask


def get_ao_scores(nuc_orbitals: tuple[int]) -> np.ndarray:
    """
    Returns the weights for each atomic orbital.

    Args:
    - nuc_orbitals (Tuple[int]): A tuple containing the number of atomic orbitals for each atom.

    Returns:
    - np.ndarray: An array containing the scores of the atomic orbitals.
    """
    result = np.concatenate([np.arange(1, n + 1) for n in nuc_orbitals])
    return result


def get_mo_scores(types: np.ndarray) -> np.ndarray:
    """
    Returns the weights for each molecular orbital.

    Args:
    - types (jax.Array): An array containing the types of each molecular orbital.

    Returns:
    - np.ndarray: An array containing the scores of the molecular orbitals.
    """
    types = np.array(types)
    result = np.zeros_like(types)
    for i, v in enumerate(np.unique(types)):
        result[types == v] = i
    return result.max() - result + 1


def make_target(
    nuc_orbitals: tuple[int, ...], nuc_assign: list[tuple[int, int]], types: np.ndarray
) -> np.ndarray:
    """
    Creates a boolean array indicating which molecular orbitals correspond to each atomic orbital.

    Args:
    - nuc_orbitals (Tuple[int]): A tuple containing the number of atomic orbitals for each atom.
    - nuc_assign (List[Tuple[int, int]]): A list of tuples containing the indices of the atoms involved in each molecular orbital.
    - types (jax.Array): An array containing the types of each molecular orbital.

    Returns:
    - np.ndarray: A boolean array indicating which molecular orbitals correspond to each atomic orbital.
    """
    nuc_prio = defaultdict(lambda: defaultdict(list))
    for i, (a, t) in enumerate(zip(nuc_assign, types)):
        for x in np.unique(a):
            nuc_prio[x][t].append(i)
    nuc_prio = OrderedDict({k: OrderedDict(v) for k, v in nuc_prio.items()})
    offsets = np.cumsum([0, *nuc_orbitals[:-1]])
    result = np.zeros((np.sum(nuc_orbitals), len(nuc_assign)), dtype=bool)

    for atom, order in nuc_prio.items():
        off = offsets[atom]
        for vals in order.values():
            for i in range(len(vals)):
                result[off + i, vals] = True
            off += len(vals)
    return result


def to_mat(x: np.ndarray) -> np.ndarray:
    """
    Reshapes the input array into a matrix and scales it to have determinant 1.

    Args:
    - x (np.ndarray): An array to be reshaped.

    Returns:
    - np.ndarray: A matrix with determinant 1.
    """
    mat = x.reshape(int(np.sqrt(x.size)), -1)
    a = np.abs(np.linalg.det(mat)) ** (1 / mat.shape[0])
    return mat / a


def make_minbasis_loss(coeff: np.ndarray, target: np.ndarray):
    """
    Creates to loss functions for minimizing the difference between a predicted matrix and a target matrix.
    One for direct optimization and the second produces a penalty matrix for all possible permutations of the rows and columns of the input matrix..

    Args:
    - coeff (jax.Array): A matrix of coefficients.
    - target (np.ndarray): A boolean array indicating which molecular orbitals correspond to each atomic orbital.

    Returns:
    - Tuple[Callable[[np.ndarray], float], Callable[[np.ndarray], np.ndarray]]: A tuple containing two functions:
        - loss: A function that takes a matrix as input and returns a float representing the loss.
        - perm_loss: A function that takes a matrix as input and returns a matrix of losses for all possible permutations of the rows and columns of the input matrix.
    """

    def loss(x: np.ndarray) -> float:
        mat = to_mat(x)
        pred = coeff @ mat
        result = (pred[~target] ** 2).sum()
        result += ((1 - np.linalg.norm(np.where(target, pred, 0), axis=-2)) ** 2).sum()
        return result

    def perm_loss(x: np.ndarray) -> np.ndarray:
        mat = to_mat(x)
        pred = coeff @ mat
        n = coeff.shape[-1]
        result = np.zeros((n, n))
        for i in range(n):
            for j in range(n):
                result[i, j] = ((pred[:, i][~target[:, j]]) ** 2).sum()
                result[i, j] += (
                    1 - np.linalg.norm(pred[:, i][target[:, j]]).item()
                ) ** 2
        return result

    return loss, perm_loss


def make_generic_loss(
    coeff: jax.Array,
    mask: np.ndarray,
    ao_scores: np.ndarray,
    mo_scores: np.ndarray,
    score_weight: float = 0,
) -> Tuple[Callable[[np.ndarray], float], Callable[[np.ndarray], np.ndarray]]:
    """
    Creates two loss functions for minimizing the difference between a predicted matrix and a target matrix.
    One for direct optimization and the second produces a penalty matrix for all possible permutations of the rows and columns of the input matrix.

    This function works for arbitrary bases and is not limited to the minimal basis. But results in worse performance.

    Args:
    - coeff (jax.Array): A matrix of coefficients.
    - mask (np.ndarray): A boolean array indicating which molecular orbitals correspond to each atomic orbital.
    - ao_scores (np.ndarray): A 1D array of scores for each atomic orbital.
    - mo_scores (np.ndarray): A 1D array of scores for each molecular orbital.
    - score_weight (float): A weight for the penalty term in the loss function.

    Returns:
    - Tuple[Callable[[np.ndarray], float], Callable[[np.ndarray], np.ndarray]]: A tuple containing two functions:
        - loss: A function that takes a matrix as input and returns a float representing the loss.
        - perm_loss: A function that takes a matrix as input and returns a matrix of losses for all possible permutations of the rows and columns of the input matrix.
    """
    penalty_mask = np.where(mask, ao_scores[..., None] * mo_scores, 10000)

    def loss(x: np.ndarray) -> float:
        mat = to_mat(x)
        pred = coeff @ mat
        result = score_weight * np.sum(penalty_mask * np.abs(pred))  # type: ignore
        result += ((pred * ~mask) ** 2).sum()
        result += 1000 * ((1 - np.linalg.norm(pred, axis=0)) ** 2).sum()
        return result

    def perm_loss(x: np.ndarray) -> np.ndarray:
        mat = to_mat(x)
        n = mat.shape[-1]

        def compute_loss(i, j):
            test = np.copy(mat)
            test[:, [j, i]] = test[:, [i, j]]
            return loss(test)

        result = np.vectorize(compute_loss)(*np.where(np.ones((n, n)))).reshape(n, n)
        return result

    return loss, perm_loss


def globe_canonicalize_weights(scf: Scf, maxiter: int = 10) -> np.ndarray:
    """
    Canonicalizes the molecular orbital coefficients of a Hartree-Fock calculation.

    Args:
    - scf (Scf): A Hartree-Fock calculation object.
    - maxiter (int): The maximum number of iterations to perform.

    Returns:
    - np.ndarray: A matrix of canonical molecular orbital coefficients.
    """
    orbitals = get_number_ao(scf)
    conf = SystemConfigs((scf.molecule.spins,), (scf.molecule.charges,))
    axes = find_axes(scf.molecule.positions, conf)
    axes = np.array(axes).reshape(3, 3)
    _, orb_type, orb_assoc, _ = jtu.tree_map(
        np.array, get_orbitals(scf.molecule.positions @ axes, conf)
    )
    types = orb_type.tolist()
    nuc_idx = orb_assoc.tolist()
    n = (scf._mean_field.mo_occ > 0).sum()  # type: ignore
    coeff = scf.mo_coeff[0, :, :n]
    assert scf.mo_coeff.shape[0] == 1
    assert len(coeff.shape) == 2

    # New generic calc
    # mask = make_mask(orbitals, nuc_idx)
    # ao_scores = get_ao_scores(orbitals)
    # mo_scores = get_mo_scores(types)
    # score_mask = ao_scores[::-1][:, None] * mask
    # loss, perm_loss = make_new_loss(coeff, mask, ao_scores, mo_scores)

    # Minimal basis calc
    target = make_target(orbitals, nuc_idx, types)
    score_mask = np.arange(target.shape[0])[::-1][:, None] * target
    loss, perm_loss = make_minbasis_loss(coeff, target)

    mat = np.eye(n)
    best_loss = np.inf
    for i in range(maxiter):
        # find optimal permutation to minimize initial loss
        perm = linear_sum_assignment(perm_loss(mat))[1]
        init = mat[..., perm].reshape(-1)
        # Minimize objective
        x = minimize(loss, init)
        mat = np.array(to_mat(x.x))
        # Check for convergence
        if np.abs(best_loss - x.fun) < 1e-5:
            break
        else:
            best_loss = x.fun
            if i == maxiter:
                raise RuntimeError('Reached maxiter.')
    # Align signs
    result = coeff @ mat
    flips = np.sign((result * score_mask).sum(0))
    result *= flips
    return np.concatenate([result[None], scf.mo_coeff[..., n:]], axis=-1)
