import os.path as osp

import torch
import torch.nn.functional as F
from torch.nn import Sequential as Seq, Dropout, Linear as Lin, init
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingLR

import torch_geometric.transforms as T
from torch_geometric.data import DataLoader
from torch_geometric.nn import DynamicEdgeConv, global_max_pool
from torch_geometric.utils import intersection_and_union as i_and_u
from torch_geometric.utils import to_dense_batch
from DiffGCN import DiffGCNBlock
import sys
from pointnet2_classification import MLP
from s3dis2 import S3DIS
from mgpool import mgunpool
import numpy as np
from FixedPoints2 import FixedPoints2
from focalLoss import FocalLoss
import matplotlib.pyplot as plt
import numpy as np
from torch.autograd import Variable

print(torch.cuda.get_device_capability())
expname = 's3dis.pth'
print("Exp name:", expname)

path = '/home/cluster/users/erant_group/s3dis'
savepath = '/home/cluster/users/erant_group/diffops/' + expname
batchSize = 20
npoints = 2048

if "slurm" in sys.argv:
    path = '/home/eliasof/meshfit/pytorch_geometric/data/s3dis'
    savepath = '/home/eliasof/meshfit/pytorch_geometric/checkpoints/' + expname

train_transform = T.Compose([
    T.RandomTranslate(0.01),
    T.RandomRotate(15, axis=0),
    T.RandomRotate(15, axis=1),
    T.RandomRotate(15, axis=2),
    T.FixedPoints(npoints, replace=False)
])
pre_transform, transform = T.NormalizeScale(), T.FixedPoints(npoints, replace=False)
train_dataset = S3DIS(path, train=True, transform=transform, test_area=6,
                      pre_transform=pre_transform)
test_dataset = S3DIS(path, train=False, transform=transform, test_area=6,
                     pre_transform=pre_transform)
train_loader = DataLoader(train_dataset, batch_size=batchSize, shuffle=True,
                          num_workers=6)
test_loader = DataLoader(test_dataset, batch_size=batchSize, shuffle=False,
                         num_workers=6)

eps = 1e-20


def calculate_sem_IoU(pred_np, seg_np):
    I_all = np.zeros(13)
    U_all = np.zeros(13)
    for sem_idx in range(seg_np.shape[0]):
        for sem in range(13):
            I = np.sum(np.logical_and(pred_np[sem_idx] == sem, seg_np[sem_idx] == sem))
            U = np.sum(np.logical_or(pred_np[sem_idx] == sem, seg_np[sem_idx] == sem))
            I_all[sem] += I
            U_all[sem] += U
    return I_all / (U_all + eps)


def stnknn(x, k):
    inner = -2 * torch.matmul(x.transpose(2, 1), x)
    xx = torch.sum(x ** 2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1)

    idx = pairwise_distance.topk(k=k, dim=-1)[1]  # (batch_size, num_points, k)
    return idx


def get_graph_feature(x, k=20, normals=None, idx=None, dim9=False):
    batch_size = x.size(0)
    num_points = x.size(2)
    x = x.view(batch_size, -1, num_points)
    if idx is None:
        if dim9 == False:
            idx = stnknn(x, k=k)  # (batch_size, num_points, k)
        else:
            idx = stnknn(x[:, 6:], k=k)
    device = torch.device('cuda')

    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points

    idx = idx + idx_base

    idx = idx.view(-1)

    _, num_dims, _ = x.size()

    x = x.transpose(2,
                    1).contiguous()  # (batch_size, num_points, num_dims)  -> (batch_size*num_points, num_dims) #   batch_size * num_points * k + range(0, batch_size*num_points)
    feature = x.view(batch_size * num_points, -1)[idx, :]
    feature = feature.view(batch_size, num_points, k, num_dims)
    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
    if normals is None:
        feature = torch.cat((x, feature - x), dim=3).permute(0, 3, 1, 2)
    else:
        normals = normals.contiguous()
        normals = normals.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
        feature = torch.cat((x, feature - x, normals), dim=3).permute(0, 3, 1, 2)

    return feature  # (batch_size, 2*num_dims, num_points, k)


