import jax
import jax.numpy as jnp
from jax.lax import stop_gradient
import optax


def cosine_dist(x1, x2):
    """Compute the cosine distance between two vectors.
    Args:
        x1: the first vector. Shape: [bsz, dim].
        x2: the second vector. Shape: [bsz, dim].
    Returns:
        The cosine distance between the two vectors.
    """
    x1 = x1 / jnp.linalg.norm(x1, axis=-1, keepdims=True)
    x2 = x2 / jnp.linalg.norm(x2, axis=-1, keepdims=True)
    sim = jnp.einsum("bd,bd->b", x1, x2)
    return -jnp.mean(sim)

def cosine_dist_detached_norm(x1, x2):
    """Compute the cosine distance between two vectors.
    Args:
        x1: the first vector. Shape: [bsz, dim].
        x2: the second vector. Shape: [bsz, dim].
    Returns:
        The cosine distance between the two vectors.
    """
    x1 = x1 / stop_gradient(jnp.linalg.norm(x1, axis=-1, keepdims=True))
    x2 = x2 / stop_gradient(jnp.linalg.norm(x2, axis=-1, keepdims=True))
    sim = jnp.einsum("bd,bd->b", x1, x2)
    return -jnp.mean(sim)


def l2_dist(x1, x2):
    """Compute the L2 distance between two vectors.
    Args:
        x1: the first vector. Shape: [bsz, dim].
        x2: the second vector. Shape: [bsz, dim].
    Returns:
        The squared L2 distance between the two vectors.
    """
    return ((x1 - x2) ** 2).mean()


def simclr_loss(projs1, projs2):
    """Compute the SimCLR loss.
    Args:
        projs1: the projections of the first view. Shape: [bsz, proj_dim].
        projs2: the projections of the second view. Shape: [bsz, proj_dim].
        proj_dim: the dimension of the projections.
        Returns:
        The SimCLR loss.
    """
    
    # compute similarities
    bsz = projs1.shape[0]
    proj = jnp.concatenate([projs1, projs2], axis=0) # [2*bsz, proj_dim]
    proj = proj / jnp.linalg.norm(proj, axis=-1, keepdims=True) # [2*bsz, proj_dim]
    logits = jnp.einsum('ij,kj->ik', proj, proj) # [2*bsz, 2*bsz]
    temperature = 0.1
    logits = logits / temperature # [2*bsz, 2*bsz]

    labels = jnp.eye(2*bsz, dtype=jnp.int32) # [2*bsz, 2*bsz]
    labels = jnp.concatenate([labels[bsz:], labels[:bsz]], axis=0) # [2*bsz, 2*bsz]

    # compute cross entropy loss
    loss = optax.softmax_cross_entropy(logits, labels).mean()
    
    return loss

def vicreg_loss(projs1, projs2, pull_coeff=1.,push_coeff=1., decorr_coeff=100.):
    """Compute the VICReg loss.
    Args:
        projs1: the projections of the first view. Shape: [bsz, proj_dim].
        projs2: the projections of the second view. Shape: [bsz, proj_dim].
        proj_dim: the dimension of the projections.
        push_coeff: the coefficient for the variance loss.
        decorr_coeff: the coefficient for the covariance loss.
        Returns:
        The VICReg loss.
    """
    
    loss = 0.
    proj_dim = projs1.shape[-1]

    # compute variance loss
    var1 = jnp.var(projs1, axis=0) + 1e-4 # [proj_dim]
    var2 = jnp.var(projs2, axis=0) + 1e-4 # [proj_dim]
    std1 = jnp.sqrt(var1) # [proj_dim]
    std2 = jnp.sqrt(var2) # [proj_dim]
    # relu
    loss += 0.5 * push_coeff * (jnp.mean(jnp.maximum(1. - std1, 0.)) + jnp.mean(jnp.maximum(1. - std2, 0.))) # scalar

    # compute covariance loss
    cov1 = jnp.cov(projs1, rowvar=False) # [proj_dim, proj_dim]
    cov2 = jnp.cov(projs2, rowvar=False) # [proj_dim, proj_dim]
    cov1 = jnp.square(cov1) # [proj_dim, proj_dim]
    cov2 = jnp.square(cov2) # [proj_dim, proj_dim]
    # leave out the diagonal
    cov1 = jnp.sum(jnp.triu(cov1, k=1)) / proj_dim
    cov2 = jnp.sum(jnp.triu(cov2, k=1)) / proj_dim
    loss += 0.5 * decorr_coeff * (cov1 + cov2)

    # compute invariance loss
    loss += pull_coeff * l2_dist(projs1, projs2)
    
    return loss

