# ruff: noqa: E402
# ruff: noqa: F841
import functools
import json
import logging
import os
from pathlib import Path

os.environ['JAX_DEFAULT_DTYPE_BITS'] = '32'
from collections import Counter, defaultdict
from typing import Tuple

import jax
import jax.flatten_util as jfu
import jax.tree_util as jtu
import numpy as np
import rich
from nix.utils.moving_average import RunningAverage
from nix.utils.tree_utils import tree_mul
from nix.utils.stopwatch import Stopwatch
from seml.experiment import Experiment
from seml_logger import Logger, add_default_observer_config, automain

import globe.systems as Systems
import globe.systems.property as Properties
from globe.systems.dataset import Dataset
from globe.trainer import Trainer
from globe.typing import SystemDefinitions

jax.config.update('jax_enable_x64', True)
jax.config.update('jax_default_matmul_precision', 'float32')

ex = Experiment()
add_default_observer_config(ex, notify_on_completed=True)


@ex.named_config
def explicit_pfaffian():
    globe = dict(
        orbital_type='ExplicitPfaffian', orbital_config=dict(pretrain_match_pfaffian=1)
    )


@ex.named_config
def agp():
    globe = dict(orbital_type='AGP')


@ex.named_config
def globe():
    globe = dict(orbital_type='ProductOrbitals')
    pretrain_localization = 'globe'


@ex.named_config
def wo_meta():
    globe = dict(meta_model='none')


@ex.named_config
def finetune():
    pretrain_epochs = 0
    thermalizing_steps = 2000


@ex.named_config
def pretrain():
    pretrain_epochs = 20000
    thermalizing_steps = 0
    training_epochs = 0


@ex.named_config
def profiler():
    thermalizing_steps = 0
    pretrain_epochs = 0
    training_epochs = 1
    profile = True


@ex.named_config
def debug():
    pretrain_epochs = 100
    thermalizing_steps = 100
    samples_per_batch = 512
    print_progress = True


@ex.named_config
def tmp():
    use_tensorboard = True
    folder = '~/logs/pfaffian_debug'
    print_progress = False


@ex.named_config
def final():
    use_tensorboard = True
    folder = '~/logs/pfaffian_final'
    print_progress = False