class STN3d(nn.Module):
    def __init__(self):
        super(STN3d, self).__init__()
        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 9)
        self.relu = nn.ReLU()

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

    def forward(self, x):
        batchsize = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        iden = Variable(torch.from_numpy(np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32))).view(1, 9).repeat(
            batchsize, 1)
        if x.is_cuda:
            iden = iden.cuda()
        x = x + iden
        x = x.view(-1, 3, 3)
        return x


class Transform_Net(nn.Module):
    def __init__(self, args=None, normals=False, initialsize=2 * 64, outputsize=64):
        super(Transform_Net, self).__init__()
        self.args = args
        self.k = 3
        if normals:
            self.initialFeatSize = 6
            self.outputSize = 3
        else:
            self.initialFeatSize = initialsize
            self.outputSize = outputsize
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(128)
        self.bn3 = nn.BatchNorm1d(1024)

        self.conv1 = nn.Sequential(nn.Conv2d(self.initialFeatSize, 64, kernel_size=1, bias=False),
                                   self.bn1,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv2 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, bias=False),
                                   self.bn2,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv3 = nn.Sequential(nn.Conv1d(128, 1024, kernel_size=1, bias=False),
                                   self.bn3,
                                   nn.LeakyReLU(negative_slope=0.2))

        self.linear1 = nn.Linear(1024, 512, bias=False)
        self.bn3 = nn.BatchNorm1d(512)
        self.linear2 = nn.Linear(512, 256, bias=False)
        self.bn4 = nn.BatchNorm1d(256)

        self.transform = nn.Linear(256, self.outputSize * self.outputSize)
        init.constant_(self.transform.weight, 0)
        init.eye_(self.transform.bias.view(self.outputSize, self.outputSize))

    def forward(self, x):
        batch_size = x.size(0)

        x = self.conv1(x)  # (batch_size, 3*2, num_points, k) -> (batch_size, 64, num_points, k)
        x = self.conv2(x)  # (batch_size, 64, num_points, k) -> (batch_size, 128, num_points, k)
        x = x.max(dim=-1, keepdim=False)[0]  # (batch_size, 128, num_points, k) -> (batch_size, 128, num_points)

        x = self.conv3(x)  # (batch_size, 128, num_points) -> (batch_size, 1024, num_points)
        x = x.max(dim=-1, keepdim=False)[0]  # (batch_size, 1024, num_points) -> (batch_size, 1024)

        x = F.leaky_relu(self.bn3(self.linear1(x)), negative_slope=0.2)  # (batch_size, 1024) -> (batch_size, 512)
        x = F.leaky_relu(self.bn4(self.linear2(x)), negative_slope=0.2)  # (batch_size, 512) -> (batch_size, 256)

        x = self.transform(x)  # (batch_size, 256) -> (batch_size, 3*3)
        x = x.view(batch_size, self.outputSize, self.outputSize)  # (batch_size, 3*3) -> (batch_size, 3, 3)

        return x


class Net(torch.nn.Module):
    def __init__(self, out_channels, k=30, aggr='max'):
        super(Net, self).__init__()
        self.k = k
        self.conv0 = DiffGCNBlock(9, 64, k, 1)

        self.conv1 = DiffGCNBlock(64, 64, k, 2, pool=True)

        self.conv2 = DiffGCNBlock(64, 64, int(k / 1), 2, pool=True)

        self.conv3 = DiffGCNBlock(64, 128, int(k / 1), 2, pool=True)

        self.lin1 = MLP([64 * 3 + 128, 2048])

        self.mlp = Seq(MLP([2048 + 3 * 64 + 128, 512]), Dropout(0.5), MLP([512, 256]),
                       Dropout(0.5), MLP([256, 128]), Dropout(0.5), Lin(128, out_channels))

    def forward(self, data):
        x, pos, batch = data.x, data.pos, data.batch
        x0 = pos
        pos = x0

        x0 = torch.cat([pos, x], dim=1)
        # Open conv. and feature transform:
        x0, pos, batch = self.conv0(x0, pos, batch)

        origbatch = batch.clone()
        x1, pos, batch, pooldata1 = self.conv1(x0, pos, batch)

        x2, pos, batch, pooldata2 = self.conv2(x1, pos, batch)

        x3, pos, batch, pooldata3 = self.conv3(x2, pos, batch)
        # Unpool:
        x3 = mgunpool(mgunpool(mgunpool(x3, *pooldata3), *pooldata2), *pooldata1)
        x2 = mgunpool(mgunpool(x2, *pooldata2), *pooldata1)
        x1 = mgunpool(x1, *pooldata1)

        out = self.lin1(torch.cat([x0, x1, x2, x3], dim=1))
        out = global_max_pool(out, origbatch)
        out = out.repeat_interleave(repeats=npoints, dim=0)
        out = torch.cat([x0, x1, x2, x3, out], dim=1)
        out = self.mlp(out)
        return F.log_softmax(out, dim=1)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net(train_dataset.num_classes, k=10).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
