import torch
import torch.utils.data
from torch import nn
from torch.nn import functional as F
import argparse
import numpy as np
import sklearn.datasets

import matplotlib.pyplot as plt
import seaborn as sns
import torchmetrics
import time
sns.set()
print(torch.cuda.is_available())

parser = argparse.ArgumentParser()

parser.add_argument('-md', '--model', \
        type=str, default='blnn', choices=['blnn','duq'])
parser.add_argument('-smooth', '--smooth', \
        type=float, default=1.0)
parser.add_argument('-convex', '--convex', \
        type=float, default=0.0)
parser.add_argument('-ep', '--num_epochs', \
        type=int, default=100)
parser.add_argument('-hd', '--h_dim', \
        type=int, default=10)
parser.add_argument('-od', '--out_dim', \
        type=int, default=20)
parser.add_argument('-cd', '--contiz_dim', \
        type=int, default=2)
parser.add_argument('-nhl', '--num_hidden_layers', \
        type=int, default=2)
parser.add_argument('-s', '--seed', \
        type=int, default=100)
parser.add_argument('-brute', '--brute_force', \
        type=int, default=1)
parser.add_argument('-comp', '--composite', \
        type=int, default=1)
args, unknown = parser.parse_known_args()

outfilename = args.model+"_"+str(args.convex)+"_"+str(args.smooth)+"_"+str(args.seed)+"_two_moons4"

print(outfilename)

with open(outfilename+'.log', 'w') as f:
    f.write(" ")

