import os
import re
import sys
import timm
import argparse

import torch
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

from engine import val_epoch
from transforms import get_transforms
from surgery import fix_attention_layer

sys.path.append(os.pardir)
from src.pruner import Pruner
from src.utils.random import fix_seed


def main():
    parser = argparse.ArgumentParser(description="One-shot pruning on ImageNet of timm models.")
    # Data params
    parser.add_argument(
        '--data_dir',
        type=str,
        required=True,
        help="Path to ImageNet.",
    )
    # Model params
    parser.add_argument(
        '--model',
        type=str,
        required=True,
        help="Model pruned",
    )
    parser.add_argument(
        '--pretrained',
        action='store_true',
        help="Whether to use pretrained model",
    )
    parser.add_argument(
        '--module_regex',
        type=str,
        required=True,
        help="Modules to prune",
    )
    parser.add_argument(
        '--blocks',
        type=str,
        default=None,
        help="Blocks name",
    )
    parser.add_argument(
        '--pre_modules',
        nargs="+",
        type=str,
        default=[],
        help="Name of modules before blocks",
    )
    # Dataloader params
    parser.add_argument(
        '--batch_size',
        default=128,
        type=int
    )
    parser.add_argument(
        '--val_batch_size',
        default=100,
        type=int
    )
    parser.add_argument(
        '--num_workers',
        default=4,
        type=int
    )
    # Sparsification params
    parser.add_argument(
        '--iterations',
        default=10,
        type=int
    )
    parser.add_argument(
        '--pruning_method',
        default="FastOBC",
        choices=["FastOBC", "OBC"],
        type=str
    )
    parser.add_argument(
        '--sparsity',
        default=0.50,
        type=float
    )
    parser.add_argument(
        '--alpha',
        default=0.0,
        type=float
    )
    parser.add_argument(
        '--calibration_dataset_size',
        default=1024,
        type=int
    )
    parser.add_argument(
        '--block_size',
        default=64,
        type=int
    )
    parser.add_argument(
        '--rel_damp',
        default=1e-2,
        type=float
    )
    parser.add_argument(
        '--rows_in_parallel',
        default=None,
        type=int
    )
    parser.add_argument(
        '--perturbation',
        default='gradient',
        choices=['gradient', 'interpolation'],
        type=str
    )
    parser.add_argument(
        '--sequential',
        action='store_true',
        help='Whether to prune sequentially'
    )
    parser.add_argument(
        '--cpu_offload',
        action='store_true',
        help='Whether to offload model to CPU.'
    )
    # misc params
    parser.add_argument(
        '--surgery',
        action='store_true',
        help='Whether to transform layers in a format compatible for calibration'
    )
    parser.add_argument(
        '--seed',
        default=0,
        type=int
    )
    # evaluation params
    parser.add_argument(
        '--eval_frequency',
        default=1,
        type=int
    )
    parser.add_argument(
        '--eval_only',
        default=0,
        type=int
    )
    # output params
    parser.add_argument(
        '--output_dir',
        default=None,
        type=str
    )
    args = parser.parse_args()
    run(args)


def run(args):
    fix_seed(args.seed)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    model = timm.create_model(args.model, pretrained=args.pretrained)
    if args.surgery:
        fix_attention_layer(model)
    model = model.to(device)

    transform_train, transform_test = get_transforms(model)
    # datasets
    train_dataset = ImageFolder(os.path.join(args.data_dir, 'train'), transform=transform_train)
    val_dataset = ImageFolder(os.path.join(args.data_dir, 'val'), transform=transform_test)
    # loaders
    train_dataloader = DataLoader(
        train_dataset, 
        batch_size=args.batch_size,
        num_workers=args.num_workers, 
        shuffle=True,
        pin_memory=True
    )
    # loaders
    val_loader = DataLoader(
        val_dataset, 
        batch_size=args.val_batch_size,
        num_workers=args.num_workers, 
        shuffle=False,
        pin_memory=True
    )

    if bool(args.eval_only): # evaluate the dense model, then exit
        print(f'Running evaluation-only on {args.model}')
        stats = val_epoch(model, val_loader, F.cross_entropy, device=device)
        with open(f'eval_only_{args.model}.txt', 'w') as f:
            f.write(str(stats))
        sys.exit(666)

    # init hooks and handles
    weights_orig = {}
    for module_name, module in model.named_modules():
        if re.search(args.module_regex, module_name):
            weights_orig[module_name] = module.weight.clone()

    if args.pruning_method == "FastOBC":
        obc_util_kwargs = {"block_size": args.block_size}
    elif args.pruning_method == "OBC":
        obc_util_kwargs = {"rows_in_parallel": args.rows_in_parallel}

    def calibration_loader():
        while True:
            for inputs, _ in train_dataloader:
                yield [inputs], {}

    pruner = Pruner(
        model,
        data_loader=calibration_loader(),
        module_regex=args.module_regex,
        weights_orig=weights_orig,
        pruning_method=args.pruning_method,
        rel_damp=args.rel_damp,
        obc_util_kwargs=obc_util_kwargs,
        sequential=args.sequential,
        cpu_offload=args.cpu_offload,
        blocks=args.blocks,
        pre_modules=args.pre_modules,
        max_samples=args.calibration_dataset_size
    )

    history = []
    print(f'{args.output_dir=}, {args.alpha=}')
    for i in range(args.iterations):
        print(f"Iteration {i}/{args.iterations} | {args.model}")
        pruner.prune(args.sparsity, args.alpha)

        if args.eval_frequency and (i + 1) % args.eval_frequency == 0:
            stats = val_epoch(model, val_loader, F.cross_entropy, device=device)
            stats['it'] = i + 1
            print('-' * 10)
            print(f"Loss:     {stats['val/loss']:4.3f}")
            print(f"Acc1:     {stats['val/acc1']:4.3f}")
            history.append(stats)

            if args.output_dir:
                os.makedirs(args.output_dir, exist_ok=True)
                torch.save(history, os.path.join(args.output_dir, 'history.pth'))

if __name__ == "__main__":
    sys.exit(main())  # pragma: no cover
