import argparse
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as dist

from utils  import *
from data   import *
from buffer import Buffer
from copy   import deepcopy
from pydoc  import locate
from model  import ResNet18, MLP
import numpy as np
import datetime

# arguments
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, choices=['split_mnist', 'permuted_mnist', 'split_cifar100_rahaf', 'split_cifar', 'split_cifar100_fb'], default = 'split_cifar')
parser.add_argument('--n_tasks', type=int, default=5)
parser.add_argument('--n_epochs', type=int, default=1)
parser.add_argument('--batch_size', type=int, default=10)
parser.add_argument('--buffer_batch_size', type=int, default=10)
parser.add_argument('--use_conv', type=int, default=1)
parser.add_argument('--samples_per_task', type=int, default=1000, help='if negative, full dataset is used')
parser.add_argument('--mem_size', type=int, default=5, help='controls buffer size') # mem_size in the tf repo.
parser.add_argument('--n_runs', type=int, default=3, help='number of runs to average performance')
parser.add_argument('--n_iters', type=int, default=1, help='training iterations on incoming batch')
parser.add_argument('--rehearsal', type=int, default=1, help='whether to replay previous data')
parser.add_argument('--hidden_dim', type=int, default=20)
parser.add_argument('--multiple_heads', action='store_true', help='multiple_gheads')
parser.add_argument('--compare_to_old_logits', action='store_true', help='for max_loss')
parser.add_argument('--compare_to_old_ratio', type=float, default=1.0, help='ratio of old loss')
parser.add_argument('--pseudo_targets', action='store_true')
parser.add_argument('--mixup', action='store_true', help='use manifold mixup')
parser.add_argument('--subsample', type=int, help='subsample', default=50)
parser.add_argument('--mixup_buf', action='store_true', help='use manifold mixup')
parser.add_argument('--name', type=str, default='', help='name_exp')
parser.add_argument('--lr', type=float, default=0.1)
parser.add_argument('--ratio', type=float, default=1.0)
parser.add_argument('--age', action='store_true',help='maximize the true loss instead of KL')
parser.add_argument('--logit_soft', action='store_true',help='maximize the true loss instead of KL')
#Added by Rahaf
parser.add_argument('--max_loss', action='store_true',help='maximize the true loss instead of KL')
parser.add_argument('--reuse_samples', action='store_true', help='reuse same samples over the iterations')
parser.add_argument('--diverse_retreival', action='store_true', help='retreive divese samples')
parser.add_argument('--entropy', type=float, default=0)
parser.add_argument('--validation', type=int, default=1,help='use validation')
args = parser.parse_args()

##################### Logs
time_stamp = str(datetime.datetime.now().isoformat())
name_log_txt = args.dataset+'_'+time_stamp + str(np.random.randint(0, 1000)) + args.name
name_log_txt=name_log_txt +'.log'
with open(name_log_txt, "a") as text_file:
    print(args, file=text_file)

# fixed for now
args.ignore_mask = False
args.input_size = (3, 32, 32)
args.device = 'cuda:0'
if args.dataset == 'split_cifar100_fb':
    args.n_classes = 100
    if args.validation:
        args.n_tasks = 3
    else:
        args.n_tasks = 17
    args.samples_per_task = 2500
   # args.multiple_heads = False
    buffer_size = args.n_tasks*args.mem_size*5
    args.n_classes = 5 * args.n_tasks
elif args.dataset == 'split_cifar100_rahaf':
    args.n_classes = 100
    args.n_tasks =  5
    args.samples_per_task = 5000
    args.multiple_heads = False
    buffer_size = args.n_tasks * args.mem_size * 20
elif args.dataset == 'split_mnist':
    args.n_classes = 10
    args.n_tasks = 5
    buffer_size = args.n_tasks * args.mem_size * 2
    args.input_size = (784,)
    args.use_conv = False
elif args.dataset == 'permuted_mnist':
    args.n_classes = 10
    args.n_tasks = 10
    buffer_size = args.mem_size * args.n_classes
    args.input_size = (784,)
    args.use_conv = False
    args.ignore_mask = True
else:
    args.n_classes = 10
    args.n_tasks = 5
    buffer_size = args.n_tasks*args.mem_size*2

args.gen = False

