from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from abc import ABC

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import precision_score, recall_score

import numpy as np

from models.pointnet2_utils import distributed_sinkhorn, l2_normalize

class SCEModel(torch.nn.Module):
    def __init__( self, temp = 0.1, temp_m=0.05, coeff = 0.5) :
        super(SCEModel, self).__init__()
        self.temp = temp
        self.temp_m = temp_m
        self.coeff = coeff

    def forward(self, q, k, sim_k) :
        batch_size = q.shape[0]
        labels = torch.zeros(batch_size, dtype=torch.long).cuda()
        sim_q = torch.einsum('nc,mc->nm', [q, k])
        # sim_k = torch.zeros(batch_size, device=self.device).unsqueeze(-1)
        mask = torch.eye(batch_size, dtype=torch.float32).cuda()
        logits_q = sim_q / self.temp
        logits_k = sim_k / self.temp_m
        prob_k = F.softmax(logits_k, dim=1)
        prob_q = F.normalize(self.coeff * mask + (1 - self.coeff) * prob_k, p=1, dim=1)
        loss = - torch.sum(prob_q * F.log_softmax(logits_q,dim=1), dim=1).mean(dim=0)
        return loss

class SupConLoss(torch.nn.Module):
    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
    It also supports the unsupervised contrastive loss in SimCLR
    From: https://github.com/HobbitLong/SupContrast"""
    def __init__(self, temperature=0.07, contrast_mode='all',
                 base_temperature=0.07):
        super(SupConLoss, self).__init__()
        self.temperature = temperature
        self.contrast_mode = contrast_mode
        self.base_temperature = base_temperature

    def forward(self, features, labels=None, mask=None):
        """Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf
        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """

        device = (torch.device('cuda')
                  if features.is_cuda
                  else torch.device('cpu'))

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)

        batch_size = features.shape[0]
        if labels is not None and mask is not None:
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None:
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            if labels.shape[0] != batch_size:
                raise ValueError('Num of labels does not match num of features')
            mask = torch.eq(labels, labels.T).float().to(device)
        else:
            mask = mask.float().to(device)

        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        if self.contrast_mode == 'one':
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == 'all':
            anchor_feature = contrast_feature
            anchor_count = contrast_count
        else:
            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            self.temperature)

        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        # compute mean of log-likelihood over positive
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) #将mask换成权重

        # loss
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.view(anchor_count, batch_size).mean()

        return loss

    def label_neg(self, features, neg_features, labels=None, mask=None):
        device = (torch.device('cuda')
                  if features.is_cuda
                  else torch.device('cpu'))

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)

        batch_size = features.shape[0]
        if labels is not None and mask is not None:
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None:
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            if labels.shape[0] != batch_size:
                raise ValueError('Num of labels does not match num of features')
            mask = torch.eq(labels, labels.T).float().to(device)
        else:
            mask = mask.float().to(device)

        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        if self.contrast_mode == 'one':
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == 'all':
            anchor_feature = contrast_feature
            anchor_count = contrast_count
        else:
            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            self.temperature)         # N,N   N,N+K

        neg_dot_contrast = torch.div(
            torch.matmul(anchor_feature, neg_features.T),
            self.temperature)

        # for numerical stability
        logits_max, _ = torch.max(torch.cat([anchor_dot_contrast, neg_dot_contrast], dim=1), dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()
        neg_dot_contrast = neg_dot_contrast - logits_max.detach()

        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(torch.cat([logits * logits_mask, neg_dot_contrast], dim=1))  #neg

        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        # compute mean of log-likelihood over positive
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)  # 将mask换成权重

        # loss
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.view(anchor_count, batch_size).mean()

        return loss


class BCE(nn.Module):
    eps = 1e-7 # Avoid calculating log(0). Use the small value of float16.
    def forward(self, prob1, prob2, simi):
        # prob1 and prob2 are pair enum
        # simi: 1->similar; -1->dissimilar; 0->unknown(ignore)
        assert len(prob1)==len(prob2)==len(simi), 'Wrong input size:{0},{1},{2}'.format(str(len(prob1)),str(len(prob2)),str(len(simi)))
        P = prob1.mul_(prob2)
        P = P.sum(1)
        P.mul_(simi).add_(simi.eq(-1).type_as(P))
        neglogP = -P.add_(BCE.eps).log_()
        neglogP = torch.nan_to_num(neglogP)
        return neglogP.mean()