@ex.config
def config():
    globe = dict(
        wf_params=dict(
            ferminet=dict(
                hidden_dims=((256, 32), (256, 32), (256, 32), (256, 32)),
                embedding_dim=256,
                embedding_adaptive_weights=True,
                restricted=True,
            ),
            moon=dict(
                hidden_dims=(256,),
                use_interaction=False,
                update_before_int=4,
                update_after_int=0,
                adaptive_update_bias=False,
                local_frames=False,
                edge_embedding='MLPEdgeEmbedding',
                edge_embedding_params=dict(
                    # MLPEdgeEmbedding
                    out_dim=8,
                    hidden_dim=16,
                    activation='silu',
                    adaptive_weights=True,
                    envelope='exponential',
                ),
                embedding_dim=256,
                embedding_e_out_dim=256,
                embedding_int_dim=32,
                embedding_adaptive_weights=True,
                embedding_adaptive_norm=False,
            ),
            attentive=dict(
                head_dim=64,
                heads=4,
                layer=4,
                use_layernorm=True,
                include_spin_emb=True,
                # Override defaults
                activation='tanh',
                jastrow_mlp_layers=0,
            ),
            shared=dict(
                activation='silu',
                jastrow_mlp_layers=3,
                jastrow_include_pair=True,
                adaptive_sum_weights=False,
                adaptive_jastrow=False,
            ),
        ),
        gnn_params=dict(
            layers=((32, 64), (32, 64), (32, 64)),
            embedding_dim=64,
            edge_embedding='MLPEdgeEmbedding',
            edge_embedding_params=dict(
                # MLPEdgeEmbedding
                out_dim=16,
                hidden_dim=32,
                activation='silu',
                adaptive_weights=True,
                adaptive=False,
                envelope='bessel',
                # SphHarmEmbedding
                n_rad=6,
                max_l=3,
            ),
            orb_edge_params=dict(param_type='orbital'),
            out_mlp_depth=3,
            out_scale='log',
            aggregate_before_out=True,
            activation='silu',
            enable_groups=False,
            enable_segments=True,
            charges=None,
        ),
        orbital_type='Pfaffian',
        orbital_config=dict(
            separate_k=True,
            orbitals_per_atom=8,
            envelope_per_atom=8,
            bottleneck_envelopes=32,
            use_spin_mask=False,
            sigma_per_det=True,
            pi_per_det=True,
            correlation='dense',
            det_precision='float64',
            pretrain_match_steps=50,
            pretrain_match_lr=1.0,
            pretrain_match_orbitals=True,
            pretrain_match_pfaffian=1e-4,
            down_projection=None,
            minimal_orbitals={
                '1': 1,
                '2': 1,
                '3': 4,
                '4': 4,
                '5': 4,
                '6': 4,
                '7': 4,
                '8': 4,
                '9': 4,
                '10': 4,
                '11': 4,
                '12': 4,
                '13': 4,
                '14': 4,
                '15': 4,
                '16': 4,
                '17': 4,
                '18': 4,
            },
        ),
        determinants=16,
        full_det=True,
        shared_orbitals=True,
        meta_model='metanet',
        wf_model='moon',
    )
    mcmc_steps = 20
    preconditioner_args = dict(
        cg=dict(maxiter=100, decay_factor=0.99, center=True),
        spring=dict(
            decay_factor=0.99,
            norm_constraint=10,
            momentum=0,
            damping=1e-3,
            dtype='float64',
            only_use_wf=False,
        ),
        minsr=dict(center=True),
    )
    preconditioner = 'spring'
    loss = dict(
        clip_local_energy=5.0,
        limit_scaling=True,
        target_std=1.0,
    )
    lr = dict(init=0.1, delay=100, decay=1)
    damping = dict(init=1e-3, base=1e-4)
    operator = 'forward'
    batched_vmap_size = 64

    batch_size = 64
    batch_behavior = 'fill_random'
    samples_per_batch = 4096
    thermalizing_steps = 1000
    chkpts = (64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768)

    properties = (
        (
            'WidthScheduler',
            dict(init_width=0.02, target_pmove=0.525, update_interval=20),
        ),
        ('EnergyStdEMA', dict(decay=0.99)),
        ('EnergyEMA', dict(decay=0.99)),
    )

    restricted = True
    basis = 'STO-6G'

    pretrain_epochs = 20000
    training_epochs = 10000

    # Pretraining may be very expensive if we do not group similar structures
    # Set to 'sorted' to accelerate pretraining
    pretrain_order = 'random'
    pretrain_systems = None
    pretrain_localization = 'hf'

    chkpt = None

    run_name = None
    max_precision = 'float64'
    log_distributions = False

    systems = None
    profile = False

    dataset_seed = 42


def naming_fn(systems, run_name):
    if run_name is not None:
        return run_name
    molecules = Systems.get_molecules(systems)
    return '-'.join([f'{k}_{c}' for k, c in Counter(molecules).items()])


def thermalize_dataset(trainer, dataset, logger, steps):
    logging.info('Thermalizing')
    for batch in dataset:
        electrons, atoms, config, props = batch.to_jax()
        mol_params = trainer.p_get_mol_params(trainer.params, atoms, config)
        for _ in logger.tqdm(range(steps)):
            electrons, atoms, config, props = batch.to_jax()
            electrons, pmove = trainer.p_wf_mcmc(
                trainer.params,
                electrons,
                atoms,
                config,
                mol_params,
                trainer.next_key(),
                props['mcmc_width'],
            )
            batch.update_states(electrons, pmove=pmove)
            logger.set_postfix({'pmove': np.mean(pmove).item()})


