import torch
import os.path as osp
import aug.augmentors as A

import numpy as np
from sklearn.utils._testing import ignore_warnings
from sklearn.exceptions import ConvergenceWarning

from torch import nn
from tqdm import tqdm
from torch.optim import Adam
from aug.eval import get_split, SVMEvaluator
from torch_geometric.nn import GCNConv, global_add_pool
from torch_geometric.data import DataLoader
from torch_geometric.datasets import TUDataset

import copy

import ot
from ot.gromov import semirelaxed_gromov_wasserstein, semirelaxed_fused_gromov_wasserstein, semirelaxed_fused_gromov_wasserstein2, gromov_wasserstein, fused_gromov_wasserstein
from torch_geometric.utils import to_scipy_sparse_matrix, to_dense_adj
# from torchmetrics.functional import pairwise_cosine_similarity
from geomloss import SamplesLoss  # See also ImagesLoss, VolumesLoss
import numpy as np
from ot.gromov._utils import init_matrix, gwloss, gwggrad, init_matrix_semirelaxed, tensor_product
from ot.backend import get_backend
import torch.nn.functional as F

device = torch.device('cuda:0')

class GConv(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers):
        super(GConv, self).__init__()
        self.layers = nn.ModuleList()
        self.activation = nn.PReLU(hidden_dim)
        for i in range(num_layers):
            if i == 0:
                self.layers.append(GCNConv(input_dim, hidden_dim))
            else:
                self.layers.append(GCNConv(hidden_dim, hidden_dim))

    def forward(self, x, edge_index, batch):
        z = x
        zs = []
        for conv in self.layers:
            z = conv(z, edge_index)
            z = self.activation(z)
            zs.append(z)
        gs = [global_add_pool(z, batch) for z in zs]
        g = torch.cat(gs, dim=1)
        return z, g


class FC(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(FC, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, output_dim),
            nn.ReLU(),
            nn.Linear(output_dim, output_dim),
            nn.ReLU(),
            nn.Linear(output_dim, output_dim),
            nn.ReLU()
        )
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.fc(x) + self.linear(x)


class Encoder(torch.nn.Module):
    def __init__(self, gcn1, gcn2, mlp1, mlp2, aug1, aug2):
        super(Encoder, self).__init__()
        self.gcn1 = gcn1
        self.gcn2 = gcn2
        self.mlp1 = mlp1
        self.mlp2 = mlp2
        self.aug1 = aug1
        self.aug2 = aug2

    def forward(self, x, edge_index, batch, mode):
        x1, edge_index1, edge_weight1, _ = self.aug1(x, edge_index)
        x2, edge_index2, edge_weight2, _ = self.aug2(x, edge_index)

        z1, g1 = self.gcn1(x1, edge_index1, batch)
        z2, g2 = self.gcn2(x2, edge_index2, batch)
        a1, a2 = [self.mlp1(h) for h in [z1, z2]]
        g1, g2 = [self.mlp2(g) for g in [g1, g2]]

        C1 = torch.squeeze(to_dense_adj(edge_index1))
        F1 = x1
        N1l = x1.shape[0]
        N1r = x1.shape[1]
        h1 = ot.unif(N1l, type_as=x1)

        C2 = torch.squeeze(to_dense_adj(
            edge_index2, max_num_nodes=x2.shape[0]))
        F2 = x2
        N2l = x2.shape[0]
        N2r = x2.shape[1]
        h2 = ot.unif(N2l, type_as=x2)

        Mp = ot.dist(x1, x2, metric='euclidean')
        Mb = ot.dist(a1, a2, metric='euclidean')

        loss = SamplesLoss(loss='sinkhorn', p=2, debias=True, blur=0.5**(1 / 2), backend='tensorized') # "blur" of geomloss is eps^(1/p).

        if mode == 'train':
            
            # alpha = 0.5
            # P, log_23 = semirelaxed_gromov_wasserstein(Mp1 + C1, Mp2 + C2, h1, symmetric=True, log=True, G0=None)
            # P, log_32 = semirelaxed_gromov_wasserstein(C2+Mp2, C1+Mp1, h2, symmetric=None, log=True, G0=None)
            # P, logP = semirelaxed_fused_gromov_wasserstein(Mp, C1, C2, h1, symmetric=True, alpha=alpha, log=True, G0=None)
            
            # nx = get_backend(h1, C1, C2)
            # constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, h1, loss_fun='square_loss', nx=nx)
            # OM = torch.ones(N1l, N2l).to(device)
            # OM = OM / (N1l * N2l)
            # qOneM = nx.sum(OM, 0)
            # ones_p = nx.ones(h1.shape[0], type_as=h1)
            # marginal_product = nx.outer(ones_p, nx.dot(qOneM, fC2t))
            # Mp2 = tensor_product(constC + marginal_product, hC1, hC2, P, nx=nx)
            # Mp2 = F.normalize(Mp2)
            # Mp = (1-alpha) * Mp + (alpha) * Mp2
            # P = rho * P1 + (1-rho) * P2

            # print('a1=', a1)

            P = ot.emd(h1, h2, Mp)
            B = ot.emd(h1, h2, Mb)
            # P.requires_grad=True
            # B.requires_grad=True
        else:
            P = 0
            B = 0

        return a1, a2, g1, g2, P, B, Mp, Mb