kl = dist.kl.kl_divergence
Cat = dist.categorical.Categorical

# fetch data
data = locate('data.get_%s' % args.dataset)(args)

seed=0
torch.manual_seed(seed)
# make dataloaders
train_loader, test_loader  = [CLDataLoader(elem, args, train=t) for elem, t in zip(data, [True, False])]


if args.use_conv:
    # fetch model and ship to GPU
    reset_model = lambda : ResNet18(args.n_classes, nf=args.hidden_dim).to(args.device)
else:
    reset_model = lambda: MLP(args).to(args.device)

reset_opt = lambda model : torch.optim.SGD(model.parameters(), lr=args.lr)
all_models = {}

CE = lambda student, teacher : F.kl_div(F.log_softmax(student, dim=-1), F.softmax(teacher.detach(), dim=-1), reduction='batchmean')
entropy_fn = lambda x : torch.sum(F.softmax(x) * F.log_softmax(x),dim=1)
# Train the model 
# -------------------------------------------------------------------------------


for run in range(args.n_runs):
    all_models[run] = []
    model = reset_model()
    opt = reset_opt(model)
    torch.manual_seed(seed+run)
    updated_inds = None
    grad_dims = []
    for param in model.parameters():
        grad_dims.append(param.data.numel())
    # build buffer
    buffer = Buffer(args,buffer_size=buffer_size)
    buffer.min_per_class = 0 
    
    if run == 0:
        print("number of model parameters:", sum([np.prod(p.size()) for p in model.parameters()]))
        print("buffer parameters:         ", np.prod(buffer.bx.size()))

    for task, loader in enumerate(train_loader):
        sample_amt = 0
        # opt = reset_opt(model)
        if task + 1 > args.n_tasks: break

        # iterate over samples from task
        for epoch in range(args.n_epochs):
            loss_ , correct, deno = 0., 0., 0.
            for i, (data, target) in enumerate(loader):
                if sample_amt > args.samples_per_task > 0: break
                sample_amt += data.size(0)
                
                data, target = data.to(args.device), target.to(args.device)
                # data = data.float()
                # data = data / 255.
                buffer_batch_size = min(buffer.x.size(0), args.buffer_batch_size)
                for iter in range(args.n_iters):
                    if iter==0 or not args.reuse_samples:
                        train_idx, track_idx = buffer.split(buffer_batch_size if args.mixup else 0)

                    input_x, input_y = data, target

                    if task>0 and (iter==0 or not args.reuse_samples):
                        subsample = torch.LongTensor(np.random.permutation(
                            np.arange(0, buffer.x.size(0)))).to(args.device)

                        # only sample ones not in the current task,
                        if not args.ignore_mask:
                            for t_n in torch.nonzero(loader.dataset.mask).squeeze(1):
                                subsample = subsample[buffer.t_y[subsample] != t_n]
                        if args.max_loss or args.age:
                            if args.subsample>0:
                                subsample = subsample[0:args.subsample]
                        if args.max_loss:
                            with torch.no_grad():
                                logits_track_pre = model(buffer.x[subsample])
                                b_task_sub = buffer.t[subsample]
                    lamb = 1
                    hid = model.return_hidden(input_x)

                    if train_idx.nelement() > 0 and args.mixup:
                        lamb = np.random.beta(2, 2)
                        hid_b = model.return_hidden(buffer.bx[train_idx])
                        hid = lamb * hid + (1 - lamb) * hid_b
                     
                    logits = model.linear(hid)
                    if args.multiple_heads:
                        logits = logits.masked_fill(loader.dataset.mask == 0, -1e9)
                    loss_a = F.cross_entropy(logits, input_y, reduction='none')
                    loss   = loss_a.sum() 

                    if train_idx.nelement() > 0 and args.mixup: 
                        loss_b = F.cross_entropy(logits, buffer.by[train_idx], reduction='none')
                    else:
                        loss_b = 0

                    loss = (lamb * loss_a + (1 - lamb) * loss_b).sum() / loss_a.size(0)
                        
                    pred = logits.argmax(dim=1, keepdim=True)
                    correct += pred.eq(input_y.view_as(pred)).sum().item() 
                    deno  += pred.size(0)
                    loss_ += loss.item()
            
                    opt.zero_grad()
                    loss.backward()

                    if args.max_loss and task>0:
                        grad_vector = get_grad_vector(model.parameters, grad_dims)
                        model_temp = get_future_step_parameters(model, grad_vector,grad_dims, lr=args.lr)


                    if task > 0 and args.rehearsal:
                        if iter == 0 or not args.reuse_samples:
                            if args.max_loss:
                                with torch.no_grad():
                                    buffer_hid = model_temp.return_hidden(buffer.x[subsample])
                                    logits_track_post = model_temp.linear(buffer_hid)

                                ###########******************#######################
                                    if args.multiple_heads:
                                        mask = torch.zeros_like(logits_track_post)
                                        mask.scatter_(1, loader.dataset.task_ids[b_task_sub], 1)
                                        assert mask.nelement() // mask.sum() == args.n_tasks
                                        logits_track_post = logits_track_post.masked_fill(mask == 0, -1e9)
                                        logits_track_pre = logits_track_pre.masked_fill(mask == 0, -1e9)



                                    pre_loss = F.cross_entropy(logits_track_pre, buffer.t_y[subsample], reduction="none")
                                    post_loss = F.cross_entropy(logits_track_post, buffer.t_y[subsample], reduction="none")
                                    scores = post_loss - pre_loss
                                    EN_logits = entropy_fn(logits_track_pre)
                                    if args.compare_to_old_logits and np.random.rand()<args.compare_to_old_ratio:
                                        old_loss = F.cross_entropy(buffer.logits[subsample],buffer.t_y[subsample],reduction="none")
                                        #pre_loss = torch.min(pre_loss, old_loss)
                                        updated_mask = pre_loss < old_loss
                                        updated_inds = updated_mask.data.nonzero().squeeze(1)


                                        scores = post_loss - torch.min(pre_loss, old_loss)
                                        #updated_inds = subsample[updated_inds]

                                    all_logits = args.entropy*EN_logits + 1.*scores
                                    if args.diverse_retreival:
                                        biggest_diff_ind=buffer.get_most_interfered(all_logits,buffer_hid,buffer_batch_size)
                                    else:
                                        biggest_diff_ind = all_logits.sort(descending=True)[1][:buffer_batch_size]
                                ##########*****************#########################
                                idx = subsample[biggest_diff_ind]
                            elif args.age:
                                all_ages = buffer.age[subsample]
                                biggest_diff_ind = all_ages.sort(descending=False)[1][:buffer_batch_size]
                                idx = subsample[biggest_diff_ind]
                               # print(idx)
                            else:
                                idx = subsample[:buffer_batch_size]

                        mem_x, mem_y, logits_y, b_task_ids = buffer.x[idx], buffer.t_y[idx], buffer.logits[idx], buffer.t[idx]

                       # opt.zero_grad()
                        ratio = args.ratio*float(args.buffer_batch_size)/float(args.batch_size)
                        if args.mixup_buf:
                            lamb = torch.FloatTensor(np.random.beta(2,2,buffer_batch_size)).to(args.device)
                            idx = torch.LongTensor([i for i in range(mem_x.size(0) - 1, -1, -1)])
                            mem_x = lamb[:,None]*model.return_hidden(mem_x)+(1.-lamb[:,None])*model.return_hidden(mem_x[idx])

                            mem_y = lamb[:,None]*onehot(mem_y, args.n_classes, args.device) + (1.-lamb[:,None])*onehot(mem_y[idx], args.n_classes,args.device)

                            logits_buffer = model.linear(mem_x)
                            if args.multiple_heads:
                                mask = torch.zeros_like(logits_buffer)
                                mask.scatter_(1, loader.dataset.task_ids[b_task_ids], 1)
                                assert mask.nelement() // mask.sum() == args.n_tasks
                                logits_buffer = logits_buffer.masked_fill(mask == 0, -1e9)
                            (ratio*naive_cross_entropy_loss(logits_buffer, mem_y)).backward()
                        else:
                            logits_buffer = model(mem_x)
                            if args.multiple_heads:
                                mask = torch.zeros_like(logits_buffer)
                                mask.scatter_(1, loader.dataset.task_ids[b_task_ids], 1)
                                assert mask.nelement() // mask.sum() == args.n_tasks
                                logits_buffer = logits_buffer.masked_fill(mask == 0, -1e9)
                            if args.pseudo_targets and np.random.randn()<0:
                                (ratio*naive_cross_entropy_loss(logits_buffer, F.softmax(logits_y,dim=-1))).backward()
                            else:
                                (ratio*F.cross_entropy(logits_buffer, mem_y)).backward()
                        ### lets try this
                       # F.cross_entropy(model(input_x), input_y).backward()

                        if updated_inds is not None:
                            buffer.logits[subsample[updated_inds]] = deepcopy(logits_track_pre[updated_inds])

                    opt.step()
                # add data to reservoir
                buffer.add_reservoir(input_x.detach(), target, logits.detach(), task)
            
            # buffer.display()
            print('Task {}\t Epoch {}\t Loss {:.6f}\t, Acc {:.6f}'.format(task, epoch, loss_ / i, correct / deno))

        all_models[run] += [deepcopy(model).cpu()]