@ex.config
def compute_charges_and_orbitals(globe, systems, chkpt):
    # Let's reduce the number of parameters if we can
    # set charges
    if globe['gnn_params']['charges'] is None and systems is not None:
        mols = Systems.get_molecules(systems)
        charges = tuple(sorted({c for m in mols for c in m.charges}))
        globe['gnn_params']['charges'] = charges
        del mols
        del charges

    # set orbitals per atom
    if globe['orbital_config']['minimal_orbitals'] is not None:
        max_charge = max(globe['gnn_params']['charges'])
        if isinstance(globe['orbital_config']['minimal_orbitals'], dict):
            extra_orbitals = globe['orbital_config']['minimal_orbitals'][
                str(max_charge)
            ]
        else:
            extra_orbitals = globe['orbital_config']['minimal_orbitals']
        globe['orbital_config']['orbitals_per_atom'] = (
            max_charge + 1
        ) // 2 + extra_orbitals
        del extra_orbitals
        del max_charge

    # let's load the globe config from the chkpt
    if chkpt is not None:
        with open(Path(chkpt).parent / 'config.json') as inp:
            globe = json.load(inp)['globe']
        del inp


@ex.named_config
def small():
    globe = dict(
        wf_params=dict(
            ferminet=dict(
                hidden_dims=((128, 16), (128, 16), (128, 16), (128, 16)),
                embedding_dim=128,
            ),
            moon=dict(
                hidden_dims=(128,),
                edge_embedding_params=dict(
                    # MLPEdgeEmbedding
                    out_dim=8,
                    hidden_dim=16,
                ),
                embedding_dim=128,
                embedding_e_out_dim=128,
                embedding_int_dim=16,
            ),
        ),
        gnn_params=dict(
            layers=((16, 32), (16, 32), (16, 32)),
            embedding_dim=32,
            edge_embedding='MLPEdgeEmbedding',
            edge_embedding_params=dict(
                # MLPEdgeEmbedding
                out_dim=8,
                hidden_dim=16,
            ),
            out_mlp_depth=3,
        ),
        orbital_config=dict(
            down_projection=None,
        ),
    )