def cross_entropy_loss(logits, labels, num_classes):
    """Compute the cross entropy loss.
    Args:
        logits: the logits. Shape: [bsz, num_classes].
        labels: the labels. Shape: [bsz, num_classes].
    Returns:
        The cross entropy loss.
    """
    # TODO: only compute loss on the labeled examples (filter for non-negative labels)
    # currently broken for STL-10 dataset (labels are -1 for unlabeled examples)

    mask = jnp.where(labels >= 0, 1., 0.)
    f = lambda a, b: optax.softmax_cross_entropy(logits=a, labels=jax.nn.one_hot(b, num_classes=num_classes))
    loss = jnp.sum(jnp.where(mask > 0., f(logits, labels), 0.))/(jnp.sum(mask) + 1e-8)
    # print(f"label_onehot shape: {labels_onehot.shape}")
    return loss # optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()


def get_lep_losses(logits_lep, labels, num_classes):
    lep_metrics = {'loss': {}, 'acc': {}}
    loss = 0.
    # compute loss for every logits in logits_lep dict
    for key in logits_lep.keys():
        lep_metrics['acc'][key] = jnp.sum(jnp.argmax(logits_lep[key], axis=-1) == labels) / labels.shape[0]
        lep_metrics['loss'][key] = cross_entropy_loss(logits=logits_lep[key], labels=labels, num_classes=num_classes)
        loss += lep_metrics['loss'][key]
    return loss, lep_metrics


def get_supervised_loss(num_classes):
    def get_loss_fn(params, state, batch):
        """Compute the loss."""
        # forward pass
        tot_params = {'params': params, 'batch_stats':state.batch_stats}
        fwd = state.apply_fn

        # concatenate the two views
        in_batches = jnp.concatenate([batch[0], batch[1]], axis=0)
        labels = jnp.concatenate([batch[-1], batch[-1]], axis=0)
        out, updates = fwd(tot_params, in_batches, mutable=['batch_stats'])
        _, _, logits, logits_lep = out

        # compute loss
        loss = 0.
        for key in logits.keys():
            loss += cross_entropy_loss(logits=logits[key], labels=labels, num_classes=num_classes)

        # compute lep loss
        lep_loss, lep_metrics = get_lep_losses(logits_lep, labels, num_classes)
        loss += lep_loss

        return loss, (updates, None, lep_metrics)
    return get_loss_fn



def get_simclr_loss(num_classes):
    def get_loss_fn(params, state, batch):
        """Compute the loss."""
        tot_params = {'params': params, 'batch_stats': state.batch_stats}
        fwd = state.apply_fn

        # concatenate the two views
        bsz = batch[0].shape[0]
        in_batches = jnp.concatenate([batch[0], batch[1]], axis=0)
        labels = jnp.concatenate([batch[-1], batch[-1]], axis=0)

        out, updates = fwd(tot_params, in_batches, mutable=['batch_stats'])
        _, projs, _, logits_lep = out

        # compute loss for every proj in projs dict
        loss = 0.
        for key in projs.keys():
            projs1 = projs[key][:bsz]
            projs2 = projs[key][bsz:]
            loss += simclr_loss(projs1, projs2)
        
        # compute lep loss
        lep_loss, lep_metrics = get_lep_losses(logits_lep, labels, num_classes)
        loss += lep_loss
        
        return loss, (updates, None, lep_metrics)
    return get_loss_fn


def get_vicreg_loss(pull_coeff = 1., push_coeff = 1., decorr_coeff = 100., distance_metric='cosine', num_classes=10):
    def get_loss_fn(params, state, batch):
        """Compute the loss."""
        # forward pass
        tot_params = {'params': params, 'batch_stats':state.batch_stats}
        fwd = state.apply_fn

        # concatenate the two views
        bsz = batch[0].shape[0]
        in_batches = jnp.concatenate([batch[0], batch[1]], axis=0)
        labels = jnp.concatenate([batch[-1], batch[-1]], axis=0)

        out, updates = fwd(tot_params, in_batches, mutable=['batch_stats'])
        _, projs, _, logits_lep = out

        tg_updates = None

        # compute loss for every proj in projs dict
        loss = 0.
        for key in projs.keys():
            projs1 = projs[key][:bsz]
            projs2 = projs[key][bsz:]
            loss += vicreg_loss(projs1, projs2, pull_coeff, push_coeff, decorr_coeff)
        
        # compute lep loss
        lep_loss, lep_metrics = get_lep_losses(logits_lep, labels, num_classes)
        loss += lep_loss

        return loss, (updates, tg_updates, lep_metrics)
    return get_loss_fn


