# -*- coding: utf-8 -*-
import torch
import torch.nn.functional as F

def InfoNCE(z1: torch.Tensor, z2: torch.Tensor, tau=0.2):
    l1 = _infonce(z1, z2, tau)
    l2 = _infonce(z2, z1, tau)
    return (l1 + l2) * 0.5

def _similarity(z1: torch.Tensor, z2: torch.Tensor):
    z1 = F.normalize(z1)
    z2 = F.normalize(z2)
    return torch.mm(z1, z2.t())

def _infonce(z1: torch.Tensor, z2: torch.Tensor, tau):
    f = lambda x: torch.exp(x / tau)
    intra_view_sim = f(_similarity(z1, z1))
    inter_view_sim = f(_similarity(z1, z2))
    pos = inter_view_sim.diag()
    neg = (inter_view_sim.sum(dim=1) + intra_view_sim.sum(dim=1) - intra_view_sim.diag() - inter_view_sim.diag())
    l = pos / (pos + neg)
    loss = -torch.log(l)
    return loss.mean()

def BarlowTwins(h1: torch.Tensor, h2: torch.Tensor, lambda_=None, batch_norm=True, eps=1e-15, *args, **kwargs):
    batch_size = h1.size(0)
    feature_dim = h1.size(1)
    if lambda_ is None:
        lambda_ = 1. / feature_dim
    if batch_norm:
        z1_norm = (h1 - h1.mean(dim=0)) / (h1.std(dim=0) + eps)
        z2_norm = (h2 - h2.mean(dim=0)) / (h2.std(dim=0) + eps)
        c = torch.mm(z1_norm.T, z2_norm) / batch_size
    else:
        c = torch.mm(h1.T, h2) / batch_size

    off_diagonal_mask = ~torch.eye(feature_dim).bool()
    loss = (1 - c.diagonal()).pow(2).sum()
    loss += lambda_ * c[off_diagonal_mask].pow(2).sum()

    return loss

import math
def loss_dependence(emb1, emb2, dim, lamda1=0.01, lamda2=0.01):
    R = torch.eye(dim).cuda() - (1/dim) * torch.ones(dim, dim).cuda()
    K1 = torch.mm(emb1, emb1.t())
    K2 = torch.mm(emb2, emb2.t())
    RK1 = torch.mm(R, K1)
    K1_SumOffdiagonal = lamda1* (K1.sum() - K1.diag().sum())
    K2_SumOffdiagonal = lamda2* (K2.sum() - K2.diag().sum())
    RK2 = torch.mm(R, K2)
    HSIC = math.pow(dim-1,-2)*torch.trace(torch.mm(RK1, RK2)) + K1_SumOffdiagonal + K2_SumOffdiagonal
    return HSIC