def train(encoder_model, dataloader, optimizer, rho=1):
    encoder_model.train()
    epoch_loss = 0
    i=0
    for data in dataloader:
        # print('i=', i)
        i=i+1
        data = data.to(device)
        optimizer.zero_grad()

        if data.x is None:
            num_nodes = data.batch.size(0)
            data.x = torch.ones((num_nodes, 1), dtype=torch.float32, device=data.batch.device)

        a1, a2, g1, g2, P, B, Mp, Mb = encoder_model(data.x, data.edge_index, data.batch, mode='train')

        # reduction = 'batchmean', log_target=True
        kl_loss = nn.KLDivLoss(reduction='batchmean')
        loss = kl_loss(Mp, Mb)

        # # Define a Sinkhorn (~Wasserstein) loss between sampled measures
        sloss = SamplesLoss(loss="sinkhorn", p=2, blur=.05)
        # loss = sloss(Mp, Mb)  # By default, use constant weights = 1/number of samples
        # g_x, = torch.autograd.grad(loss, [x])  # GeomLoss fully supports autograd!

        loss = rho * loss + torch.linalg.matrix_norm(P - B, ord='fro')

        # loss = contrast_model(a1=a1, a2=a2, g1=g1, g2=g2, batch=data.batch)

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
    return epoch_loss

@ignore_warnings(category=ConvergenceWarning)
def test(encoder_model, dataloader):
    encoder_model.eval()
    x = []
    y = []
    for data in dataloader:
        data = data.to(device)
        if data.x is None:
            num_nodes = data.batch.size(0)
            data.x = torch.ones((num_nodes, 1), dtype=torch.float32, device=data.batch.device)

        _, _, g1, g2, _, _, _, _ = encoder_model(data.x, data.edge_index, data.batch, mode = 'test')

        x.append(g1 + g2)
        y.append(data.y)
    
    x = torch.cat(x, dim=0)
    y = torch.cat(y, dim=0)

    split = get_split(num_samples=x.size()[0], train_ratio=0.8, test_ratio=0.1)
    result = SVMEvaluator(linear=True)(x, y, split)

    return result


def main():
    
    path = 'datasets'
    dataset = TUDataset(path, name='PROTEINS')

    dataloader = DataLoader(dataset, batch_size=16)

    input_dim = max(dataset.num_features, 1)

    aug1 = A.Identity()
    aug2 = A.Compose([A.EdgePerturbation(pe=0.3), A.RWSampling(
        use=False, num_seeds=1000, walk_length=1000), A.FeatureMasking(pf=0.2), A.NodeDropping(pn=0.0)])

    gcn1 = GConv(input_dim=input_dim, hidden_dim=512, num_layers=2).to(device)
    gcn2 = GConv(input_dim=input_dim, hidden_dim=512, num_layers=2).to(device)
    mlp1 = FC(input_dim=512, output_dim=512)
    mlp2 = FC(input_dim=512 * 2, output_dim=512)


    encoder_model = Encoder(gcn1=gcn1, gcn2=gcn2, mlp1=mlp1, mlp2=mlp2, aug1=aug1, aug2=aug2).to(device)


    optimizer = Adam(encoder_model.parameters(), lr=0.01)

    res = []
    with tqdm(total=1000, desc='(T)') as pbar:
        for epoch in range(1, 1001):
            rho = 1
            loss = train(encoder_model, dataloader, optimizer, rho=rho)
            pbar.set_postfix({'loss': loss})
            pbar.update()

            if epoch % 1 == 0:
                test_result = test(encoder_model, dataloader)
                res.append(test_result["acc"])
                print(f'Best test ACC={test_result["acc"]:.4f}')


if __name__ == '__main__':
    main()