class SoftCE(nn.Module):
    eps = 1e-7 # Avoid calculating log(0). Use the small value of float16.
    def forward(self, prob1, prob2, simi):
        # prob1 and prob2 are pair enum
        # simi: a soft prob value
        assert len(prob1)==len(prob2)==len(simi), 'Wrong input size:{0},{1},{2}'.format(str(len(prob1)),str(len(prob2)),str(len(simi)))
        P = prob1.mul_(prob2)
        P = P.sum(1)
        logP = P.add_(SoftCE.eps).log_()
        negsimlogP = -simi.mul_(logP)
        return negsimlogP.mean()

class SelfUniformLoss(nn.Module):
    def __init__(self, row_tau=0.1, col_tau=0.1, eps=1e-8):
        super(SelfUniformLoss, self).__init__()
        self.row_tau = row_tau
        self.col_tau = col_tau
        self.eps = eps

    def forward(self, cls_out):
        total_loss = 0.0
        num_loss_terms = 0

        const = cls_out[0].shape[0] / cls_out[0].shape[1] # A, N,C
        target = []

        for view_i_idx, view_i in enumerate(cls_out):
            view_i_target = F.softmax(view_i/ self.col_tau, dim=0)
            view_i_target = F.normalize(view_i_target, p=1, dim=1, eps=self.eps)
            target.append(view_i_target)

        for view_j_idx, view_j in enumerate(cls_out):  # view j
            view_j_pred = F.softmax(view_j / self.row_tau, dim=1)
            view_j_pred = F.normalize(view_j_pred, p=1, dim=0, eps=self.eps)
            view_j_log_pred = torch.log(const * view_j_pred + self.eps)

            for view_i_idx, view_i_target in enumerate(target):
                if view_i_idx == view_j_idx or (view_i_idx >= 2 and view_j_idx >= 2):
                    continue
                # cross entropy
                loss_i_j = - torch.mean(torch.sum(view_i_target * view_j_log_pred, dim=1))
                total_loss += loss_i_j
                num_loss_terms += 1
        total_loss /= num_loss_terms
        return total_loss

def PairEnum(x, mask=None):
    # Enumerate all pairs of feature in x
    assert x.ndimension() == 2, 'Input dimension must be 2'
    x1 = x.repeat(x.size(0),1)
    x2 = x.repeat(1,x.size(0)).view(-1,x.size(1))
    if mask is not None:
        xmask = mask.view(-1,1).repeat(1,x.size(1))
        #dim 0: #sample, dim 1:#feature
        x1 = x1[xmask].view(-1,x.size(1))
        x2 = x2[xmask].view(-1,x.size(1))
    return x1,x2

def compute_bce(part_attn, prob2, prob2_bar, criterion2, topk, return_acc=False, label=None, loss=False):
    # part_attn  b,num_parts
    rank_idx = torch.argsort(part_attn, dim=1, descending=True)
    rank_idx1, rank_idx2 = PairEnum(rank_idx)
    rank_idx1, rank_idx2 = rank_idx1[:, :topk], rank_idx2[:, :topk]
    rank_idx1, _ = torch.sort(rank_idx1, dim=1)
    rank_idx2, _ = torch.sort(rank_idx2, dim=1)

    # 方法一
    # rank_diff = rank_idx1 - rank_idx2
    # rank_diff = torch.sum(torch.abs(rank_diff), dim=1)
    # 方法二
    combined = torch.cat((rank_idx1, rank_idx2), dim=1)
    rank_diff = torch.tensor(np.array([len(torch.unique(x))>8 for x in combined]))

    target_ulb = torch.ones_like(rank_diff).float().cuda()
    target_ulb[rank_diff > 0] = -1

    if return_acc:
        rlabel = label.view(-1, 1)
        label1, label2 = PairEnum(rlabel)

        label12_diff = label1 - label2
        label12_diff = torch.sum(torch.abs(label12_diff), dim=1)

        real_target_ulb = torch.ones_like(label12_diff).float().cuda()
        real_target_ulb[label12_diff > 0] = -1

        ulb_acc = (real_target_ulb == target_ulb).sum().item() / real_target_ulb.shape[0]
        ################################################################
        # calculate recall and precision
        y_true = real_target_ulb.detach().clone().cpu().numpy()
        y_true[y_true == -1] = 0
        y_pred = target_ulb.detach().clone().cpu().numpy()
        y_pred[y_pred == -1] = 0

        p = precision_score(y_true, y_pred)  # , average='macro')
        r = recall_score(y_true, y_pred)  # , average='macro')
        if loss :
            return ulb_acc
        ################################################################

    # calc BCE loss using enum
    prob2 = F.softmax(prob2, dim=1)
    prob2_bar = F.softmax(prob2_bar, dim=1)
    prob1_ulb, _ = PairEnum(prob2)
    _, prob2_ulb = PairEnum(prob2)
    loss_bce = criterion2(prob1_ulb, prob2_ulb, target_ulb)
    if return_acc:
        return loss_bce, ulb_acc
    return loss_bce

