import logging

import argparse

import ase

import torch.nn.functional

from mace import data, modules
from mace.tools import torch_tools, utils
from mace.tools.scripts_utils import create_error_table


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument("--configs", help="path to XYZ configurations", required=True)
    parser.add_argument("--model", help="path to model", required=True)
    parser.add_argument(
        "--device",
        help="select device",
        type=str,
        choices=["cpu", "cuda"],
        default="cpu",
    )
    parser.add_argument(
        "--default_dtype",
        help="set default dtype",
        type=str,
        choices=["float32", "float64"],
        default="float64",
    )
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    torch_tools.set_default_dtype(args.default_dtype)
    device = torch_tools.init_device(args.device)

    # Load model
    model = torch.load(f=args.model, map_location=args.device)
    model = model.to(
        args.device
    )  # shouldn't be necessary but seems to help with CUDA problems
    
    for param in model.parameters():
        param.requires_grad = False

    loss_fn = modules.WeightedEnergyForcesLoss(
        energy_weight=1.0, forces_weight=1.0
    )

    output_args = {
        "energy": True,
        "forces": True,
        "virials": False,
        "stress": True,
        "dipoles": False,
    }

    # load model
    model = torch.load(f=args.model, map_location=args.device)
    
    # Load data and prepare input
    atoms_list = ase.io.read(args.configs, index=":")
    configs = [data.config_from_atoms(atoms) for atoms in atoms_list]
    
    z_table = utils.AtomicNumberTable([int(z) for z in model.atomic_numbers])

    # Evaluation on test datasets
    logging.info("Computing metrics for test set")

    all_collections = [
        ("test", configs),
    ]
    
    table = create_error_table(
        # table_type='TotalRMSE',
        table_type='PerAtomRMSE',
        all_collections=all_collections,
        z_table=z_table,
        r_max=5.0, # it should be possible to extract from the model... 
        valid_batch_size=10,
        model=model,
        loss_fn=loss_fn,
        output_args=output_args,
        log_wandb=False,
        device=device,
    )
    print(str(table))

    logging.info("Done")


if __name__ == "__main__":
    main()
