import os

#os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'


import random
import numpy as np

import torch
import torch.utils.data
from torch.nn import functional as F
from torch import nn, Tensor
import time
import torch.nn.init as init

from utils.evaluate_ood import (
    get_fashionmnist_mnist_ood,
    get_fashionmnist_notmnist_ood,
)
from utils.datasets import FastFashionMNIST, get_FashionMNIST, FastFashionMNIST3
from utils.duq import DUQ
import torchmetrics

import argparse
print(torch.__version__)
parser = argparse.ArgumentParser()

parser.add_argument('-md', '--model', \
        type=str, default='blnn', choices=['blnn','duq'])
parser.add_argument('-op', '--optimizer', \
        type=str, default='GD', choices=['GD','LBFGS','Adam', 'Adagrad', 'RMSprop'])
parser.add_argument('-smooth', '--smooth', \
        type=float, default=1.0)
parser.add_argument('-convex', '--convex', \
        type=float, default=0.0)
parser.add_argument('-ep', '--num_epochs', \
        type=int, default=30)
parser.add_argument('-hd', '--h_dim', \
        type=int, default=10)
parser.add_argument('-od', '--out_dim', \
        type=int, default=100)
parser.add_argument('-nc', '--num_classes', \
        type=int, default=10)
parser.add_argument('-es', '--embedding_size', \
        type=int, default=100)
parser.add_argument('-cd', '--contiz_dim', \
        type=int, default=28*28)
parser.add_argument('-nhl1', '--num_hidden_layers1', \
        type=int, default=3)
parser.add_argument('-nhl2', '--num_hidden_layers2', \
        type=int, default=1)
parser.add_argument('-s', '--seed', \
        type=int, default=1)
parser.add_argument('-brute', '--brute_force', \
        type=int, default=1)
parser.add_argument('-r', '--resume', \
        type=int, default=0)
parser.add_argument('-learn', '--learn_params', \
        type=int, default=0)
parser.add_argument('-comp', '--composite', \
        type=int, default=1)
parser.add_argument('-b', '--batch_size', \
        type=int, default=128)

args, unknown = parser.parse_known_args()

outfilename = "log_txt3_fullsize"+args.model+str(args.convex)+"_"+str(args.smooth)+"_"+str(args.seed)


random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)