# scheduler = CosineAnnealingLR(optimizer, 800)
print(train_dataset.num_classes)
print("optimizer:", optimizer)
print(model)
prev_test_acc = 0.0
if "continue" in sys.argv:
    optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
    checkpoint = torch.load(savepath)
    model.load_state_dict(checkpoint['model_state_dict'])
    epoch = 0
    acc = checkpoint['acc']
    prev_test_acc = checkpoint['test_acc']
    print("Continuing... current test acc:", prev_test_acc)


def one_hot_embedding(labels, num_classes):
    """Embedding labels to one-hot form.

    Args:
      labels: (LongTensor) class labels, sized [N,].
      num_classes: (int) number of classes.

    Returns:
      (tensor) encoded labels, sized [N, #classes].
    """
    y = torch.eye(num_classes)
    return y[labels].cuda()


lfunc = FocalLoss(alpha=0.2, reduction='mean')

optimizer.zero_grad()


def train():
    model.train()

    total_loss = correct_nodes = total_nodes = 0
    for i, data in enumerate(train_loader):

        data = data.to(device)
        # optimizer.zero_grad()
        out = model(data)
        loss = lfunc(out, data.y)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        total_loss += loss.item()
        correct_nodes += out.max(dim=1)[1].eq(data.y).sum().item()
        total_nodes += data.num_nodes

        if (i + 1) % 10 == 0:
            print('[{}/{}] Loss: {:.4f}, Train Accuracy: {:.4f}'.format(
                i + 1, len(train_loader), total_loss / 10,
                correct_nodes / total_nodes))
            total_loss = correct_nodes = total_nodes = 0


def test(loader):
    model.eval()
    correct_nodes = total_nodes = 0

    intersections, unions = [], []
    allIous = []
    for data in loader:
        data = data.to(device)
        with torch.no_grad():
            out = model(data)
        pred = out.max(dim=1)[1]
        correct_nodes += pred.eq(data.y).sum().item()
        total_nodes += data.num_nodes
        i, u = i_and_u(pred, data.y, test_dataset.num_classes, data.batch)
        intersections.append(i.to(torch.device('cpu')))
        unions.append(u.to(torch.device('cpu')))

    intersection = torch.cat(intersections, dim=0).to(torch.float)
    union = torch.cat(unions, dim=0).to(torch.float)
    print("intersection size:", intersection.shape)
    print("unions shape:", union.shape)
    print("len loader;", len(loader.dataset))

    intersection_all = torch.sum(intersection, dim=0)
    union_all = torch.sum(union, dim=0)

    iou_all = intersection_all.to(torch.float) / union_all.to(torch.float)
    iou_all[torch.isnan(iou_all)] = 1
    print("Iou All:", iou_all, "Mean IoU:", torch.mean(iou_all, dim=0))

    return correct_nodes / total_nodes, iou_all


print(model)

if "continue" in sys.argv:
    optimizer = torch.optim.Adam(model.parameters(), lr=0.000005)
    checkpoint = torch.load(savepath)
    model.load_state_dict(checkpoint['model_state_dict'])
    epoch = 0
    acc = checkpoint['acc']
    prev_test_acc = checkpoint['test_acc']
    print("Continuing... current test acc:", prev_test_acc)

for epoch in range(1, 801):
    train()
    acc, iou = test(test_loader)
    iou = iou.clone().detach().cpu().numpy()
    print('Epoch: {:02d}, Acc: {:.4f}'.format(epoch, acc))
    print("IoU:", np.mean(iou))

    if prev_test_acc < np.mean(iou):
        prev_test_acc = np.mean(iou)
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'acc': acc,
            'test_acc': np.mean(iou),
        }, savepath)
    print("Best IOU so far:", prev_test_acc)