def get_simsiam_loss(distance_metric='cosine', iso=False, num_classes=10):
    def get_loss_fn(params, state, batch):
        """Compute the loss."""
        # forward pass
        tot_params = {'params': params, 'batch_stats':state.batch_stats}
        mutable = ['batch_stats']

        # targets
        tot_tg_params = {'params': state.target_params,
                        'batch_stats': state.tg_batch_stats}

        if state.direct_pred is not None:
            tot_params['direct_pred'] = state.direct_pred
            tot_tg_params['direct_pred'] = state.direct_pred
            mutable.append('direct_pred')

        fwd = state.apply_fn

        # concatenate the two views
        bsz = batch[0].shape[0]
        in_batches = jnp.concatenate([batch[0], batch[1]], axis=0)
        labels = jnp.concatenate([batch[-1], batch[-1]], axis=0)

        out, updates = fwd(tot_params, in_batches, mutable=mutable)
        _, projs, preds, logits_lep = out

        # targets
        tg_out, tg_updates = fwd(tot_tg_params, in_batches, is_target_net=True, mutable=mutable)
        _, projs_tg, _, _ = tg_out

        # compute loss for every proj in projs dict
        loss = 0.
        for key in preds.keys():
            z1, z2 = projs[key][:bsz], projs[key][bsz:]
            z1t, z2t = projs_tg[key][:bsz], projs_tg[key][bsz:]
            p1, p2 = preds[key][:bsz], preds[key][bsz:]
            if distance_metric == 'cosine':
                if iso:
                    raise NotImplementedError
                    sim = jnp.einsum("bd,bd->b", z1, stop_gradient(z2t)) / (jnp.linalg.norm(p1, axis=1) * jnp.linalg.norm(stop_gradient(z2t), axis=1) + 1e-8) \
                        + jnp.einsum("bd,bd->b", z2, stop_gradient(z1t)) / (jnp.linalg.norm(p2, axis=1) * jnp.linalg.norm(stop_gradient(z1t), axis=1) + 1e-8)
                    loss += - 0.5 * jnp.mean(sim) \
                            + 0.5 * (cosine_dist(p1, stop_gradient(z2t)) \
                                   + cosine_dist(p2, stop_gradient(z1t)))
                else:
                    loss += cosine_dist(p1, stop_gradient(z2t)) \
                          + cosine_dist(p2, stop_gradient(z1t))
            
            elif distance_metric == 'l2':
                if iso:
                    loss += 0.5 * l2_dist(z1, stop_gradient(z1 + z2t - p1)) \
                          + 0.5 * l2_dist(z2, stop_gradient(z2 + z1t - p2)) \
                          + 0.5 * l2_dist(p1, stop_gradient(z2t)) \
                          + 0.5 * l2_dist(p2, stop_gradient(z1t))
                else:
                    loss += l2_dist(p1, stop_gradient(z2t)) \
                          + l2_dist(p2, stop_gradient(z1t))
                    
        # compute lep loss
        lep_loss, lep_metrics = get_lep_losses(logits_lep, labels, num_classes)
        loss += lep_loss

        return loss, (updates, tg_updates, lep_metrics)
    return get_loss_fn