# Test the model 
# -------------------------------------------------------------------------------
avgs = []
with torch.no_grad():
    accuracies = {}
    forgetting = {}
    for run in range(args.n_runs):
        accuracies[run] = {}  #this is going to be Tasks x Runs
        forgetting[run] = {}
        for task_model, model in enumerate(all_models[run]):
            model = model.eval().to(args.device)
            accuracies[run][task_model] = []
            forgetting[run][task_model] = []
            for task, loader in enumerate(test_loader):
                # iterate over samples from task
                loss_, correct, deno = 0., 0., 0.
                for i, (data, target) in enumerate(loader):
                    data, target = data.to(args.device), target.to(args.device)

                    logits = model(data)
                    if args.multiple_heads:
                        logits = logits.masked_fill(loader.dataset.mask == 0 , -1e9)
                    loss   = F.cross_entropy(logits, target)
                    pred = logits.argmax(dim=1, keepdim=True)
                    correct += pred.eq(target.view_as(pred)).sum().item()
                    deno += data.size(0) #pred.size(0)

                accuracies[run][task_model] += [correct / deno]
                if task<= task_model:
                    forgetting[run][task_model] += [max(accuracies[run][task])-accuracies[run][task_model][task] ]
            model = model.cpu()
        out = ''
        for i, acc in enumerate(accuracies[run][task_model]):
            out += '{} : {:.2f}\t'.format(i, acc)
        print(out)
        avgs += [sum(accuracies[run][task_model]) / len(accuracies[run][task_model])]
      #  print('Avg {:.5f}'.format(avgs[-1]), '\n')
        with open(name_log_txt, "a") as text_file:
            print(out, file=text_file)