class PPC(nn.Module, ABC):
    def __init__(self):
        super(PPC, self).__init__()

    def forward(self, contrast_logits, contrast_target):
        contrast_logits = torch.div(contrast_logits, 0.1)
        loss_ppc = F.cross_entropy(contrast_logits.view(-1, contrast_logits.shape[-1]) , contrast_target.long())
        return loss_ppc

class PPD(nn.Module, ABC):
    def __init__(self):
        super(PPD, self).__init__()

    def forward(self, contrast_logits, contrast_target):
        logits = torch.gather(contrast_logits.view(-1, contrast_logits.shape[-1]), 1, contrast_target[:, None].long())
        loss_ppd = (1 - logits).pow(2).mean()
        return loss_ppd

class ProtoDiff(nn.Module):
    def __init__(self):
        super(ProtoDiff, self).__init__()
        self.dis = 0.4

    def forward(self, part_prototype):
        part_diff = torch.einsum('kmc,knc->kmn', part_prototype, part_prototype) - torch.eye(part_prototype.shape[1],part_prototype.shape[1]).cuda()- self.dis
        loss_diff = torch.sum(part_diff[part_diff>0])

        return loss_diff

class CatProtoDiff(nn.Module):
    def __init__(self, dist = 0.1):
        super(CatProtoDiff, self).__init__()
        self.dis = dist

    def forward(self, cat_prototype):
        part_diff = torch.einsum('mc,nc->mn', cat_prototype, cat_prototype) - torch.eye(cat_prototype.shape[0],cat_prototype.shape[0]).cuda()- self.dis
        loss_diff = torch.sum(part_diff[part_diff>0])
        return loss_diff

class AttnDiff(nn.Module):
    def __init__(self, num_classes):
        super(AttnDiff, self).__init__()
        self.num_class = num_classes
        self.loss_fun = nn.MSELoss()
        # self.loss_fun = nn.BCELoss()

    def forward(self, attn):
        # target = F.one_hot(target.long(), num_classes = self.num_class) # b,m
        # target = torch.repeat_interleave(target, 3, dim=1).float()
        attn = attn.mean(dim=0)
        loss = self.loss_fun(attn, torch.ones_like(attn))
        return loss

class ClusteringLoss(nn.Module):
    def __init__(self, n_clusters=10, hidden=10, alpha=1.0):
        super(ClusteringLoss, self).__init__()
        self.n_clusters = n_clusters
        self.alpha = alpha
        self.hidden = hidden
        # self.cluster_centers = cluster_centers
        self.cluster_loss = nn.KLDivLoss(size_average=False)

    def target_distribution(self, q_):
        weight = (q_ ** 2) / torch.sum(q_, 0)
        return (weight.t() / torch.sum(weight, 1)).t()

    def forward(self, x, cluster_centers):
        norm_squared = torch.sum((x.unsqueeze(1) - cluster_centers)**2, 2)
        numerator = 1.0 / (1.0 + (norm_squared / self.alpha))
        power = float(self.alpha + 1) / 2
        numerator = numerator**power
        t_dist = (numerator.t() / torch.sum(numerator, 1)).t() #soft assignment using t-distribution
        target = self.target_distribution(t_dist).detach()
        loss = self.cluster_loss(t_dist.log(), target) / t_dist.shape[0]
        return loss

class l1_regularization_loss(nn.Module):
    def __init__(self):
        super(l1_regularization_loss, self).__init__()

    def forward(self, attention):
        l1_loss  = torch.sum(torch.abs(attention),dim=-1)
        return torch.mean(l1_loss)

class SelfEntropyLoss(nn.Module):
    def __init__(self):
        super(SelfEntropyLoss, self).__init__()

    def forward(self, x):
        x = x.reshape(-1, x.shape[-1])
        b = F.softmax(x, dim=1) * F.log_softmax(x, dim=1)
        b = -1.0 * b.sum() / b.shape[0]
        return b

class LessPositiveLoss(torch.nn.Module):
    def __init__(self, threshold=0):
        super(LessPositiveLoss, self).__init__()
        self.threshold = threshold

    def forward(self, input_tensor):
        row_sum = torch.sum(torch.clamp(input_tensor, min=self.threshold), dim=1)
        return torch.mean(row_sum)

