import torch

from utils.utils import *
import training.models as models
from training.data_utils import get_dataloaders, shapes_dict
import hydra

import torch.nn.functional as F

from tqdm import tqdm

def error_barrier(args, model1, model2, interpolation_model_template, train_loader, val_loader, test_loader):
    alphas = torch.linspace(0, 1, steps=args.steps)
    criterion = lambda logits, y: F.cross_entropy(logits, y)

    state_dict1 = model1.state_dict()
    state_dict2 = model2.state_dict()

    res = {'all_train_err': [],
           'max_train_err': 0,
           'all_train_loss': [],
           'max_train_loss': 0,

           'all_val_err': [],
           'max_val_err': 0,
           'all_val_loss': [],
           'max_val_loss': 0,

           'all_test_err': [],
           'max_test_err': 0,
           'all_test_loss': [],
           'max_test_loss': 0}

    for a in tqdm(alphas, desc='Interpolating...', total=args.steps):
        interpolation_model_template.load_state_dict({k:a*state_dict1[k] + (1-a)*state_dict2[k] for k in state_dict1})

        # train
        total_error, total_loss, total_samples = 0., 0., 0
        for x, y, _, _ in train_loader:
            x, y = x.to(args.device), y.to(args.device)
            outputs = interpolation_model_template(x)
            loss = criterion(outputs, y)
            total_loss += loss.item()
            _, y_hat = torch.max(outputs.data, 1)
            total_samples += y.size(0)
            total_error += (y_hat != y).sum().item()
        error_rate = 100*total_error/total_samples
        average_loss = total_loss/len(train_loader)

        res['all_train_err'].append(error_rate)
        res['max_train_err'] = max(error_rate, res['max_train_err'])
        res['all_train_loss'].append(average_loss)
        res['max_train_loss'] = max(average_loss, res['max_train_loss'])

        # val
        total_error, total_loss, total_samples = 0., 0., 0
        for x, y, _, _ in val_loader:
            x, y = x.to(args.device), y.to(args.device)
            outputs = interpolation_model_template(x)
            loss = criterion(outputs, y)
            total_loss += loss.item()
            _, y_hat = torch.max(outputs.data, 1)
            total_samples += y.size(0)
            total_error += (y_hat != y).sum().item()
        error_rate = 100*total_error/total_samples
        average_loss = total_loss/len(train_loader)


        res['all_val_err'].append(error_rate)
        res['max_val_err'] = max(error_rate, res['max_val_err'])
        res['all_val_loss'].append(average_loss)
        res['max_val_loss'] = max(average_loss, res['max_val_loss'])

        # test
        total_error, total_loss, total_samples = 0., 0., 0
        for x, y, _, _ in test_loader:
            x, y = x.to(args.device), y.to(args.device)
            outputs = interpolation_model_template(x)
            loss = criterion(outputs, y)
            total_loss += loss.item()
            loss = criterion(outputs, y)
            total_loss += loss.item()
            _, y_hat = torch.max(outputs.data, 1)
            total_samples += y.size(0)
            total_error += (y_hat != y).sum().item()
        error_rate = 100*total_error/total_samples
        average_loss = total_loss/len(train_loader)

        res['all_test_err'].append(error_rate)
        res['max_test_err'] = max(error_rate, res['max_test_err'])
        res['all_test_loss'].append(average_loss)
        res['max_test_loss'] = max(average_loss, res['max_test_loss'])

    return res

@hydra.main(version_base=None, config_path='.', config_name='error_barrier_params')
def main(cfg):
    # Remove hydra logger
    os.remove(f'{os.path.splitext(os.path.basename(__file__))[0]}.log')
    os.umask(0)

    # Argument intake and validation
    args = DotDict(cfg)

    assert args.exp_name1 and args.run_name1 and args.exp_name1 and args.run_name1, \
           'Must specify experiment and run name for 2 models to compute their error barrier.'
    
    args.ckpt_path1 = os.path.join(args.team_path, args.exp_name1, args.run_name1, 'ckpt')
    args.ckpt_path2 = os.path.join(args.team_path, args.exp_name2, args.run_name2, 'ckpt')

    # Set environment meta-config (gpu, seed, etc.)
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    print(f'Using GPU: {args.gpu}')
    print(f'GPU memory available: {(torch.cuda.get_device_properties("cuda").total_memory / 10**9):.2f} GB')

    args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    #args.device = 'cpu'

    set_all_seeds(args.seed)

    run_args1 = get_wandb_args(args.team_path, args.exp_name1, args.run_name1)
    run_args2 = get_wandb_args(args.team_path, args.exp_name2, args.run_name2)

    assert run_args1.model == run_args2.model, \
        'Cannot compute the barrier between 2 different model architectures!'
    
    run_args1.us, run_args1.aus = False, False
    train_loader, val_loader, test_loader, num_classes = get_dataloaders(run_args1)

    model1 = models.get_model(run_args1.model,
                            num_classes,
                            False,
                            shapes_dict[run_args1.dataset],
                            run_args1.model_width,
                            'relu',
                            droprate=run_args1.droprate).to(args.device)
    model2 = models.get_model(run_args2.model,
                            num_classes,
                            False,
                            shapes_dict[run_args2.dataset],
                            run_args1.model_width,
                            'relu',
                            droprate=run_args2.droprate).to(args.device)
    interpolation_model_template = models.get_model(run_args1.model,
                            num_classes,
                            False,
                            shapes_dict[run_args1.dataset],
                            run_args1.model_width,
                            'relu',
                            droprate=run_args1.droprate).to(args.device)
    
    model_dict1 = torch.load(os.path.join(args.ckpt_path1, 'epochs=200.pt'))['last']
    model1.load_state_dict({k: v for k, v in model_dict1.items() if 'model_preact_hl1' not in k})
    model1.eval()

    model_dict2 = torch.load(os.path.join(args.ckpt_path2, 'epochs=200.pt'))['last']
    model2.load_state_dict({k: v for k, v in model_dict2.items() if 'model_preact_hl1' not in k})
    model2.eval()

    interpolation_model_template.eval()

    barrier_stats = error_barrier(args, model1, model2, interpolation_model_template, train_loader, val_loader, test_loader)

    os.makedirs(os.path.dirname(args.output_dir), exist_ok=True)
    torch.save(barrier_stats, args.output_dir)

if __name__ == '__main__':
    main()