#print('Max loss = {}. AVG over {} runs : {:.4f}'.format(args.max_loss, args.n_runs, sum(avgs) / len(avgs)))
for task_model in range(len(all_models[0])):
    avgs = []
    for run in range(args.n_runs):
        avgs += [sum(accuracies[run][task_model][:task_model + 1]) / len(accuracies[run][task_model][:task_model + 1])]

    avg = np.array(avgs).mean()
    std = np.array(avgs).std()
    with open(name_log_txt, "a") as text_file:
        print('After Task {} Max loss = {}. AVG over {} runs : {:.4f} +- {:.4f}'
              .format(task_model,args.max_loss, args.n_runs, avg, std*2./np.sqrt(args.n_runs)) , file=text_file)
all_forget2 = []
for run in range(args.n_runs):
    mat = np.array([accuracies[run][jj] for jj in range(len(all_models[0]))])
    all_forget2.append([max(mat[:-1, jj]) - mat[len(all_models[0])-1][jj] for jj in range(len(all_models[0])-1)])
all_forget2 = np.array(all_forget2)
out = '\n----Forgetting----\n'
for i,acc in enumerate(all_forget2.mean(0)):
    out += '{}:{:.2f}\t'.format(i,acc)

avg2 = all_forget2.mean(1).mean(0)
std2 = all_forget2.mean(1).std()
with open(name_log_txt,"a") as text_file:
    print(out, file=text_file)
    print('Forgetting Total Max loss = {}. AVG over {} runs : {:.4f} +- {:.4f}'
          .format(args.max_loss, args.n_runs, avg2, std2*2./np.sqrt(args.n_runs)), file=text_file)
