import random
import torch
import torch.nn as nn
from math import sqrt
import numpy as np


class SeqTokenAug(nn.Module):
    """SeqTokenAug
    """

    def __init__(self, p=0.5, aug_token_prob=0.75, batch_prob=1.0, token_attention_flag=0, seq_token_flag=0, eps=1e-6):
        """
        Args:
          p (float): probability of using SeqTokenMix.
          eps (float): scaling parameter to avoid numerical issues.
          aug_token_prob (float): prob of augmented tokens.
          batch_prob (float): prob of augmented samples.
        """
        super().__init__()
        self.p = p

        self.eps = eps
        self.aug_token_prob = aug_token_prob
        self.batch_prob = batch_prob

        self.token_attention_flag = token_attention_flag
        self.seq_token_flag = seq_token_flag

    def scores_to_mask(self, scores, mask_prob=0.75):
        # scores: B, K, C, N
        B, K, C, N = scores.shape
        scores_BK_C_N = scores.reshape(B*K, C, N)
        scores_channel_mean = torch.mean(scores_BK_C_N, dim=1, keepdim=False)  # BKxN
        K_phase = int(N * mask_prob)  # percent of the phase-related patches

        if mask_prob == 1.0:
            K_phase = N - 1

        threshold = torch.sort(scores_channel_mean, dim=1, descending=True)[0][:, K_phase]
        threshold_expand = threshold.view(B*K, 1).expand(B*K, N)
        mask_phase = torch.where(scores_channel_mean > threshold_expand,
                                 torch.tensor(1.).cuda(), torch.tensor(0.).cuda())
        mask_phase = mask_phase.unsqueeze(dim=1).view(B, K, 1, N)
        return mask_phase

    def forward(self, x, x_aug, Bx=None):
        # BxKxCxL
        if not self.training or (random.random() > self.p):
            return x

        # x, x_aug: B, K, C, L
        B, K, C, L = x.shape

        if self.token_attention_flag != 0:
            token_mask = self.scores_to_mask(scores=Bx, mask_prob=self.aug_token_prob)
        else:
            shape = (B, K, 1, L)
            token_mask = x.new_empty(shape).bernoulli_(self.aug_token_prob).float()

        x_final = x_aug * token_mask + x * (1 - token_mask)
        batch_mask = x.new_empty((B, 1, 1, 1)).bernoulli_(self.batch_prob).float().cuda()
        x_final = (batch_mask * x_final + (1 - batch_mask) * x)
        return x_final