#BLNN
class PositiveLinear(nn.Module):
    def __init__(self, in_features, out_features, use_bias=False):
        super(PositiveLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        # if use_bias:
        #     self.bias = nn.Parameter(torch.Tensor(out_features))
        # else:
        #     self.bias = None
        self.reset_parameters()

    def reset_parameters(self):
        #nn.init.xavier_uniform_(self.weight, 0.01)
        nn.init.xavier_normal_(self.weight)

    def forward(self, input):
        #return nn.functional.linear(input, torch.clamp(self.weight.exp(), min=0.0, max=1e10))
        return nn.functional.linear(input, torch.clamp(self.weight, min=0.0, max=1e10))

class BLNN(nn.Module):
    """
    Bi-Lipschitz Neural network
    """
    def __init__(
        self,
        contiz_dim,
        h_dim,
        out_dim,
        px_num_hidden_layers,
        smooth,
        convex,
        num_embeddings,
        composite,
        args,
        USE_BATCHNORM=False):

        super(BLNN, self).__init__()
        self.contiz_dim = contiz_dim
        self.h_dim = h_dim
        self.out_dim = out_dim
        self.smooth=smooth
        self.convex=convex

        self.composite = bool(args.composite)
        self.qy_layers = []
        self.brute_force = bool(args.brute_force)
        self.device = torch.device("cuda")
        self.px1_num_hidden_layers = px_num_hidden_layers
        self.px2_num_hidden_layers = px_num_hidden_layers

        self.icnn1_Wy0 = nn.Linear(contiz_dim, h_dim).to(self.device)
        icnn1_Wy_layers = []
        icnn1_Wz_layers = []
        for i in range(self.px1_num_hidden_layers-1):
            icnn1_Wy_layers.append(nn.Linear(self.contiz_dim, self.h_dim).to(self.device))
            icnn1_Wz_layers.append(PositiveLinear(self.h_dim, self.h_dim).to(self.device))
        icnn1_Wy_layers.append(nn.Linear(self.contiz_dim, 1).to(self.device))
        icnn1_Wz_layers.append(PositiveLinear(self.h_dim, 1).to(self.device))

        self.icnn1_Wy_layers = nn.ModuleList(icnn1_Wy_layers)
        self.icnn1_Wz_layers = nn.ModuleList(icnn1_Wz_layers)

        if self.composite == True:

            self.icnn2_Wy0 = nn.Linear(self.out_dim, self.h_dim).to(self.device)
            icnn2_Wy_layers = []
            icnn2_Wz_layers = []
            for i in range(self.px2_num_hidden_layers-1):
                icnn2_Wy_layers.append(nn.Linear(self.out_dim, self.h_dim).to(self.device))
                icnn2_Wz_layers.append(PositiveLinear(self.h_dim, self.h_dim).to(self.device))
            icnn2_Wy_layers.append(nn.Linear(self.out_dim, 1).to(self.device))
            icnn2_Wz_layers.append(PositiveLinear(self.h_dim, 1).to(self.device))

            self.icnn2_Wy_layers = nn.ModuleList(icnn2_Wy_layers)
            self.icnn2_Wz_layers = nn.ModuleList(icnn2_Wz_layers)

        self.gamma = 0.99
        self.sigma = 0.3

        embedding_size = 10

        self.W = nn.Parameter(torch.normal(torch.zeros(embedding_size, num_embeddings, out_dim), 1).to(self.device))

        self.register_buffer('N', (torch.ones(num_embeddings) * 20).cuda())
        self.register_buffer('m', torch.normal(torch.zeros(embedding_size, num_embeddings).cuda(), 1))

        self.m = self.m * self.N.unsqueeze(0)

    def f1(self, input, with_output=False):
        h2 = nn.Softplus()(self.icnn1_Wy0(input))
        #h2 = nn.ReLU()(self.icnn1_Wy0(input))

        h1 = torch.sigmoid((self.icnn1_Wy0(input)).view(-1,self.h_dim,1))*self.icnn1_Wy0.weight
        #h1 = 0.5*torch.mul((torch.sign(self.icnn1_Wy0(input))+1).view(-1,self.h_dim,1),self.icnn1_Wy0.weight)
        for i in range(self.px1_num_hidden_layers):
            h2_n = nn.Softplus()(self.icnn1_Wz_layers[i](h2) + self.icnn1_Wy_layers[i](input))
            #h2_n= nn.ReLU()(self.icnn1_Wz_layers[i](h2) + self.icnn1_Wy_layers[i](input))

            h1_n = torch.sigmoid(self.icnn1_Wz_layers[i](h2) + self.icnn1_Wy_layers[i](input)).view(-1,self.icnn1_Wy_layers[i].weight.size()[0],1)*(torch.clamp(self.icnn1_Wz_layers[i].weight, min=0.0, max=1e10)@h1 + self.icnn1_Wy_layers[i].weight)
            #h1_n = 0.5*torch.mul((torch.sign(self.icnn1_Wz_layers[i](h2) + self.icnn1_Wy_layers[i](input))+1).view(-1,self.icnn1_Wy_layers[i].weight.size()[0],1),(torch.matmul(torch.clamp(self.icnn1_Wz_layers[i].weight, min=0.0, max=1e10),h1) + self.icnn1_Wy_layers[i].weight))

            h2 = h2_n
            h1 = h1_n
        grad_icnn = h1.view(-1,self.contiz_dim) + 1/self.smooth*input
        icnn_output = h2 + 1/(2*self.smooth)*(torch.norm(input,dim=1)**2).view(-1,1)
        if with_output:
            return grad_icnn, icnn_output
        else:
            return grad_icnn


    def f2(self,input, with_output=False):

        h2 = nn.Softplus()(self.icnn2_Wy0(input))
        #h2 = nn.ReLU()(self.icnn2_Wy0(input))

        h1 = torch.sigmoid((self.icnn2_Wy0(input)).view(-1,self.h_dim,1))*self.icnn2_Wy0.weight
        #h1 = 0.5*torch.mul((torch.sign(self.icnn2_Wy0(input))+1).view(-1,self.h_dim,1),self.icnn2_Wy0.weight)
        for i in range(self.px2_num_hidden_layers):
            h2_n = nn.Softplus()(self.icnn2_Wz_layers[i](h2) + self.icnn2_Wy_layers[i](input))
            #h2_n= nn.ReLU()(self.icnn2_Wz_layers[i](h2) + self.icnn2_Wy_layers[i](input))

            h1_n = torch.sigmoid(self.icnn2_Wz_layers[i](h2) + self.icnn2_Wy_layers[i](input)).view(-1,self.icnn2_Wy_layers[i].weight.size()[0],1)*(torch.clamp(self.icnn2_Wz_layers[i].weight, min=0.0, max=1e10)@h1 + self.icnn2_Wy_layers[i].weight)
            #h1_n = 0.5*torch.mul((torch.sign(self.icnn2_Wz_layers[i](h2) + self.icnn2_Wy_layers[i](input))+1).view(-1,self.icnn2_Wy_layers[i].weight.size()[0],1),(torch.matmul(torch.clamp(self.icnn2_Wz_layers[i].weight, min=0.0, max=1e10),h1) + self.icnn2_Wy_layers[i].weight))

            h2 = h2_n
            h1 = h1_n

        grad_icnn = h1.view(-1,self.out_dim) + 1/self.smooth*input
        icnn_output = h2 + 1/(2*self.smooth)*(torch.norm(input,dim=1)**2).view(-1,1)
        if with_output:
            return grad_icnn, icnn_output
        else:
            return grad_icnn

    def legendre(self, z, pr=False):
        with torch.no_grad():
            x = torch.ones(z.size()).cuda()
            step = 2*self.smooth
            if eval == True:
                max_it = 5000
            else:
                max_it = 500
            for i in range (max_it):
                grad =self.f1(x)
                x = x + step/(i+1) * (z-grad)

                if torch.mean(torch.norm(z-grad,dim=1))<0.001:
                    break

            if self.composite == True:
                z2 = torch.matmul(x+self.convex*z,torch.eye(self.contiz_dim,self.out_dim).cuda())
                x2 = torch.ones(z2.size()).cuda()
                for i in range (max_it):
                    grad =self.f2(x2)
                    x2 = x2 + step/(i+1) * (z2-grad)
                    if torch.mean(torch.norm(z2-grad,dim=1))<0.001:
                        break

        return x+self.convex*z, x2+self.convex*z2

    def embed(self, x, pr=False):
        f_ast1, f_ast2 = self.legendre(x)
        f_ast2.requires_grad_(True)
        x = torch.einsum('ij,mnj->imn', f_ast2, self.W)
        return f_ast1, f_ast2, x

    def bilinear(self, z):
        embeddings = (self.m / self.N.unsqueeze(0)).to(self.device)
        diff = z - embeddings.unsqueeze(0)
        y_pred = (- diff**2).mean(1).div(2 * self.sigma**2).exp()

        return y_pred

    def forward(self, x):
        x = x.view(-1,x.shape[-1])
        f_ast1, f_ast2, z = self.embed(x)
        y_pred = self.bilinear(z)

        return z,f_ast1, f_ast2, y_pred

    def update_embeddings(self, x, y):
        _, _, z = self.embed(x)

        # normalizing value per class, assumes y is one_hot encoded
        self.N = torch.max(self.gamma * self.N.to(self.device) + (1 - self.gamma) * y.sum(0), torch.ones_like(self.N).cuda())

        # compute sum of embeddings on class by class basis
        features_sum = torch.einsum('ijk,ik->jk', z, y)
        self.m = self.gamma * self.m + (1 - self.gamma) * features_sum

#Simple Neural Network for DUQ
class Model_bilinear(nn.Module):
    def __init__(self, features, num_embeddings):
        super().__init__()

        self.gamma = 0.99
        self.sigma = 0.3

        embedding_size = 10

        self.fc1 = nn.Linear(2, features)
        self.fc2 = nn.Linear(features, features)
        self.fc3 = nn.Linear(features, features)

        self.W = nn.Parameter(torch.normal(torch.zeros(embedding_size, num_embeddings, features), 1))

        self.register_buffer('N', torch.ones(num_embeddings) * 20)
        self.register_buffer('m', torch.normal(torch.zeros(embedding_size, num_embeddings), 1))

        self.m = self.m * self.N.unsqueeze(0)

    def embed(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        # i is batch, m is embedding_size, n is num_embeddings (classes)
        x = torch.einsum('ij,mnj->imn', x, self.W)
        return x

    def bilinear(self, z):
        embeddings = self.m / self.N.unsqueeze(0)

        diff = z - embeddings.unsqueeze(0)
        y_pred = (- diff**2).mean(1).div(2 * self.sigma**2).exp()

        return y_pred

    def forward(self, x):
        z = self.embed(x)
        y_pred = self.bilinear(z)

        return z, y_pred

    def update_embeddings(self, x, y):
        z = self.embed(x)

        # normalizing value per class, assumes y is one_hot encoded
        self.N = torch.max(self.gamma * self.N + (1 - self.gamma) * y.sum(0), torch.ones_like(self.N).cuda())

        # compute sum of embeddings on class by class basis
        features_sum = torch.einsum('ijk,ik->jk', z, y)

        self.m = self.gamma * self.m + (1 - self.gamma) * features_sum

np.random.seed(args.seed)
torch.manual_seed(args.seed)

l_gradient_penalty = 1.0

# Moons
noise = 0.1
X_train, y_train = sklearn.datasets.make_moons(n_samples=1500, noise=noise)
X_test, y_test = sklearn.datasets.make_moons(n_samples=200, noise=noise)

num_classes = 2
batch_size = 64

if args.model == "duq":
    model = Model_bilinear(40, num_classes).cuda()
elif args.model == "blnn":
    contiz_dim = 2
    h_dim = args.h_dim
    out_dim = args.out_dim
    gen_model_hidden_layers = args.num_hidden_layers
    smooth = args.smooth
    convex = args.convex
    composite = args.composite
    model = BLNN(contiz_dim, h_dim, out_dim, gen_model_hidden_layers, smooth, convex, num_classes, composite, args).cuda()

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
#optimizer = torch.optim.Adam(model.parameters(), lr=0.01) # Adam shows better convergence for DUQ+BLL


def calc_gradient_penalty(x, y_pred):
    gradients = torch.autograd.grad(
            outputs=y_pred,
            inputs=x,
            grad_outputs=torch.ones_like(y_pred),
            create_graph=True,
        )[0]


    gradients = gradients.flatten(start_dim=1)

    # L2 norm
    grad_norm = gradients.norm(2, dim=1)

    # Two sided penalty
    gradient_penalty = ((grad_norm - 1) ** 2).mean()

    # One sided penalty - down
#     gradient_penalty = F.relu(grad_norm - 1).mean()

    return gradient_penalty


print(args.convex, args.smooth)
accuracy= torchmetrics.Accuracy(task="multiclass", num_classes=2).cuda()


ds_train = torch.utils.data.TensorDataset(torch.from_numpy(X_train).float(), F.one_hot(torch.from_numpy(y_train)).float())
dl_train = torch.utils.data.DataLoader(ds_train, batch_size=batch_size, shuffle=True, drop_last=True)

ds_test = torch.utils.data.TensorDataset(torch.from_numpy(X_test).float(), F.one_hot(torch.from_numpy(y_test)).float())
dl_test = torch.utils.data.DataLoader(ds_test, batch_size=200, shuffle=False)
nparams = np.sum([p.numel() for p in model.parameters() if p.requires_grad])
if nparams >= 1000000:
    print(f"num_params: {1e-6*nparams:.1f}M")
else:
    print(f"num_params: {1e-3*nparams:.1f}K")

print("Brute", args.brute_force)
for epoch in range (args.num_epochs):
    for batch in dl_train:
        model.train()
        optimizer.zero_grad()

        x, y = batch[0].cuda(), batch[1].cuda()
        start = time.time()
        #if args.model == "duq":
        if args.model != "blnn":
            x.requires_grad_(True)
            z, y_pred = model(x)
            end = time.time()
            start = time.time()
            loss1 =  F.binary_cross_entropy(y_pred, y)
            loss = loss1
            if args.model == "duq":
                loss2 = l_gradient_penalty * calc_gradient_penalty(x, y_pred)
                loss += loss2
            start = time.time()
            loss.backward()
        else:
            z, f_ast1, f_ast2, y_pred = model(x)
            loss =  F.binary_cross_entropy(y_pred, y)
            loss.backward()
            nabla_loss= f_ast2.grad.clone().view(-1,1,model.out_dim)
            f_ast2.requires_grad_(False)
            D = torch.eye(model.out_dim,model.contiz_dim).cuda()
            Dt = torch.eye(model.contiz_dim,model.out_dim).cuda()
            hessian1 = torch.linalg.inv(torch.func.vmap(torch.func.jacrev(model.f1))((f_ast1-model.convex*x).view(-1,1,model.contiz_dim)).view(-1,model.contiz_dim,model.contiz_dim))
            hessian1 = hessian1.detach()
            hessian2 = torch.linalg.inv(torch.func.vmap(torch.func.jacrev(model.f2))((f_ast2-model.convex*torch.matmul(f_ast1,Dt)).view(-1,1,model.out_dim)).view(-1,model.out_dim,model.out_dim))
            hessian2 = hessian2.detach()
            v2 = torch.bmm(nabla_loss,hessian2)

            L2 = -torch.bmm(v2,model.f2(f_ast2-model.convex*torch.matmul(f_ast1,Dt)).view(-1,model.out_dim,1)).sum()

            v1 = torch.bmm(torch.matmul(torch.bmm(nabla_loss,hessian2+model.convex*torch.eye(model.out_dim).cuda()),D),hessian1)
            L1 = -torch.bmm(v1,model.f1(f_ast1-model.convex*x).view(-1,model.contiz_dim,1)).sum()
            L = L1+L2
            L.backward()
        optimizer.step()
        end = time.time()
        with torch.no_grad():
            model.update_embeddings(x.cuda(), y.cuda())
    with torch.no_grad():
        n = 0
        acc = 0
        bce = 0
        for batch in dl_test:
            x, y = batch
            if args.model != "blnn":
                z, y_pred = model(x.cuda())
            else:
                z ,_, _, y_pred = model(x.cuda())

            n += y.size(0)
            bce += F.binary_cross_entropy(y_pred,y.cuda()).item()*y.size(0)
            y = torch.argmax(y, dim=1)

            acc = accuracy(y_pred,y.cuda()).item() * y.size(0)
        acc = acc/n
        bce = bce/n
    with open(outfilename+'.log', 'a') as f:
        f.write("Test Results - Epoch: {} Acc: {:.4f} BCE: {:.2f}".format(epoch, acc, bce))
        f.write("\n")

    print("Test Results - Epoch: {} Acc: {:.4f} BCE: {:.2f}"
          .format(epoch, acc, bce))



domain = 3
x_lin = np.linspace(-domain+0.5, domain+0.5, 100)
y_lin = np.linspace(-domain, domain, 100)

xx, yy = np.meshgrid(x_lin, y_lin)

X_grid = np.column_stack([xx.flatten(), yy.flatten()])

X_vis, y_vis = sklearn.datasets.make_moons(n_samples=1000, noise=noise)
mask = y_vis.astype(bool)

with torch.no_grad():
    output = model(torch.from_numpy(X_grid).float().cuda())[-1]
    confidence = output.max(1)[0].cpu().numpy()


z = confidence.reshape(xx.shape)

plt.figure()
plt.contourf(x_lin, y_lin, z, cmap='cividis')

plt.scatter(X_vis[mask,0], X_vis[mask,1])
plt.scatter(X_vis[~mask,0], X_vis[~mask,1])

plt.savefig(outfilename+".png")
