import argparse

import ase.data
import ase.io
import torch

from mace import data
from mace.tools import torch_geometric, torch_tools, utils

import time


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument("--configs", help="path to XYZ configurations", required=True)
    parser.add_argument("--model", help="path to model", required=True)
    parser.add_argument(
        "--device",
        help="select device",
        type=str,
        choices=["cpu", "cuda"],
        default="cpu",
    )
    parser.add_argument(
        "--default_dtype",
        help="set default dtype",
        type=str,
        choices=["float32", "float64"],
        default="float64",
    )
    return parser.parse_args()


def main():
    args = parse_args()
    torch_tools.set_default_dtype(args.default_dtype)
    device = torch_tools.init_device(args.device)

    # Load model
    model = torch.load(f=args.model, map_location=args.device)
    model = model.to(
        args.device
    )  # shouldn't be necessary but seems to help with CUDA problems
    
    for param in model.parameters():
        param.requires_grad = False

    # Load data and prepare input
    print(args.configs)
    atoms_list = ase.io.read(args.configs, index=":")
    configs = [data.config_from_atoms(atoms) for atoms in atoms_list]

    z_table = utils.AtomicNumberTable([int(z) for z in model.atomic_numbers])

    data_loader = torch_geometric.dataloader.DataLoader(
        dataset=[
            data.AtomicData.from_config(
                config, z_table=z_table, cutoff=float(model.r_max)
            )
            for config in configs
        ],
        batch_size=100,
        shuffle=False,
        drop_last=False,
    )

    batch = next(iter(data_loader)).to(device)
    for _ in range(10):
        model(batch.to_dict(), compute_stress=False)
    
    torch.cuda.synchronize()
    start_time = time.time()
    for _ in range(100):
        model(batch.to_dict(), compute_stress=False)
    torch.cuda.synchronize()
    end_time = time.time()

    print('total_time', end_time - start_time)
    print('time_per_repetition', (end_time - start_time) / 100)
    print('time_per_structure', (end_time - start_time) / 100 / 100)

if __name__ == "__main__":
    main()