class DistillLoss(nn.Module):
    def __init__(self, warmup_teacher_temp_epochs=30, nepochs=250,
                 ncrops=2, warmup_teacher_temp=0.07, teacher_temp=0.04,
                 student_temp=0.01):
        super().__init__()
        self.student_temp = student_temp
        self.ncrops = ncrops
        self.teacher_temp_schedule = np.concatenate((
            np.linspace(warmup_teacher_temp,
                        teacher_temp, warmup_teacher_temp_epochs),
            np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
        ))


    def forward(self, student_output, teacher_output):
        """
        Cross-entropy between softmax outputs of the teacher and student networks.
        """
        student_out = student_output.reshape(-1, student_output.shape[-1])
        # teacher centering and sharpening
        # temp = self.teacher_temp_schedule[epoch]
        teacher_output = teacher_output.reshape(-1, teacher_output.shape[-1])
        teacher_output = F.softmax(teacher_output / self.student_temp, dim=-1)
        # teacher_out = teacher_output.detach()
        loss = torch.sum(-teacher_output * F.log_softmax(student_out, dim=-1), dim=-1).mean()

        return loss

class PGD_Prototype_novel(nn.Module, ABC):
    def __init__(self, configer=None):
        super(PGD_Prototype_novel, self).__init__()
        self.configer = configer
        self.loss_ppc_weight = self.configer.loss_ppc_weight
        self.loss_ppd_weight = self.configer.loss_ppd_weight
        self.seg_criterion = nn.CrossEntropyLoss()
        self.ppc_criterion = PPC()
        self.cat_proto_criterion = CatProtoDiff(0.1)
        self.cluster_criterion = DistillLoss()
        self.student_criterion = ClusteringLoss()
        self.part_score_critertion = SupConLoss()
        self.self_uniform_criterion = SelfUniformLoss()
        self.self_entropy = SelfEntropyLoss()
        self.soft_part_contrast_criterion = SCEModel()


    def forward(self, preds, target, predsbar=None, flag=0):
        loss_list = {
            'loss_ce': 0.0,
            'loss_part_proto': 0.0,
            'loss_ppd': 0.0,
            'loss_cpd': 0.0,
            'loss_ppc': 0.0,
            'loss_cat_diff': 0.0,
            'loss_attn': 0.0,
            'loss_pair': 0.0,
            'loss_consist': 0.0,
            'pair_acc': 0.0,
            'loss_part_diff': 0.0,
            'loss_novel_cat_diff': 0.0,
            'loss_self_entropy': 0.0,
            'loss_feature_recon': 0.0,
            'loss_part_contrast': 0.0,
            'loss_l1': 0.0,
            'loss_soft_part_contrast': 0.0
        }


        part_prototypes = preds['part_protos']
        cat_prototypes = preds['cat_protos']

        loss = 0
        if flag == 0:
            logits = preds['logits']
            logits = torch.div(logits, 0.1)
            logits_bar = predsbar['logits']
            logits_bar = torch.div(logits_bar, 0.1)

            part_target = preds['part_target']
            part_score = preds['part_score']
            part_score_bar = predsbar['part_score']
            part_logits = preds['part_logits']
            part_logits_bar = predsbar['part_logits']

            loss_ce = self.seg_criterion(torch.vstack([logits, logits_bar]), torch.cat([target.long(), target.long()]))

            loss_self = self.cluster_criterion(torch.vstack([part_logits, part_logits_bar]), torch.vstack([part_logits, part_logits_bar]))
            loss_diff = self.proto_criterion(part_prototypes.view(-1, part_prototypes.shape[-1]))
            loss_part_contrast = self.part_score_critertion(l2_normalize(torch.stack([part_score, part_score_bar],dim=1)), target.long())


            loss = loss + loss_ce + loss_self   + loss_part_contrast + loss_diff
            loss_list['loss_part_diff'] = loss_diff.item()
            loss_list['loss_ce'] = loss_ce.item()
            loss_list['loss_ppc'] = loss_self.item()
            loss_list['loss_part_contrast'] = loss_part_contrast.item()

        elif flag == 4:
            logits = preds['logits']
            logits_bar = predsbar['logits']


            part_logits = preds['part_logits']
            part_logits_bar = predsbar['part_logits']


            loss_uniform = self.self_uniform_criterion([logits, logits_bar])

            loss_ppc = self.cluster_criterion(torch.vstack([part_logits, part_logits_bar]), torch.vstack([part_logits, part_logits_bar]))

            loss = loss + loss_uniform + loss_ppc
            loss_list['loss_ce'] =  loss_uniform.item()
            loss_list['loss_ppc'] = loss_ppc.item()


        return logits, loss , loss_list