def get_direct_loss(distance_metric='cosine', iso=False, iso2=False, num_dims=-1, num_classes=10):
    def get_loss_fn(params, state, batch):
        """Compute the loss."""
        # network params
        tot_params = {'params': params, 'batch_stats':state.batch_stats}
        mutable = ['batch_stats']

        # target network params
        tot_tg_params = {'params': state.target_params,
                        'batch_stats': state.tg_batch_stats}

        # directpred params
        if state.direct_pred is not None:
            tot_params['direct_pred'] = state.direct_pred
            tot_tg_params['direct_pred'] = state.direct_pred
            mutable.append('direct_pred')

        fwd = state.apply_fn

        # concatenate the two views
        bsz = batch[0].shape[0]
        in_batches = jnp.concatenate([batch[0], batch[1]], axis=0)
        labels = jnp.concatenate([batch[-1], batch[-1]], axis=0)

        # forward pass
        out, updates = fwd(tot_params, in_batches, mutable=mutable)
        _, projs, preds, logits_lep = out

        # targets
        tg_out, tg_updates = fwd(tot_tg_params, in_batches, is_target_net=True, mutable=mutable)
        _, projs_tg, _, _ = tg_out
        
        # compute loss for every proj in projs dict
        loss = 0.
        for key in preds.keys():
            z1, z2 = projs[key][:bsz], projs[key][bsz:]
            z1t, z2t = projs_tg[key][:bsz], projs_tg[key][bsz:]
            p1, p2 = preds[key][:bsz], preds[key][bsz:]

            pred_key = f'pred_{key}'
            U = state.direct_pred[pred_key]['U']
            s = state.direct_pred[pred_key]['ev']
            if num_dims > 0:
                U = U[:, :num_dims] # only use the first num_dims dims
                s = s[:num_dims]

            z1hat = jnp.matmul(z1, U)
            z2hat = jnp.matmul(z2, U)
            z1t_hat = jnp.matmul(z1t, U)
            z2t_hat = jnp.matmul(z2t, U)
            p1hat = jnp.matmul(z1hat, jnp.diag(s))
            p2hat = jnp.matmul(z2hat, jnp.diag(s))

            if distance_metric == 'cosine':
                if iso:
                    assert iso2 == False, "can't have both iso and iso2"
                    p1sqrt_hat = jnp.matmul(z1hat, jnp.diag(jnp.sqrt(s)))
                    p2sqrt_hat = jnp.matmul(z2hat, jnp.diag(jnp.sqrt(s)))
                    sim = jnp.einsum("bd,bd->b", z1hat, stop_gradient(z2t_hat)) / stop_gradient(jnp.linalg.norm(p1hat, axis=1) * jnp.linalg.norm(z2t_hat, axis=1)) \
                        + jnp.einsum("bd,bd->b", z2hat, stop_gradient(z1t_hat)) / stop_gradient(jnp.linalg.norm(p2hat, axis=1) * jnp.linalg.norm(z1t_hat, axis=1)) \
                        - 0.5 * stop_gradient(jnp.einsum("bd,bd->b", p1hat, z2t_hat) / (jnp.linalg.norm(p1hat, axis=1)**3 * jnp.linalg.norm(z2t_hat, axis=1))) * jnp.linalg.norm(p1sqrt_hat, axis=1)**2 \
                        - 0.5 * stop_gradient(jnp.einsum("bd,bd->b", p2hat, z1t_hat) / (jnp.linalg.norm(p2hat, axis=1)**3 * jnp.linalg.norm(z1t_hat, axis=1))) * jnp.linalg.norm(p2sqrt_hat, axis=1)**2
                    loss += - jnp.mean(sim)
                elif iso2:
                    p1sqrt_hat = jnp.matmul(z1hat, jnp.diag(jnp.sqrt(s)))
                    p2sqrt_hat = jnp.matmul(z2hat, jnp.diag(jnp.sqrt(s)))
                    sim = jnp.einsum("bd,bd->b", z1hat, stop_gradient(z2t_hat)) / (jnp.linalg.norm(p1sqrt_hat, axis=1) * stop_gradient(jnp.linalg.norm(z2t_hat, axis=1))) * stop_gradient(jnp.linalg.norm(p1sqrt_hat, axis=1)**3 / jnp.linalg.norm(p1hat, axis=1)**3) \
                        + jnp.einsum("bd,bd->b", z2hat, stop_gradient(z1t_hat)) / (jnp.linalg.norm(p2sqrt_hat, axis=1) * stop_gradient(jnp.linalg.norm(z1t_hat, axis=1))) * stop_gradient(jnp.linalg.norm(p2sqrt_hat, axis=1)**3 / jnp.linalg.norm(p2hat, axis=1)**3)
                    loss += - jnp.mean(sim)
                else:
                    loss += cosine_dist(p1hat, stop_gradient(z2t_hat)) \
                          + cosine_dist(p2hat, stop_gradient(z1t_hat))
            elif distance_metric == 'pseudo_cosine':
                raise NotImplementedError
            else:
                assert iso2 == False, "iso2 only works with cosine distance"
                if iso:
                    loss += l2_dist(z1hat, stop_gradient(z1hat + z2t_hat - p1hat)) \
                          + l2_dist(z2hat, stop_gradient(z2hat + z1t_hat - p2hat))
                else:
                    loss += l2_dist(p1hat, stop_gradient(z2t_hat)) \
                          + l2_dist(p2hat, stop_gradient(z1t_hat))
                    
        # compute lep loss
        lep_loss, lep_metrics = get_lep_losses(logits_lep, labels, num_classes)
        loss += lep_loss

        return loss, (updates, tg_updates, lep_metrics)
    return get_loss_fn


def get_training_loss(args, num_classes):
    if args.loss == 'simsiam':
        loss_fn = get_simsiam_loss(args.distance_metric, args.iso, num_classes=num_classes)
    elif args.loss == 'simclr':
        loss_fn = get_simclr_loss(num_classes=num_classes)
    elif args.loss == 'vicreg':
        loss_fn = get_vicreg_loss(args.pull_coeff, args.push_coeff, args.decorr_coeff, num_classes=num_classes)
    elif args.loss == 'directloss':
        loss_fn = get_direct_loss(args.distance_metric, args.iso, args.iso2, args.dp_pc_num_components, num_classes=num_classes)
    elif args.loss == 'supervised':
        loss_fn = get_supervised_loss(num_classes)
    return loss_fn