@automain(ex, naming_fn, default_folder='~/logs/dev_globe')
def main(
    seed: int,
    pretrain_systems: SystemDefinitions,
    systems: SystemDefinitions,
    globe: dict,
    mcmc_steps: int,
    preconditioner_args: dict,
    preconditioner: str,
    loss: dict,
    lr: dict,
    damping: dict,
    operator: str,
    batched_vmap_size: int,
    batch_size: int,
    batch_behavior: str,
    samples_per_batch: int,
    thermalizing_steps: int,
    chkpts: tuple[int, ...] | set[int],
    properties: Tuple[Tuple[str, dict], ...],
    restricted: bool,
    basis: str,
    pretrain_order: str,
    pretrain_epochs: int,
    pretrain_localization: str,
    training_epochs: int,
    chkpt: str,
    run_name: str | None,
    max_precision: str,
    log_distributions: bool,
    dataset_seed: int,
    profile: bool,
    logger: Logger = None,  # type: ignore
):
    # compilation_cache.set_cache_dir(str(Path.home().absolute() / '.jax_cache'))
    rich.print(locals())
    if max_precision in ('float64', '64', 64):
        jax.config.update('jax_enable_x64', True)
    else:
        jax.config.update('jax_enable_x64', False)
    chkpts = set(chkpts)
    key = jax.random.PRNGKey(seed)
    n_devices = jax.device_count()
    key, subkey = jax.random.split(key)
    trainer = Trainer(
        subkey,
        globe,
        mcmc_steps,
        preconditioner,
        preconditioner_args,
        loss,
        lr,
        damping,
        operator=operator,
        batched_vmap_size=batched_vmap_size,
    )

    logging.info(f'Using the following devices: {jax.devices()}')

    # Initialize pretraining dataset
    key, subkey = jax.random.split(key)
    mols = Systems.get_molecules(
        pretrain_systems if pretrain_systems is not None else systems
    )
    for m in set(mols):
        logger.add_tag(str(m))
    # if we have fewer molecules update this variable accordingly
    eff_batch_size = min((batch_size, len(mols)))
    # we divide and multiple by n_devices to ensure that the batches can be parallized across multiple GPUs.
    samples_per_molecule = samples_per_batch // eff_batch_size // n_devices * n_devices
    dataset = Dataset(
        jax.random.PRNGKey(dataset_seed),
        mols,
        pretrain_order,
        batch_behavior,
        eff_batch_size,
        samples_per_molecule,
        tuple(
            functools.partial(Properties.PROPERTIES[prop_name], **kwargs)
            for prop_name, kwargs in properties
        ),
        restricted,
        basis,
        pretrain_localization,
    )
    logging.info(
        'Dataset info:\n'
        f'{len(mols)} structures\n'
        f'{eff_batch_size} batch size\n'
        f'{samples_per_molecule} samples per molecule\n'
        f'{globe["gnn_params"]["charges"]} charges\n'
        f'{globe["orbital_config"]["orbitals_per_atom"]} orbitals per atom'
    )

    # Initialization
    if chkpt is not None:
        logging.info(f'Loading checkpoint: {chkpt}')
        with open(chkpt, 'rb') as inp:
            trainer.load_params(inp.read())
        try:
            with open(chkpt.replace('.chk', '.dataset'), 'rb') as inp:
                dataset.deserialize(inp.read())
        except Exception as e:
            logging.warning(
                'Failed to load dataset from checkpoint\nThermalizing instead.'
            )
            # Thermalize
            if pretrain_epochs > 0:
                thermalize_dataset(trainer, dataset, logger, thermalizing_steps)
    else:
        electrons, atoms, config, _ = next(iter(dataset)).to_jax()
        trainer.init_params(electrons[0, 0], atoms, config)

    # Pretrain
    logging.info('Pretraining')
    logging.info(
        f'Parameters: {jfu.ravel_pytree(trainer.params)[0].size // jax.device_count()}'
    )
    step = 0
    electrons, atoms, config, props = next(
        iter(dataset)
    ).to_jax()  # This line is just for type checking
    for _ in logger.tqdm(range(pretrain_epochs)):
        for batch in dataset:
            electrons, atoms, config, props = batch.to_jax()
            if profile:
                jax.profiler.start_trace(f'{logger.log_dir}/profiler_prestep_{step}')
            losses, electrons, pmove, cache = trainer.pretrain_step(
                electrons, atoms, config, batch.mo_orbital_fns, props, batch.cache
            )
            if profile:
                jax.profiler.stop_trace()
            batch.update_states(electrons, pmove=pmove, cache=cache)
            if jtu.tree_reduce(
                np.logical_or, jtu.tree_map(lambda x: np.isnan(x).any(), losses)
            ).item():
                raise RuntimeError(f'Encountered NaNs in pretraining step {step}!')
            for k, v in losses.items():
                logger.add_scalar(f'pretrain/{k}', v.mean().item(), step=step)
            logger.add_scalar('pretrain/pmove', pmove.mean().item(), step=step)
            logger.add_scalar('pretrain/elec_std', electrons.std().item(), step=step)
            step += 1
            logger.set_postfix(jtu.tree_map(lambda x: np.mean(x).item(), losses))
        # Log parameters
        if (step % 1000 == 0 or step == pretrain_epochs - 1) and log_distributions:
            with logger.without_aim():
                logger.add_distribution_dict(trainer.params, 'pretrain', step=step)
                logger.add_distribution_dict(
                    trainer.mol_params(atoms, config), 'pretrain/mol_params', step=step
                )
    dataset.clear_cache()

    # Initialize VMC dataset
    if pretrain_systems is not None:
        key, subkey = jax.random.split(key)
        mols = Systems.get_molecules(systems)
        # if we have fewer molecules update this variable accordingly
        eff_batch_size = min((batch_size, len(mols)))
        # we divide and multiple by n_devices to ensure that the batches can be parallized across multiple GPUs.
        samples_per_molecule = (
            samples_per_batch // eff_batch_size // n_devices * n_devices
        )
        dataset = Dataset(
            jax.random.PRNGKey(dataset_seed),
            mols,
            'random',
            batch_behavior,
            eff_batch_size,
            samples_per_molecule,
            tuple(
                functools.partial(getattr(Properties, prop_name), **kwargs)
                for prop_name, kwargs in properties
            ),
            restricted,
            basis,
            pretrain_localization,
        )
    else:
        dataset.set_loader('random')

    # Thermalize
    if thermalizing_steps:
        thermalize_dataset(trainer, dataset, logger, thermalizing_steps)
    logger.store_blob('chk_pretrained.chk', trainer.serialize_params())
    logger.store_blob('chk_pretrained.dataset', dataset.serialize())

    # VMC training
    logging.info('VMC Training')
    stopwatch = Stopwatch()
    iterator = iter(dataset)
    epoch_data = None
    step_in_epoch = 0
    epoch = 0
    epoch_average = RunningAverage(500)
    for step in logger.tqdm(range(training_epochs)):
        # Load data
        try:
            batch = next(iterator)
        except StopIteration:
            iterator = iter(dataset)
            batch = next(iterator)
            # Log epoch averages
            epoch_data = tree_mul(epoch_data, 1 / step_in_epoch)
            logger.add_scalar_dict(
                epoch_data,
                'epoch',
                step=epoch,
            )
            logger.add_scalar_dict(
                epoch_average(epoch_data),
                'epoch_averaged',
                step=epoch,
            )
            epoch_data = None
            step_in_epoch = 0
            epoch += 1
        step_in_epoch += 1
        electrons, atoms, config, props = batch.to_jax()

        # Step
        if profile:
            jax.profiler.start_trace(f'{logger.log_dir}/profiler_step_{step}')
        electrons, mol_data, aux_data = trainer.step(electrons, atoms, config, props)
        if profile:
            jax.profiler.stop_trace()
        batch.update_states(electrons, **mol_data)
        # Move to CPU and reduce parallel GPU dimension
        aux_data = jtu.tree_map(lambda x: np.mean(x, 0), aux_data)
        if np.isnan(aux_data['E']).any():
            raise RuntimeError(f'Encountered NaNs in step {step}!')
        log_data = jtu.tree_map(np.mean, aux_data)
        if epoch_data is None:
            epoch_data = jtu.tree_map(np.zeros_like, log_data)
        epoch_data = jtu.tree_map(np.add, epoch_data, log_data)
        logger.add_scalar_dict(log_data, 'train', step=step)
        step_data = defaultdict(list)
        for mol, e, e_var in zip(batch.molecules, aux_data['E'], aux_data['E_var']):
            step_data[mol].append((e, e_var))
        if step % 10000 == 0 or step in chkpts:
            logging.info(f'Checkpoint {step}')
            logger.store_blob(f'chk_{step}.chk', trainer.serialize_params())
            logger.store_blob(f'chk_{step}.dataset', dataset.serialize())
            logger.store_data(
                f'chk_{step}',
                {
                    repr(mol): jtu.tree_map(
                        lambda x: np.array(x) if isinstance(x, jax.Array) else x,
                        mol.property_values,
                    )
                    for mol in dataset.molecules
                },
                use_json=True,
                use_pickle=False,
            )
        logger.add_scalar('train/epoch_time', stopwatch(), step=step)
        # Log per molecule data
        postfix = {'E': {}, 'E_std': {}}
        for m, data in step_data.items():
            data = np.array(data)
            E = data[:, 0].mean()
            E_std = data[:, 1].mean() ** 0.5
            postfix['E'][str(m)] = E
            postfix['E_std'][str(m)] = E_std
            logger.add_scalar('mol/E', E, step=step, context={'subset': f'{m}'})
            logger.add_scalar('mol/E_std', E_std, step=step, context={'subset': f'{m}'})
        logger.set_postfix(postfix)
        # Log parameters
        if (
            (step < 1000 and step % 100 == 0) or step % 1000 == 0
        ) and log_distributions:
            with logger.without_aim():
                logger.add_distribution_dict(
                    trainer.params, step=step, context={'subset': 'train'}
                )
                logger.add_distribution_dict(
                    trainer.mol_params(atoms, config),
                    'mol_params',
                    step=step,
                    context={'subset': 'train'},
                )
                # logger.add_distribution_dict(trainer.intermediates(electrons, atoms, config), step=epoch, context={'subset': 'train'})
    logging.info('Training complete')
    logger.store_blob('chk_final.chk', trainer.serialize_params())
    logger.store_blob('chk_final.dataset', dataset.serialize())
    training_results = {
        repr(mol): jtu.tree_map(
            lambda x: np.array(x) if isinstance(x, jax.Array) else x,
            mol.property_values,
        )
        for mol in dataset.molecules
    }

    logging.info('Returning')
    return {'training': training_results}