class PositiveLinear(nn.Module):
    def __init__(self, in_features, out_features, use_bias=False):
        super(PositiveLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        # if use_bias:
        #     self.bias = nn.Parameter(torch.Tensor(out_features))
        # else:
        #     self.bias = None
        self.reset_parameters()

    def reset_parameters(self):
        #nn.init.xavier_uniform_(self.weight, 0.01)
        nn.init.xavier_normal_(self.weight)

    def forward(self, input):
        #return nn.functional.linear(input, torch.clamp(self.weight.exp(), min=0.0, max=1e10))
        return nn.functional.linear(input, torch.clamp(self.weight, min=0.0, max=1e10))

class BLNN(nn.Module):
    """
    Bi-Lipschitz Neural network
    """
    def __init__(
        self,
        gamma,
        length_scale,
        args):

        super(BLNN, self).__init__()
        self.contiz_dim = args.contiz_dim
        self.h_dim = args.h_dim
        self.out_dim = args.out_dim
        self.composite = bool(args.composite)
        self.qy_layers = []
        self.brute_force = bool(args.brute_force)
        self.device = torch.device("cuda")
        self.px1_num_hidden_layers = args.num_hidden_layers1
        self.px2_num_hidden_layers = args.num_hidden_layers2
        self.opt = args.optimizer
        self.icnn1_Wy0 = nn.Linear(self.contiz_dim, self.h_dim).to(self.device)
        icnn1_Wy_layers = []
        icnn1_Wz_layers = []
        for i in range(self.px1_num_hidden_layers-1):
            icnn1_Wy_layers.append(nn.Linear(self.contiz_dim, self.h_dim).to(self.device))
            icnn1_Wz_layers.append(PositiveLinear(self.h_dim, self.h_dim).to(self.device))
        icnn1_Wy_layers.append(nn.Linear(self.contiz_dim, 1).to(self.device))
        icnn1_Wz_layers.append(PositiveLinear(self.h_dim, 1).to(self.device))

        self.icnn1_Wy_layers = nn.ModuleList(icnn1_Wy_layers)
        self.icnn1_Wz_layers = nn.ModuleList(icnn1_Wz_layers)

        self.gamma = gamma
        self.sigma = length_scale
        if bool(args.learn_params):
            self.convex = nn.Parameter(torch.zeros(1)+args.convex)
            self.smooth = nn.Parameter(torch.zeros(1)+args.smooth)
        else:
            self.smooth= args.smooth
            self.convex= args.convex

        self.W = nn.Parameter((torch.normal(torch.zeros(args.embedding_size, args.num_classes, self.out_dim), 0.05)).to(self.device))

        self.register_buffer('N', (torch.ones(args.num_classes) * 12))
        self.register_buffer('m', torch.normal(torch.zeros(args.embedding_size, args.num_classes), 1))

        self.m = self.m * self.N.unsqueeze(0)
        self.init_points1 = torch.ones((60000,args.contiz_dim)).to(self.device)
        self.init_points2 = torch.ones((60000,args.out_dim)).to(self.device)

    def f1(self, input, with_output=False, create_graph = True):
        with torch.enable_grad():
            input.requires_grad_(True)
            h2 = nn.Softplus()(self.icnn1_Wy0(input))
            #h2 = nn.ELU()((self.icnn1_Wy0(input)))
            #h2 = torch.pow(nn.ReLU()(self.icnn1_Wy0(input)),2)
            #h2 = nn.ReLU()(self.icnn1_Wy0(input))

            for i in range(self.px1_num_hidden_layers):
                h2_n = nn.Softplus()(self.icnn1_Wz_layers[i](h2) + self.icnn1_Wy_layers[i](input))
                #h2_n = nn.ELU()(self.icnn1_Wz_layers[i](h2) + self.icnn1_Wy_layers[i](input))
                #h2_n = torch.pow(nn.ReLU()(self.icnn1_Wz_layers[i](h2) + self.icnn1_Wy_layers[i](input)),2)
                #h2_n= nn.ReLU()(self.icnn1_Wz_layers[i](h2) + self.icnn1_Wy_layers[i](input))

                h2 = h2_n

            icnn_output = h2 + 1/(2*self.smooth)*(torch.norm(input,dim=1)**2).view(-1,1)
            grad_icnn = torch.autograd.grad(icnn_output, [input], torch.ones_like(icnn_output), create_graph=create_graph)[0]

            if with_output:
                return grad_icnn, icnn_output
            else:
                return grad_icnn


    def legendre(self, z, id = None, eval=False):
        if self.opt=="GD":
            if id == None:
                x1 = torch.ones(z.size()).cuda()
            else:
                x1 = self.init_points1[id]
            step = 2*self.smooth
            if eval == True:
                max_it = 5000
            else:
                max_it = 500
            for i in range (max_it):
                grad = self.f1(x1)
                x1 = x1 + step/(i+1) * (z-grad)
                if torch.mean(torch.norm(z-grad,dim=1))<0.001:
                    print("i",i)
                    with open(outfilename+"it1_"+self.opt+'.log', 'a') as f:
                        f.write(str(i)+"\n")
                    break
                if i==max_it-1:
                    with open(outfilename+"it1_"+self.opt+'.log', 'a') as f:
                        f.write(str(i)+"\n")
                    print("i",i)
            if id != None:
                self.init_points1[id] = x1.clone().detach()

            if self.composite == True:
                z2 = torch.matmul(x1+self.convex*z,torch.eye(self.contiz_dim,self.out_dim).cuda())
                x2 = z2-self.convex*z2
        else:
            learning_rate = 2*self.smooth
            x = torch.ones(z.size()).cuda()
            if eval == True:
                max_iter = 1000000
            else:
                max_iter = 1000
            tol = 1e-12

            if self.opt=="LBFGS":
                max_iter=1
            def closure1():
                with torch.no_grad():
                    grad,F = self.f1(x, with_output=True)
                    loss = torch.sum(F) - torch.sum(x * z)
                    x.grad = grad - z
                return loss

            if self.opt == "Adam":
                optim = torch.optim.Adam([x],lr=learning_rate,eps=tol)
            elif self.opt == "Adagrad":
                optim = torch.optim.Adagrad([x],lr=learning_rate,eps=tol)
            elif self.opt == "RMSprop":
                optim = torch.optim.RMSprop([x],lr=learning_rate,eps=tol)
            elif self.opt == "LBFGS":
                optim = torch.optim.LBFGS([x], lr=learning_rate, line_search_fn="strong_wolfe", max_iter=500, tolerance_grad=tol, tolerance_change=tol)

            for i in range (max_iter):
                optim.step(closure1)

            if self.composite == True:
                z2 = torch.matmul(x+self.convex*z,torch.eye(self.contiz_dim,self.out_dim).cuda())
                x2 = z2-self.convex*z2

        return x1+self.convex*z, x2+self.convex*z2

    def legendre_ng(self, z, id = None, eval=False):
        with torch.no_grad():
            if self.opt=="GD":
                if id == None:
                    x1 = torch.ones(z.size()).cuda()
                else:
                    x1 = self.init_points1[id]

                step = 2*self.smooth
                if eval == True:
                    max_it = 5000
                else:
                    max_it = 500
                for i in range (max_it):
                    with torch.enable_grad():
                        grad = self.f1(x1, create_graph = False)
                    x1 = x1 + step/(i+1) * (z-grad)
                    if torch.mean(torch.norm(z-grad,dim=1))<0.0001:
                        break
                if id != None:
                    self.init_points1[id] = x1.clone().detach()
                if self.composite == True:
                    z2 = torch.matmul(x1+self.convex*z,torch.eye(self.contiz_dim,self.out_dim).cuda())
                    x2 = z2-self.convex*z2

            else:
                learning_rate = 2*self.smooth
                x = torch.ones(z.size()).cuda()
                if eval == True:
                    max_iter = 1000000
                else:
                    max_iter = 1000
                tol = 1e-12

                if self.opt=="LBFGS":
                    max_iter=1
                def closure1():
                    with torch.no_grad():
                        grad,F = self.f1(x, with_output=True)
                        loss = torch.sum(F) - torch.sum(x * z)
                        x.grad = grad - z
                    return loss

                if self.opt == "Adam":
                    optim = torch.optim.Adam([x],lr=learning_rate,eps=tol)
                elif self.opt == "Adagrad":
                    optim = torch.optim.Adagrad([x],lr=learning_rate,eps=tol)
                elif self.opt == "RMSprop":
                    optim = torch.optim.RMSprop([x],lr=learning_rate,eps=tol)
                elif self.opt == "LBFGS":
                    optim = torch.optim.LBFGS([x], lr=learning_rate, line_search_fn="strong_wolfe", max_iter=500, tolerance_grad=tol, tolerance_change=tol)

                for i in range (max_iter):
                    optim.step(closure1)

                if self.composite == True:
                    z2 = torch.matmul(x+self.convex*z,torch.eye(self.contiz_dim,self.out_dim).cuda())
                    x2 = z2-self.convex*z2


        return x1+self.convex*z, x2+self.convex*z2

    def embed(self, x, id):
        x = x.view(-1,28,28).flatten(1)
        f_ast1, f_ast2 = self.legendre(x, id)
        x = torch.einsum('ij,mnj->imn', f_ast2, self.W)
        return f_ast1, f_ast2, x

    def bilinear(self, z):
        embeddings = (self.m / self.N.unsqueeze(0)).to(self.device)
        diff = z - embeddings.unsqueeze(0)
        y_pred = (- diff**2).mean(1).div(2 * self.sigma**2).exp()

        return y_pred

    def forward(self, x, id=None, extended=False):
        f_ast1, f_ast2, z = self.embed(x, id)
        y_pred = self.bilinear(z)
        if extended == True:
            return z,f_ast1, f_ast2, y_pred
        else:
            return y_pred

    def update_embeddings(self, x, y, ids):
        with torch.no_grad():
            x = x.view(-1,28,28).flatten(1)
            _, f_ast2 = self.legendre_ng(x, ids)
            z = torch.einsum('ij,mnj->imn', f_ast2, self.W)

            # normalizing value per class, assumes y is one_hot encoded
            self.N = self.gamma * self.N + (1 - self.gamma) * y.sum(0)

            # compute sum of embeddings on class by class basis
            features_sum = torch.einsum('ijk,ik->jk', z, y)
            self.m = self.gamma * self.m + (1 - self.gamma) * features_sum


def train_model(l_gradient_penalty, length_scale, final_model, args):
    dataset = FastFashionMNIST3(args.contiz_dim, args.out_dim, "data/", train=True, download=True)
    test_dataset = FastFashionMNIST3(args.contiz_dim, args.out_dim, "data/", train=False, download=True)
    idx = list(range(60000))
    random.shuffle(idx)

    if final_model:
        train_dataset = dataset
        val_dataset = test_dataset
    else:
        train_dataset = torch.utils.data.Subset(dataset, indices=idx[:55000])
        val_dataset = torch.utils.data.Subset(dataset, indices=idx[55000:])

    num_classes = args.num_classes
    embedding_size = args.embedding_size
    learnable_length_scale = False
    gamma = 0.999
    if args.model=="duq":
        model = DUQ(
            num_classes,
            embedding_size,
            learnable_length_scale,
            length_scale,
            gamma,
        )
    elif args.model == "blnn":
        model = BLNN(
            gamma,
            length_scale,
            args
        )
    if args.resume == 1:
        chpt = torch.load("model"+outfilename+".ckpt")
        model.load_state_dict(chpt)
    model = model.cuda()
    nparams = np.sum([p.numel() for p in model.parameters() if p.requires_grad])
    if nparams >= 1000000:
        with open(outfilename+'.log', 'a') as f:
            f.write(f"num_params: {1e-6*nparams:.1f}M \n")
        print(f"num_params: {1e-6*nparams:.1f}M")
    else:
        with open(outfilename+'.log', 'a') as f:
            f.write(f"num_params: {1e-3*nparams:.1f}K \n")
        print(f"num_params: {1e-3*nparams:.1f}K")

    optimizer = torch.optim.SGD(model.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4)

    milestones = [15, 25, 30]


    def calc_gradient_penalty(x, y_pred_sum):
        gradients = torch.autograd.grad(
            outputs=y_pred_sum,
            inputs=x,
            grad_outputs=torch.ones_like(y_pred_sum),
            create_graph=True,
            retain_graph=True,
        )[0]

        gradients = gradients.flatten(start_dim=1)

        # L2 norm
        grad_norm = gradients.norm(2, dim=1)

        # Two sided penalty
        gradient_penalty = ((grad_norm - 1) ** 2).mean()

        return gradient_penalty


    dl_train = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True
    )

    dl_val = torch.utils.data.DataLoader(
        val_dataset, batch_size=2000, shuffle=False, num_workers=0
    )

    dl_test = torch.utils.data.DataLoader(
        test_dataset, batch_size=2000, shuffle=False, num_workers=0
    )

    accuracy= torchmetrics.Accuracy(task="multiclass", num_classes=10).cuda()


    for epoch in range (args.num_epochs):
        #print("epoch:",epoch)
        it = 0
        for batch in dl_train:
            #print("it",it)
            it += 1
            model.train()
            optimizer.zero_grad()
            if args.model =="blnn":
                x, y, ids = batch[0], batch[1], batch[2]
                y = F.one_hot(y, num_classes=10).float()

                x, y = x.cuda(), y.cuda()

                y_pred = model(x, ids)
                loss = F.binary_cross_entropy(y_pred, y)
                loss.backward()
            elif args.model =="duq":
                x, y = batch[0], batch[1]

                y = F.one_hot(y, num_classes=10).float()

                x, y = x.cuda(), y.cuda()

                x.requires_grad_(True)

                y_pred = model(x)
                loss = F.binary_cross_entropy(y_pred, y)
                loss += l_gradient_penalty * calc_gradient_penalty(x, y_pred.sum(1))

                x.requires_grad_(False)
                print("loss", loss)

                loss.backward()

            else:
                x, y, ids = batch[0], batch[1], batch[2]

                y = F.one_hot(y, num_classes=10).float()

                x, y = x.cuda(), y.cuda()

                z, f_ast1, f_ast2, y_pred = model(x, ids, extended= True)
                loss =  F.binary_cross_entropy(y_pred, y)
                loss.backward()
                nabla_loss= f_ast2.grad.clone().view(-1,1,model.out_dim)
                f_ast2.requires_grad_(False)
                D = torch.eye(model.out_dim,model.contiz_dim).cuda()
                Dt = torch.eye(model.contiz_dim,model.out_dim).cuda()

                hessian1 = torch.linalg.inv(torch.func.vmap(torch.func.jacrev(model.f1))((f_ast1-model.convex*(F.max_pool2d(x,2,2).view(-1,14,14).flatten(1))).view(-1,1,model.contiz_dim)).view(-1,model.contiz_dim,model.contiz_dim))
                hessian1 = hessian1.detach()
                hessian2 = torch.linalg.inv(torch.func.vmap(torch.func.jacrev(model.f2))((f_ast2-model.convex*torch.matmul(f_ast1,Dt)).view(-1,1,model.out_dim)).view(-1,model.out_dim,model.out_dim))
                hessian2 = hessian2.detach()
                v2 = torch.bmm(nabla_loss,hessian2)
                L2 = -torch.bmm(v2,model.f2(f_ast2-model.convex*torch.matmul(f_ast1,Dt)).view(-1,model.out_dim,1)).sum()

                v1 = torch.bmm(torch.matmul(torch.bmm(nabla_loss,hessian2+model.convex*torch.eye(model.out_dim).cuda()),D),hessian1)
                L1 = -torch.bmm(v1,model.f1(f_ast1-model.convex*(F.max_pool2d(x,2,2).view(-1,14,14).flatten(1))).view(-1,model.contiz_dim,1)).sum()
                L = L1+L2
                L.backward()

            optimizer.step()
            with torch.no_grad():
                model.eval()
                if args.model == "blnn" or args.model == "blnnconv":
                    model.update_embeddings(x, y, ids)
                elif args.model == "duq":
                    model.update_embeddings(x, y)

        if epoch % 5 == 0:
            model.eval()
            with torch.no_grad():
                for batch in dl_test:
                    x, y = batch[0], batch[1]
                    y = F.one_hot(y, num_classes=10).float()

                    x, y = x.cuda(), y.cuda()
                    y_pred = model(x)
                    acc=accuracy(y_pred, torch.argmax(y, dim=1))
                    bce=F.binary_cross_entropy(y_pred,y)
            roc_auc_mnist = 0
            roc_auc_notmnist = 0
            with open(outfilename+'.log', 'a') as f:
                f.write(f"Test Results - Epoch: {epoch}, Acc: {acc:.4f}, BCE: {bce:.2f}, AUROC MNIST: {roc_auc_mnist:.2f}, AUROC NotMNIST: {roc_auc_notmnist:.2f} \n")
                f.write(f"Sigma: {model.sigma} \n")

            print(
                f"Test Results - Epoch: {epoch} "
                f"Acc: {acc:.4f} "
                f"BCE: {bce:.2f} "
                f"AUROC MNIST: {roc_auc_mnist:.2f} "
                f"AUROC NotMNIST: {roc_auc_notmnist:.2f} "
            )
    if args.resume ==0:
        torch.save(model.state_dict(), "model"+outfilename+".ckpt")

    model.eval()
    with torch.no_grad():
        for batch in dl_test:
            x, y = batch[0], batch[1]
            y = F.one_hot(y, num_classes=10).float()

            x, y = x.cuda(), y.cuda()


            y_pred = model(x)
            acc=accuracy(y_pred, torch.argmax(y, dim=1))
            bce=F.binary_cross_entropy(y_pred,y)
    _, roc_auc_mnist = get_fashionmnist_mnist_ood(model,outfilename)
    _, roc_auc_notmnist = get_fashionmnist_notmnist_ood(model,outfilename)

    with open(outfilename+'.log', 'a') as f:
        f.write(f"Test Results - Epoch: {epoch}, Acc: {acc:.4f}, BCE: {bce:.2f}, AUROC MNIST: {roc_auc_mnist:.2f}, AUROC NotMNIST: {roc_auc_notmnist:.2f} \n")
        f.write(f"Sigma: {model.sigma} \n")
    print(f"Test Results - Epoch: {epoch}, Acc: {acc:.4f}, BCE: {bce:.2f}, AUROC MNIST: {roc_auc_mnist:.2f}, AUROC NotMNIST: {roc_auc_notmnist:.2f}")

    val_accuracy = 0
    test_accuracy = 0
    return model, val_accuracy, test_accuracy


if __name__ == "__main__":
    start  = time.time()
    print(outfilename+'.log')
    with open(outfilename+'.log', 'w') as f:
        f.write(" ")
    with open(outfilename+"it1_"+args.optimizer+'.log', 'w') as f:
        f.write("\n")
    with open(outfilename+"it2_"+args.optimizer+'.log', 'w') as f:
        f.write("\n")
    _, _, _, fashionmnist_test_dataset = get_FashionMNIST()
    l_gradient_penalties = [0.05]
    length_scales = [0.1]

    repetition = 1  # Increase for multiple repetitions
    final_model = False  # set true for final model to train on full train set
    results = {}
    for l_gradient_penalty in l_gradient_penalties:
        for length_scale in length_scales:
            val_accuracies = []
            test_accuracies = []
            roc_aucs_mnist = []
            roc_aucs_notmnist = []
            for _ in range(repetition):
                with open(outfilename+'.log', 'a') as f:
                    f.write(" ### NEW MODEL ### \n")
                    f.write(str(args.convex)+" "+ str(args.smooth)+ "\n")
                model, val_accuracy, test_accuracy = train_model(
                    l_gradient_penalty, length_scale, final_model, args
                )
    end = time.time()
    print("time", end-start)
