"""
dataloader : same dataloader, mod N sentences
embedding layer averaging, BCE loss
sweep over N, sweep over length of sentence
TODO: code for iterating through all buckets, need to write new Bucket Datloader
Think about scheduling N, run inference
Masking, loss calculation
"""
from importlib.machinery import WindowsRegistryFinder
from logging import debug
from re import X
from dataclasses import dataclass

from cffi.cffi_opcode import G_FLAGS
from tensorflow.python.keras.backend import binary_crossentropy, epsilon
from transformers.utils import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers import (
    DataCollatorForLanguageModeling,
    RobertaModel,
)
from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput
from transformers.file_utils import ModelOutput
from transformers.models.roberta.modeling_roberta import (
    RobertaLMHead,
    RobertaPreTrainedModel,
    RobertaClassificationHead,
)
from typing import Optional, Tuple
from transformers.activations import gelu
import math

from transformers.utils.dummy_tokenizers_objects import RetriBertTokenizerFast
from .utils import (
    random_encoding,
    sinusoidal_encoding,
    random_encoding_fourier,
    gen_attn_mask,
    binary_encoding,
)
import time
import numpy as np
from scipy.stats import ortho_group

logger = logging.get_logger(__name__)


class RobertaMultiSentenceSequenceClassificationParallel(RobertaPreTrainedModel):
    _keys_to_ignore_on_load_missing = [r"position_ids"]

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.config = config

        self.num_sentences = config.num_sentences
        self.sentence_loss_fct = config.sentence_loss_fct
        self.random_encoding_norm = config.random_encoding_norm
        self.retrieval_loss_coeff = config.retrieval_loss_coeff

        self.task_loss_coeff = config.task_loss_coeff

        self.roberta = RobertaModel(config, add_pooling_layer=False)
        # self.roberta = RobertaModel(config)
        self.freeze_encoder_layers = config.freeze_encoder_layers
        assert (
            self.freeze_encoder_layers <= config.num_hidden_layers
        ), "number of layers to freeze exceeds total number of layers"
        # freeze encoder
        for name, param in self.roberta.named_parameters():
            if "encoder" in name:
                layer = "".join(filter(str.isdigit, name))
                layer = int(layer)
                assert (
                    layer <= config.num_hidden_layers
                ), "extracted layer exceeds total number of layer, check for bug"
                if layer <= self.freeze_encoder_layers - 1:
                    param.requires_grad = False

        self.lm_head = RobertaClassificationHeadParallelLikeLM(config)
        # self.lm_head = RobertaClassificationHeadParallelLikeLMNoPrefixDebug(config)
        self.retrieval_head = RobertaLMHeadConditional(config)
        # self.classifier = RobertaClassificationHeadParallel(config)
        self.init_weights()

        # constant embedings, transfer weights and set requires grad to zero
        d_model = config.hidden_size
        sentence_embedding = None

        if self.sentence_loss_fct == "mlm_multisentence_sinusoidal":
            sentence_embedding = sinusoidal_encoding(self.num_sentences, d_model)

        elif self.sentence_loss_fct == "mlm_multisentence_sinusoidal_morespread":
            all_embeds = sinusoidal_encoding(10000, d_model)
            sentence_embedding = all_embeds[
                torch.randint(10000 - 1, (self.num_sentences,)) + 1, :
            ]

        elif self.sentence_loss_fct == "mlm_multisentence_random":
            sentence_embedding = random_encoding(
                self.num_sentences, d_model, norm=self.random_encoding_norm
            )
        elif self.sentence_loss_fct == "mlm_multisentence_random_orthogonal":
            sentence_embedding = [
                torch.from_numpy(ortho_group.rvs(config.hidden_size)).float()
                for _ in range(self.num_sentences)
            ]
            sentence_embedding = torch.stack(sentence_embedding, dim=0)
        elif self.sentence_loss_fct == "mlm_multisentence_binary":
            sentence_embedding = binary_encoding(
                self.num_sentences, d_model, epsilon=config.binary_encoding_epsilon
            )
        else:
            raise NotImplementedError()

        if sentence_embedding is not None:
            self.sentence_embedding = torch.nn.Parameter(sentence_embedding)
        else:
            # if no option is specified default
            all_embeds = sinusoidal_encoding(10000, d_model)
            sentence_embedding = all_embeds[
                torch.randint(10000 - 1, (self.num_sentences,)) + 1, :
            ]
        if not config.learnt_embeddings:
            self.sentence_embedding.requires_grad = False
        else:
            self.sentence_embedding.requires_grad = True

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
            Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
            config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
            If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )
        # get input embeddings and average over N sentence
        input_shape = input_ids.size()

        batch_size, seq_length = input_shape
        num_sentences = self.num_sentences
        past_key_values_length = 0

        # add the prefix
        # [CLS1, <s>, <s>, <s>, <s>]
        # [<s>, CLS2, <s>, <s>, <s>]
        # [<s>, <s>, CLS3, <s>, <s>]
        # [<s>, <s>, <s>, CLS4, <s>]
        # [<s>, <s>, <s>, <s>, CLS5]
        # let us just assume the last 5 tokens barring the masked token
        # are the cls tokens (easiest way to make use of existing vocab)

        # prefix 5 x 5
        prefix = torch.full((num_sentences, num_sentences), 50000).to(input_ids.device)
        modified_batch_size = batch_size // num_sentences
        prefix[
            torch.arange(num_sentences).to(input_ids.device),
            torch.arange(num_sentences).to(input_ids.device),
        ] = (
            -(torch.arange(num_sentences).to(input_ids.device) + 2)
            + self.roberta.embeddings.word_embeddings.weight.shape[0]
        )

        # [-2   <s>, <s>, <s>, <s>]
        # [<s>, -3, <s>, <s>, <s>]
        # [<s>, <s>, -4, <s>, <s>]
        # [<s>, <s>, <s>, -5, <s>]
        # [<s>, <s>, <s>, <s>, -6]
        # +  size of vocab
        cls_tokens = torch.full((num_sentences, 1), 49923).to(input_ids.device)
        prefix = torch.cat([prefix, cls_tokens], dim=1)

        prefix = prefix.repeat(modified_batch_size, 1)
        input_ids = input_ids[: (modified_batch_size * num_sentences)]
        input_ids = torch.cat([prefix, input_ids], dim=1)

        # # add of end of sentence tokens
        # lens = torch.sum(attention_mask, dim=1)
        # lens = torch.clamp(lens + 1, max=attention_mask.shape[1])
        # input_ids[torch.arange(input_ids.shape[0]), lens - 1] = 50163

        # concatenate
        embedding_output = self.roberta.embeddings(
            input_ids=input_ids,
            position_ids=position_ids,
            token_type_ids=token_type_ids,
            inputs_embeds=inputs_embeds,
            past_key_values_length=past_key_values_length,
        )
        _, _, embedding_dim = embedding_output.shape

        if self.sentence_loss_fct == "mlm_multisentence_random_orthogonal":
            embedding_output = embedding_output.view(
                modified_batch_size,
                num_sentences,
                seq_length + num_sentences + 1,
                embedding_dim,
            )
            embedding_output = torch.matmul(
                self.sentence_embedding, embedding_output.permute(0, 1, 3, 2)
            )
            # swap the last 2 dimension again
            embedding_output = embedding_output.permute(0, 1, 3, 2)
            # average across the sentences
            embedding_output = torch.sum(embedding_output, dim=1) / math.sqrt(
                self.num_sentences
            )
        else:
            # embedding_output = embedding_output[: (modified_batch_size * num_sentences)]
            embedding_output = embedding_output.view(
                modified_batch_size,
                num_sentences,
                seq_length + num_sentences + 1,
                embedding_dim,
            )

            # extract relevant sentence embeddings
            sentence_embed = self.sentence_embedding[:num_sentences, :]
            sentence_embed = sentence_embed.unsqueeze(1).expand(
                num_sentences, seq_length + num_sentences + 1, embedding_dim
            )
            sentence_embed = sentence_embed.to(embedding_output.device)
            embedding_output = embedding_output * sentence_embed.unsqueeze(0)

            embedding_output = torch.mean(embedding_output, dim=1)

        outputs = self.roberta(
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            position_ids=position_ids,
            inputs_embeds=embedding_output,
            return_dict=return_dict,
        )
        sequence_output = outputs[0]
        # fancy indexing to get the sentence position embedding
        assert (
            labels is not None
        ), "labels need to be supplied for multi-sentence objective"
        labels = labels[: (modified_batch_size * num_sentences)]
        assert len(labels.shape) == 1  # assert one dimension

        # logits = self.classifier(sequence_output)
        logits = self.lm_head(sequence_output)

        # retrieval auxiliary loss head
        sentence_labels = torch.full(
            (modified_batch_size, seq_length + 1 + num_sentences),
            0,
            device=input_ids.device,
        ).long()
        # skip the cls and prefix tokens
        sentence_labels[:, 1 + num_sentences :] = torch.randint(
            num_sentences, (modified_batch_size, seq_length)
        ).to(input_ids.device)
        # index into input ids to get the corresponding labels
        input_ids = input_ids.view(modified_batch_size, num_sentences, -1)
        input_ids = input_ids.permute(0, 2, 1)
        # TODO check this
        retrieval_labels = input_ids[
            torch.arange(modified_batch_size)
            .to(input_ids.device)
            .unsqueeze(1)
            .expand(modified_batch_size, seq_length + 1 + num_sentences),
            torch.arange(seq_length + 1 + num_sentences)
            .to(input_ids.device)
            .unsqueeze(0)
            .expand(modified_batch_size, seq_length + 1 + num_sentences),
            sentence_labels,
        ]
        retrieval_labels[:, : (num_sentences + 1)] = -100

        pad_mask = retrieval_labels == 1
        # wipe of 1 - (0.1  *  retrieval percentage) of pad tokens
        # pad_mask_wipe = pad_mask & torch.bernoulli(torch.full(retrieval_labels.shape, 1 - (0.1 * self.config.retrieval_percentage))).bool().to(input_ids.device)
        pad_mask_wipe = pad_mask
        # don't want to predict too many retrieval tokens, wipe of 80% of the labels and set to -100
        # retrieval_remove_indices = torch.bernoulli(torch.full(retrieval_labels.shape, 1 - self.config.retrieval_percentage)).bool().to(input_ids.device)
        non_pad_mask_wipe = ~pad_mask & torch.bernoulli(
            torch.full(retrieval_labels.shape, 1 - self.config.retrieval_percentage)
        ).bool().to(input_ids.device)
        retrieval_labels[non_pad_mask_wipe] = -100

        retrieval_labels[pad_mask_wipe] = -100

        retrieval_predictions = self.retrieval_head(sequence_output, sentence_labels)
        retrieval_loss = None
        task_loss = None
        loss = None
        if labels is not None:
            if self.num_labels == 1:
                #  We are doing regression
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                loss_fct = CrossEntropyLoss()
                task_loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
                retrieval_loss = loss_fct(
                    retrieval_predictions.view(-1, self.config.vocab_size),
                    retrieval_labels.view(-1),
                )
                loss = (self.task_loss_coeff * task_loss) + (
                    self.retrieval_loss_coeff * retrieval_loss
                )

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutputWithAuxiliary(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            task_loss=task_loss,
            retrieval_loss=retrieval_loss,
        )

class RobertaMultiSentenceTokenClassificationParallel(RobertaPreTrainedModel):
    _keys_to_ignore_on_load_missing = [r"position_ids"]

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.config = config

        self.num_sentences = config.num_sentences
        self.sentence_loss_fct = config.sentence_loss_fct
        self.random_encoding_norm = config.random_encoding_norm
        self.retrieval_loss_coeff = config.retrieval_loss_coeff

        self.task_loss_coeff = config.task_loss_coeff

        self.roberta = RobertaModel(config, add_pooling_layer=False)
        # self.roberta = RobertaModel(config)
        self.freeze_encoder_layers = config.freeze_encoder_layers
        assert (
            self.freeze_encoder_layers <= config.num_hidden_layers
        ), "number of layers to freeze exceeds total number of layers"
        # freeze encoder
        for name, param in self.roberta.named_parameters():
            if "encoder" in name:
                layer = "".join(filter(str.isdigit, name))
                layer = int(layer)
                assert (
                    layer <= config.num_hidden_layers
                ), "extracted layer exceeds total number of layer, check for bug"
                if layer <= self.freeze_encoder_layers - 1:
                    param.requires_grad = False

        self.lm_head = RobertaTokenClassificationHeadConditional(config)
        # self.lm_head = RobertaClassificationHeadParallelLikeLMNoPrefixDebug(config)
        self.retrieval_head = RobertaLMHeadConditional(config)
        # self.classifier = RobertaClassificationHeadParallel(config)
        self.init_weights()

        # constant embedings, transfer weights and set requires grad to zero
        d_model = config.hidden_size
        sentence_embedding = None

        if self.sentence_loss_fct == "mlm_multisentence_sinusoidal":
            sentence_embedding = sinusoidal_encoding(self.num_sentences, d_model)

        elif self.sentence_loss_fct == "mlm_multisentence_sinusoidal_morespread":
            all_embeds = sinusoidal_encoding(10000, d_model)
            sentence_embedding = all_embeds[
                torch.randint(10000 - 1, (self.num_sentences,)) + 1, :
            ]

        elif self.sentence_loss_fct == "mlm_multisentence_random":
            sentence_embedding = random_encoding(
                self.num_sentences, d_model, norm=self.random_encoding_norm
            )
        elif self.sentence_loss_fct == "mlm_multisentence_random_orthogonal":
            sentence_embedding = [
                torch.from_numpy(ortho_group.rvs(config.hidden_size)).float()
                for _ in range(self.num_sentences)
            ]
            sentence_embedding = torch.stack(sentence_embedding, dim=0)
        elif self.sentence_loss_fct == "mlm_multisentence_binary":
            sentence_embedding = binary_encoding(
                self.num_sentences, d_model, epsilon=config.binary_encoding_epsilon
            )
        else:
            raise NotImplementedError()

        if sentence_embedding is not None:
            self.sentence_embedding = torch.nn.Parameter(sentence_embedding)
        else:
            # if no option is specified default
            all_embeds = sinusoidal_encoding(10000, d_model)
            sentence_embedding = all_embeds[
                torch.randint(10000 - 1, (self.num_sentences,)) + 1, :
            ]
        if not config.learnt_embeddings:
            self.sentence_embedding.requires_grad = False
        else:
            self.sentence_embedding.requires_grad = True

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
            Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
            config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
            If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )
        # get input embeddings and average over N sentence
        input_shape = input_ids.size()

        batch_size, seq_length = input_shape
        num_sentences = self.num_sentences
        past_key_values_length = 0

        # add the prefix
        # [CLS1, <s>, <s>, <s>, <s>]
        # [<s>, CLS2, <s>, <s>, <s>]
        # [<s>, <s>, CLS3, <s>, <s>]
        # [<s>, <s>, <s>, CLS4, <s>]
        # [<s>, <s>, <s>, <s>, CLS5]
        # let us just assume the last 5 tokens barring the masked token
        # are the cls tokens (easiest way to make use of existing vocab)

        # prefix 5 x 5
        prefix = torch.full((num_sentences, num_sentences), 50000).to(input_ids.device)
        modified_batch_size = batch_size // num_sentences
        prefix[
            torch.arange(num_sentences).to(input_ids.device),
            torch.arange(num_sentences).to(input_ids.device),
        ] = (
            -(torch.arange(num_sentences).to(input_ids.device) + 2)
            + self.roberta.embeddings.word_embeddings.weight.shape[0]
        )

        # [-2   <s>, <s>, <s>, <s>]
        # [<s>, -3, <s>, <s>, <s>]
        # [<s>, <s>, -4, <s>, <s>]
        # [<s>, <s>, <s>, -5, <s>]
        # [<s>, <s>, <s>, <s>, -6]
        # +  size of vocab
        cls_tokens = torch.full((num_sentences, 1), 49923).to(input_ids.device)
        prefix = torch.cat([prefix, cls_tokens], dim=1)

        prefix = prefix.repeat(modified_batch_size, 1)
        input_ids = input_ids[: (modified_batch_size * num_sentences)]
        input_ids = torch.cat([prefix, input_ids], dim=1)

        # # add of end of sentence tokens
        # lens = torch.sum(attention_mask, dim=1)
        # lens = torch.clamp(lens + 1, max=attention_mask.shape[1])
        # input_ids[torch.arange(input_ids.shape[0]), lens - 1] = 50163

        # concatenate
        embedding_output = self.roberta.embeddings(
            input_ids=input_ids,
            position_ids=position_ids,
            token_type_ids=token_type_ids,
            inputs_embeds=inputs_embeds,
            past_key_values_length=past_key_values_length,
        )
        _, _, embedding_dim = embedding_output.shape

        if self.sentence_loss_fct == "mlm_multisentence_random_orthogonal":
            embedding_output = embedding_output.view(
                modified_batch_size,
                num_sentences,
                seq_length + num_sentences + 1,
                embedding_dim,
            )
            embedding_output = torch.matmul(
                self.sentence_embedding, embedding_output.permute(0, 1, 3, 2)
            )
            # swap the last 2 dimension again
            embedding_output = embedding_output.permute(0, 1, 3, 2)
            # average across the sentences
            embedding_output = torch.sum(embedding_output, dim=1) / math.sqrt(
                self.num_sentences
            )
        else:
            # embedding_output = embedding_output[: (modified_batch_size * num_sentences)]
            embedding_output = embedding_output.view(
                modified_batch_size,
                num_sentences,
                seq_length + num_sentences + 1,
                embedding_dim,
            )

            # extract relevant sentence embeddings
            sentence_embed = self.sentence_embedding[:num_sentences, :]
            sentence_embed = sentence_embed.unsqueeze(1).expand(
                num_sentences, seq_length + num_sentences + 1, embedding_dim
            )
            sentence_embed = sentence_embed.to(embedding_output.device)
            embedding_output = embedding_output * sentence_embed.unsqueeze(0)

            embedding_output = torch.mean(embedding_output, dim=1)

        outputs = self.roberta(
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            position_ids=position_ids,
            inputs_embeds=embedding_output,
            return_dict=return_dict,
        )
        sequence_output = outputs[0]
        # fancy indexing to get the sentence position embedding
        assert (
            labels is not None
        ), "labels need to be supplied for multi-sentence objective"
        labels = labels[: (modified_batch_size * num_sentences)]

        # logits = self.classifier(sequence_output)
        logits = self.lm_head(sequence_output)

        # retrieval auxiliary loss head
        sentence_labels = torch.full(
            (modified_batch_size, seq_length + 1 + num_sentences),
            0,
            device=input_ids.device,
        ).long()
        # skip the cls and prefix tokens
        sentence_labels[:, 1 + num_sentences :] = torch.randint(
            num_sentences, (modified_batch_size, seq_length)
        ).to(input_ids.device)
        # index into input ids to get the corresponding labels
        input_ids = input_ids.view(modified_batch_size, num_sentences, -1)
        input_ids = input_ids.permute(0, 2, 1)
        # TODO check this
        retrieval_labels = input_ids[
            torch.arange(modified_batch_size)
            .to(input_ids.device)
            .unsqueeze(1)
            .expand(modified_batch_size, seq_length + 1 + num_sentences),
            torch.arange(seq_length + 1 + num_sentences)
            .to(input_ids.device)
            .unsqueeze(0)
            .expand(modified_batch_size, seq_length + 1 + num_sentences),
            sentence_labels,
        ]
        retrieval_labels[:, : (num_sentences + 1)] = -100

        pad_mask = retrieval_labels == 1
        # wipe of 1 - (0.1  *  retrieval percentage) of pad tokens
        # pad_mask_wipe = pad_mask & torch.bernoulli(torch.full(retrieval_labels.shape, 1 - (0.1 * self.config.retrieval_percentage))).bool().to(input_ids.device)
        pad_mask_wipe = pad_mask
        # don't want to predict too many retrieval tokens, wipe of 80% of the labels and set to -100
        # retrieval_remove_indices = torch.bernoulli(torch.full(retrieval_labels.shape, 1 - self.config.retrieval_percentage)).bool().to(input_ids.device)
        non_pad_mask_wipe = ~pad_mask & torch.bernoulli(
            torch.full(retrieval_labels.shape, 1 - self.config.retrieval_percentage)
        ).bool().to(input_ids.device)
        retrieval_labels[non_pad_mask_wipe] = -100

        retrieval_labels[pad_mask_wipe] = -100

        retrieval_predictions = self.retrieval_head(sequence_output, sentence_labels)
        retrieval_loss = None
        task_loss = None
        loss = None
        if labels is not None:
            if attention_mask is not None:
                loss_fct = CrossEntropyLoss()
                active_loss = attention_mask.view(-1) == 1
                # remove the logits corresponding to the CLS token and the prefix
                logits = logits[:, (self.num_sentences + 1):, :]

                active_logits = logits.reshape(-1, self.num_labels)
                active_labels = torch.where(
                    active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
                )
                task_loss = loss_fct(active_logits, active_labels)
                retrieval_loss = loss_fct(
                    retrieval_predictions.view(-1, self.config.vocab_size),
                    retrieval_labels.view(-1),
                )
                loss = (self.task_loss_coeff * task_loss) + (
                    self.retrieval_loss_coeff * retrieval_loss
                )

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TokenClassifierOutputWithAuxiliary(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            task_loss=task_loss,
            retrieval_loss=retrieval_loss,
        )


class RobertaTokenClassificationHeadConditional(nn.Module):
    """Roberta Head for masked language modeling."""

    def __init__(self, config):
        super().__init__()
        self.num_sentences = config.num_sentences
        self.dense = nn.Linear(2 * config.hidden_size, config.hidden_size)
        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

        self.decoder = nn.Linear(config.hidden_size, config.num_labels)
        self.bias = nn.Parameter(torch.zeros(config.num_labels))
        self.decoder.bias = self.bias

    def forward(self, features, **kwargs):

        # extract the first <num sentence> representations and concatenate with the right word
        batch, seqlength, feature_dim = features.shape
        sentence_representations = features[:, : self.num_sentences, :]
        # concatenate features with the sentence representations based on sentence_labels
        # don't overwrite sentence labels !!

        # need to expand the batch to the original size, need to make predictions
        # on the original
        sentence_representations = sentence_representations.unsqueeze(2).expand(
            batch, self.num_sentences, seqlength, feature_dim
        )
        features = features.unsqueeze(1).expand(batch, self.num_sentences, seqlength, feature_dim)
        features = torch.cat([sentence_representations, features], dim=3)
        # increase the batch size by collapsing the first 2 dimensions
        features = features.view(-1, seqlength, 2 * feature_dim)
        x = self.dense(features)
        x = gelu(x)
        x = self.layer_norm(x)

        # project back to size of vocabulary with bias
        x = self.decoder(x)

        return x

    def _tie_weights(self):
        # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
        self.bias = self.decoder.bias


class SentenceHead(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.out_proj = nn.Linear(config.hidden_size, config.num_sentences)

    def forward(self, features, **kwargs):
        x = self.dropout(features)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x


class RobertaClassificationHeadParallel(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(self, config):
        super().__init__()
        self.num_sentences = config.num_sentences
        self.dense = nn.Linear(2 * config.hidden_size, config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, features, **kwargs):
        batch, _, _ = features.shape
        # conditional_embeds = self.conditional_softmax_embeddings.repeat(batch, 1)
        conditional_embeds = features[
            :, : self.num_sentences, :
        ]  # take <s> token (equiv. to [CLS])
        conditional_embeds = conditional_embeds.reshape(
            -1, conditional_embeds.shape[-1]
        )
        # extract the added [CLS] token during inference
        x = features[:, self.num_sentences, :]
        x = x.unsqueeze(1).repeat(1, self.num_sentences, 1)
        x = x.view(-1, x.shape[-1])
        # concatenate conditional sofmax ids
        x = torch.cat([x, conditional_embeds], dim=1)
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x


class RobertaClassificationHeadParallelLikeLM(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(self, config):
        super().__init__()
        self.num_sentences = config.num_sentences
        self.dense = nn.Linear(2 * config.hidden_size, config.hidden_size)
        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.dense_before_out_proj = nn.Linear(config.hidden_size, config.hidden_size)
        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, features, **kwargs):
        batch, _, _ = features.shape
        # conditional_embeds = self.conditional_softmax_embeddings.repeat(batch, 1)
        conditional_embeds = features[
            :, : self.num_sentences, :
        ]  # take <s> token (equiv. to [CLS])
        conditional_embeds = conditional_embeds.reshape(
            -1, conditional_embeds.shape[-1]
        )
        # extract the added [CLS] token during inference
        x = features[:, self.num_sentences, :]
        x = x.unsqueeze(1).repeat(1, self.num_sentences, 1)
        x = x.view(-1, x.shape[-1])

        # lm head like stuff
        x = torch.cat([conditional_embeds, x], dim=1)
        x = self.dense(x)
        x = gelu(x)
        x = self.layer_norm(x)
        # project back to size of vocabulary with bias
        x = self.dense_before_out_proj(x)
        x = gelu(x)
        x = self.out_proj(x)
        # concatenate conditional sofmax ids
        return x


@dataclass
class SentenceLMOutput(ModelOutput):
    """
    Base class for masked language models outputs.

    Args:
        loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
            Masked language modeling (MLM) loss.
        logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
            sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        sent_loss: sent classification loss i.e choosing which sentence masked token comes from
    """

    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    sent_loss: Optional[torch.FloatTensor] = None
    lm_loss: Optional[torch.FloatTensor] = None
    retrieval_loss: Optional[torch.FloatTensor] = None


@dataclass
class SequenceClassifierOutputWithAuxiliary(SequenceClassifierOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    task_loss: Optional[torch.FloatTensor] = None
    retrieval_loss: Optional[torch.FloatTensor] = None

@dataclass
class TokenClassifierOutputWithAuxiliary(ModelOutput):
    """
    Base class for outputs of token classification models.

    Args:
        loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
            Classification loss.
        logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
            Classification scores (before SoftMax).
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
            sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    task_loss: Optional[torch.FloatTensor] = None
    retrieval_loss: Optional[torch.FloatTensor] = None

class RobertaLMHeadConditional(nn.Module):
    """Roberta Head for masked language modeling."""

    def __init__(self, config):
        super().__init__()
        self.num_sentences = config.num_sentences
        self.dense = nn.Linear(2 * config.hidden_size, config.hidden_size)
        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

        self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
        self.decoder.bias = self.bias

    def forward(self, features, sentence_labels, **kwargs):
        # extract the first <num sentence> representations and concatenate with the right word
        batch, seqlength, _ = features.shape
        sentence_representations = features[:, : self.num_sentences, :]
        # concatenate features with the sentence representations based on sentence_labels
        # don't overwrite sentence labels !!
        sentence_labels_copy = sentence_labels.clone()
        sentence_labels_copy[sentence_labels == -100] = 0
        sentence_embeds = sentence_representations[
            torch.arange(batch).unsqueeze(1).repeat(1, seqlength).to(features.device),
            sentence_labels_copy,
        ]
        features = torch.cat([sentence_embeds, features], dim=2)
        x = self.dense(features)
        x = gelu(x)
        x = self.layer_norm(x)

        # project back to size of vocabulary with bias
        x = self.decoder(x)

        return x

    def _tie_weights(self):
        # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
        self.bias = self.decoder.bias


class RobertaLMHeadConditionalBeefy(nn.Module):
    """Roberta Head for masked language modeling."""

    def __init__(self, config):
        super().__init__()
        self.num_sentences = config.num_sentences
        self.dense = nn.Linear(2 * config.hidden_size, config.hidden_size)
        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dense2 = nn.Linear(config.hidden_size, config.hidden_size)
        self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # self.dense3 = nn.Linear(config.hidden_size, config.hidden_size)
        # self.layer_norm3 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
        self.decoder.bias = self.bias

    def forward(self, features, sentence_labels, **kwargs):
        # extract the first <num sentence> representations and concatenate with the right word
        batch, seqlength, _ = features.shape
        sentence_representations = features[:, : self.num_sentences, :]
        # concatenate features with the sentence representations based on sentence_labels
        # don't overwrite sentence labels !!
        sentence_labels_copy = sentence_labels.clone()
        sentence_labels_copy[sentence_labels == -100] = 0
        sentence_embeds = sentence_representations[
            torch.arange(batch).unsqueeze(1).repeat(1, seqlength).to(features.device),
            sentence_labels_copy,
        ]
        features = torch.cat([sentence_embeds, features], dim=2)
        x = self.dense(features)
        x = gelu(x)
        x = self.layer_norm(x)

        x = self.dense2(x)
        x = gelu(x)
        x = self.layer_norm2(x)

        # x = self.dense3(x)
        # x = gelu(x)
        # x = self.layer_norm3(x)

        # project back to size of vocabulary with bias
        x = self.decoder(x)

        return x

    def _tie_weights(self):
        # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
        self.bias = self.decoder.bias


class RobertaLMHeadConditionalRetrieval(nn.Module):
    """Roberta Head for masked language modeling."""

    def __init__(self, config):
        super().__init__()
        self.num_sentences = config.num_sentences
        self.num_retrieval_words = config.num_retrieval_words
        self.dense = nn.Linear(2 * config.hidden_size, config.hidden_size)
        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

        self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
        self.decoder.bias = self.bias

    def forward(self, features, sentence_labels):

        # sentence labels: B' x num retrieval words x L
        batch_size, seqlength, embedding_dim = features.shape
        L = seqlength - self.num_sentences
        conditional_embeds = features[:, : self.num_sentences, :]  # B' x N x dim

        x = features[:, self.num_sentences :, :]  # B' x L x dim
        # create (B' x num retrieval words ) x L x (2 x dim) from  B' x N x dim and B' x L x dim
        x = torch.repeat_interleave(x, self.num_retrieval_words, dim=0)
        # conditional_embeds = (
        #     conditional_embeds.reshape(-1, embedding_dim)
        #     .unsqueeze(1)
        #     .expand(x.shape[0], L, embedding_dim)
        # )
        conditional_embeds = conditional_embeds.unsqueeze(2).expand(
            -1, self.num_sentences, L, embedding_dim
        )
        # gather right conditional embeds
        sentence_labels = sentence_labels.unsqueeze(-1).expand(
            batch_size, self.num_retrieval_words, L, embedding_dim
        )
        conditional_embeds = torch.gather(
            conditional_embeds, 1, sentence_labels
        )  # B' x num retreval words x L x dim
        conditional_embeds = conditional_embeds.reshape(-1, L, embedding_dim)

        x = torch.cat([x, conditional_embeds], dim=2)
        x = self.dense(x)
        x = gelu(x)
        x = self.layer_norm(x)

        # project back to size of vocabulary with bias
        x = self.decoder(x)

        return x

    def _tie_weights(self):
        # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
        self.bias = self.decoder.bias


class RobertaForMaskedLMMultiSentenceNoPrefixBSMultipleHeads(RobertaPreTrainedModel):
    _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.bias"]
    _keys_to_ignore_on_load_unexpected = [r"pooler"]

    def __init__(self, config):
        super().__init__(config)

        if config.is_decoder:
            logger.warning(
                "If you want to use `RobertaForMaskedLM` make sure `config.is_decoder=False` for "
                "bi-directional self-attention."
            )

        self.roberta = RobertaModel(config, add_pooling_layer=False)
        self.lm_head = RobertaLMHeadMultipleHeads(config)
        self.num_sentences = config.num_sentences
        self.random_encoding_norm = config.random_encoding_norm
        self.sentence_loss_fct = config.sentence_loss_fct

        self.init_weights()

        freeze_sent_embedding = True
        # constant embedings, transfer weights and set requires grad to zero
        d_model = config.hidden_size
        if self.sentence_loss_fct == "mlm_multisentence_sinusoidal_morespread":
            all_embeds = sinusoidal_encoding(10000, d_model)
            sentence_embedding = all_embeds[
                torch.randint(10000 - 1, (self.num_sentences,)) + 1, :
            ]

        elif self.sentence_loss_fct == "mlm_multisentence_sinusoidal":
            sentence_embedding = sinusoidal_encoding(self.num_sentences, d_model)

        elif self.sentence_loss_fct == "mlm_multisentence_learnt":
            all_embeds = sinusoidal_encoding(10000, d_model)
            sentence_embedding = all_embeds[
                torch.randint(10000 - 1, (self.num_sentences,)) + 1, :
            ]
            freeze_sent_embedding = False

        elif self.sentence_loss_fct == "mlm_multisentence_random":
            sentence_embedding = random_encoding(
                self.num_sentences, d_model, norm=self.random_encoding_norm
            )

        else:
            raise NotImplementedError()

        self.sentence_embedding = torch.nn.Parameter(sentence_embedding)
        if freeze_sent_embedding:
            self.sentence_embedding.requires_grad = False
        else:
            self.sentence_embedding.requires_grad = True

    def get_output_embeddings(self):
        return self.lm_head.decoder

    def set_output_embeddings(self, new_embeddings):
        self.lm_head.decoder = new_embeddings

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
            config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
        kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
            Used to hide legacy arguments that have been deprecated.
        """

        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )
        # get input embeddings and average over N sentence

        input_shape = input_ids.size()
        batch_size, seq_length = input_shape
        num_sentences = self.num_sentences
        if batch_size < num_sentences:
            num_sentences = batch_size
        past_key_values_length = 0

        modified_batch_size = batch_size // num_sentences

        cls_tokens = torch.full((num_sentences, 1), 49923).to(input_ids.device)
        prefix = cls_tokens
        prefix = prefix.repeat(modified_batch_size, 1)

        # add CLS tokens after the prefix to denote start of sentence
        input_ids = input_ids[: (modified_batch_size * num_sentences)]
        labels = labels[: (modified_batch_size * num_sentences)]

        input_ids = torch.cat([prefix, input_ids], dim=1)

        # replace the mask token with other corresponding mask tokens based on which sentence they are being used in

        input_ids = input_ids.view(modified_batch_size, num_sentences, -1)
        new_mask_tokens = torch.arange(50046, 50046 + num_sentences).to(
            input_ids.device
        )
        new_mask_tokens = (
            new_mask_tokens.unsqueeze(1)
            .unsqueeze(0)
            .repeat(modified_batch_size, 1, input_ids.shape[-1])
        )
        input_ids[input_ids == 50264] = new_mask_tokens[input_ids == 50264]
        input_ids = input_ids.view((modified_batch_size * num_sentences), -1)

        # add the end of sentence tokens at the end and also artificially pad with CLS token
        # CLS tokens at the front might change the lengths, check those
        # insert 1-5 PAD tokens after end of sentence token

        trim_lengths = torch.randint(1, 6, (input_ids.shape[0],)).to(input_ids.device)
        new_lengths = input_ids.shape[-1] - trim_lengths
        input_ids[
            torch.arange(input_ids.shape[0]).to(input_ids.device), new_lengths - 1
        ] = 50163
        # need to fill the tokens following the last sentence with the PAD token
        attn_mask = gen_attn_mask(new_lengths, len=input_ids.shape[1])
        input_ids[~attn_mask] = 1

        labels = torch.cat(
            [
                torch.full((modified_batch_size * num_sentences, 1), -100).to(
                    input_ids.device
                ),
                labels,
            ],
            dim=1,
        )

        # don't have MLM loss over PAD tokens

        labels[~attn_mask] = -100

        """
        seq_length = seq_length + num_sentences
        """

        seq_length = seq_length + 1

        embedding_output = self.roberta.embeddings(
            input_ids=input_ids,
            position_ids=position_ids,
            token_type_ids=token_type_ids,
            inputs_embeds=inputs_embeds,
            past_key_values_length=past_key_values_length,
        )
        _, _, embedding_dim = embedding_output.shape
        embedding_output = embedding_output.view(
            modified_batch_size,
            num_sentences,
            seq_length,
            embedding_dim,
        )

        # extract relevant sentence embeddings
        sentence_embed = self.sentence_embedding[:num_sentences, :]
        sentence_embed = sentence_embed.unsqueeze(1).expand(
            num_sentences, seq_length, embedding_dim
        )
        sentence_embed = sentence_embed.to(embedding_output.device)
        embedding_output = embedding_output * sentence_embed.unsqueeze(0)

        embedding_output = torch.mean(embedding_output, dim=1)
        outputs = self.roberta(
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            position_ids=position_ids,
            inputs_embeds=embedding_output,
            return_dict=return_dict,
        )
        sequence_output = outputs[0]

        # fancy indexing to get the sentence position embedding
        assert (
            labels is not None
        ), "labels need to be supplied for multi-sentence objective"
        # ignore labels with the -100 class
        labels = labels.view(modified_batch_size, num_sentences, seq_length)
        labels = labels.permute(0, 2, 1)  # (B' x L x S); B' = B / S
        # find positions where inputs are masked

        # Aggregate B' x L x S into B' x L (select -100 if all the S tokens are -100 else choose the one which isn't -100)
        # assert that there is atmost 1 token which is not -100
        if not torch.all(torch.sum(labels != -100, dim=2) <= 1):
            logger.warning("Multiple targets found")
        masked_indices = torch.sum(labels != -100, dim=2) == 1
        # create new labels
        labels_copy = labels.clone()
        labels_copy[labels == -100] = 0

        legal_labels = torch.full((modified_batch_size, seq_length), -100).to(
            masked_indices.device
        )
        legal_labels[masked_indices] = torch.sum(labels_copy, dim=2)[masked_indices]

        # sentence labels
        sentence_labels = torch.full((modified_batch_size, seq_length), -100).to(
            masked_indices.device
        )
        # HACK hacky, dealing with case when the target label 0
        labels_copy[labels == 0] = -2
        sentence_labels_non_zero = labels_copy.nonzero(as_tuple=True)
        sentence_labels[
            sentence_labels_non_zero[0], sentence_labels_non_zero[1]
        ] = sentence_labels_non_zero[2]
        sentence_labels[~masked_indices] = -100

        # use the prefix outputs to condition the vocab softmax
        # conditioning logic in lm head

        prediction_scores = self.lm_head(sequence_output, sentence_labels)  # B' X L X V
        loss = None

        # single target word, this is just masked mlm loss post the averaged embeddings
        loss_fct = nn.CrossEntropyLoss()
        lm_loss = loss_fct(
            prediction_scores.view(-1, prediction_scores.shape[-1]),
            legal_labels.view(-1),
        )

        loss = lm_loss

        if not return_dict:
            output = (prediction_scores,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SentenceLMOutput(
            loss=loss,
            logits=prediction_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            lm_loss=lm_loss,
        )


class RobertaLMHeadMultipleHeads(nn.Module):
    """Roberta Head for masked language modeling."""

    def __init__(self, config):
        super().__init__()
        self.num_sentences = config.num_sentences
        # initialize different MLPs for different sentences
        self.demux_module = RobertaDemuxModule(config)

        # shared vocab layers across different sentences
        self.dense_pre_vocab = nn.Linear(config.hidden_size, config.hidden_size)
        self.layer_norm_pre_vocab = nn.LayerNorm(
            config.hidden_size, eps=config.layer_norm_eps
        )
        self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
        self.decoder.bias = self.bias

    def forward(self, features, sentence_labels, **kwargs):
        # extract the first <num sentence> representations and concatenate with the right word
        batch, seqlength, _ = features.shape
        all_feats = torch.zeros_like(features)
        all_feats = all_feats.view(-1, features.shape[-1])

        for sent_id in range(self.num_sentences):
            cur_dense1 = getattr(self.demux_module, f"dense_{sent_id}")
            cur_dense2 = getattr(self.demux_module, f"dense2_{sent_id}")
            cur_layer_norm = getattr(self.demux_module, f"layer_norm_{sent_id}")
            cur_layer_norm2 = getattr(self.demux_module, f"layer_norm2_{sent_id}")
            dropout = getattr(self.demux_module, f"dropout_{sent_id}")

            cur_sent_mask = sentence_labels == sent_id
            cur_sent_feats = features[cur_sent_mask]

            x = dropout(cur_sent_feats)
            x = cur_dense1(x)
            x = gelu(x)
            x = cur_layer_norm(x)

            # x = dropout(x)
            # x = cur_dense2(x)
            # x = gelu(x)
            # x = cur_layer_norm2(x)

            all_feats[cur_sent_mask.view(-1), :] = x

        # reshape into  B x L x V
        all_feats = all_feats.view(batch, seqlength, -1)
        # project back to size of vocabulary with bias
        x = self.dense_pre_vocab(all_feats)
        x = gelu(x)
        x = self.layer_norm_pre_vocab(x)
        x = self.decoder(x)

        return x

    def _tie_weights(self):
        # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
        self.bias = self.decoder.bias


class RobertaForMaskedLMMultiSentenceNoPrefixBSMultipleHeadsSequenceClassification(
    RobertaPreTrainedModel
):
    _keys_to_ignore_on_load_missing = [r"position_ids"]

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.config = config

        self.num_sentences = config.num_sentences
        self.sentence_loss_fct = config.sentence_loss_fct
        self.random_encoding_norm = config.random_encoding_norm
        self.retrieval_loss_coeff = config.retrieval_loss_coeff
        self.task_loss_coeff = config.task_loss_coeff

        self.roberta = RobertaModel(config, add_pooling_layer=False)
        # self.roberta = RobertaModel(config)
        self.freeze_encoder_layers = config.freeze_encoder_layers
        assert (
            self.freeze_encoder_layers <= config.num_hidden_layers
        ), "number of layers to freeze exceeds total number of layers"
        # freeze encoder
        for name, param in self.roberta.named_parameters():
            if "encoder" in name:
                layer = "".join(filter(str.isdigit, name))
                layer = int(layer)
                assert (
                    layer <= config.num_hidden_layers
                ), "extracted layer exceeds total number of layer, check for bug"
                if layer <= self.freeze_encoder_layers - 1:
                    param.requires_grad = False

        self.lm_head = RobertaClassificationHeadParallelLikeLMNoPrefixMultipleHeads(
            config
        )
        self.retrieval_head = RobertaLMHeadMultipleHeads(config)
        self.init_weights()
        # TODO this mignt not work if we are transferring weights from a pretrained model
        self.retrieval_head.demux_module = self.lm_head.demux_module
        # constant embedings, transfer weights and set requires grad to zero
        d_model = config.hidden_size
        sentence_embedding = None

        if self.sentence_loss_fct == "mlm_multisentence_sinusoidal":
            sentence_embedding = sinusoidal_encoding(self.num_sentences, d_model)

        elif self.sentence_loss_fct == "mlm_multisentence_sinusoidal_morespread":
            all_embeds = sinusoidal_encoding(10000, d_model)
            sentence_embedding = all_embeds[
                torch.randint(10000 - 1, (self.num_sentences,)) + 1, :
            ]

        elif self.sentence_loss_fct == "mlm_multisentence_random":
            sentence_embedding = random_encoding(
                self.num_sentences, d_model, norm=self.random_encoding_norm
            )
        elif self.sentence_loss_fct == "mlm_multisentence_random_orthogonal":
            sentence_embedding = [
                torch.from_numpy(ortho_group.rvs(config.hidden_size)).float()
                for _ in range(self.num_sentences)
            ]
            sentence_embedding = torch.stack(sentence_embedding, dim=0)
        elif self.sentence_loss_fct == "mlm_multisentence_binary":
            sentence_embedding = binary_encoding(
                self.num_sentences, d_model, epsilon=config.binary_encoding_epsilon
            )
        else:
            raise NotImplementedError()

        if sentence_embedding is not None:
            self.sentence_embedding = torch.nn.Parameter(sentence_embedding)
        else:
            # if no option is specified default
            all_embeds = sinusoidal_encoding(10000, d_model)
            sentence_embedding = all_embeds[
                torch.randint(10000 - 1, (self.num_sentences,)) + 1, :
            ]

        if not config.learnt_embeddings:
            self.sentence_embedding.requires_grad = False
        else:
            self.sentence_embedding.requires_grad = True

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
            Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
            config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
            If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )
        # get input embeddings and average over N sentence
        input_shape = input_ids.size()

        batch_size, seq_length = input_shape
        num_sentences = self.num_sentences
        past_key_values_length = 0
        modified_batch_size = batch_size // num_sentences

        # add the prefix
        # [CLS1, <s>, <s>, <s>, <s>]
        # [<s>, CLS2, <s>, <s>, <s>]
        # [<s>, <s>, CLS3, <s>, <s>]
        # [<s>, <s>, <s>, CLS4, <s>]
        # [<s>, <s>, <s>, <s>, CLS5]
        # let us just assume the last 5 tokens barring the masked token
        # are the cls tokens (easiest way to make use of existing vocab)

        # prefix 5 x 5
        cls_tokens = torch.full((num_sentences, 1), 49923).to(input_ids.device)
        cls_tokens = cls_tokens.repeat(modified_batch_size, 1)
        # prefix = prefix.repeat(modified_batch_size, 1)
        input_ids = input_ids[: (modified_batch_size * num_sentences)]
        input_ids[:, 0:1] = cls_tokens

        # # add of end of sentence tokens
        # lens = torch.sum(attention_mask, dim=1)
        # input_ids[torch.arange(input_ids.shape[0]), lens - 1] = 50163

        embedding_output = self.roberta.embeddings(
            input_ids=input_ids,
            position_ids=position_ids,
            token_type_ids=token_type_ids,
            inputs_embeds=inputs_embeds,
            past_key_values_length=past_key_values_length,
        )
        _, _, embedding_dim = embedding_output.shape

        if self.sentence_loss_fct == "mlm_multisentence_random_orthogonal":
            embedding_output = embedding_output.view(
                modified_batch_size,
                num_sentences,
                seq_length,
                embedding_dim,
            )
            embedding_output = torch.matmul(
                self.sentence_embedding, embedding_output.permute(0, 1, 3, 2)
            )
            # swap the last 2 dimension again
            embedding_output = embedding_output.permute(0, 1, 3, 2)
            # average across the sentences
            embedding_output = torch.sum(embedding_output, dim=1) / math.sqrt(
                self.num_sentences
            )
        else:
            # embedding_output = embedding_output[: (modified_batch_size * num_sentences)]
            embedding_output = embedding_output.view(
                modified_batch_size,
                num_sentences,
                seq_length,
                embedding_dim,
            )

            # extract relevant sentence embeddings
            sentence_embed = self.sentence_embedding[:num_sentences, :]
            sentence_embed = sentence_embed.unsqueeze(1).expand(
                num_sentences, seq_length, embedding_dim
            )
            sentence_embed = sentence_embed.to(embedding_output.device)
            embedding_output = embedding_output * sentence_embed.unsqueeze(0)

            embedding_output = torch.mean(embedding_output, dim=1)

        outputs = self.roberta(
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            position_ids=position_ids,
            inputs_embeds=embedding_output,
            return_dict=return_dict,
        )
        sequence_output = outputs[0]
        # fancy indexing to get the sentence position embedding
        assert (
            labels is not None
        ), "labels need to be supplied for multi-sentence objective"
        labels = labels[: (modified_batch_size * num_sentences)]
        assert len(labels.shape) == 1  # assert one dimension

        logits = self.lm_head(sequence_output)

        # retrieval auxiliary loss head
        sentence_labels = torch.full(
            (modified_batch_size, seq_length),
            0,
            device=input_ids.device,
        ).long()
        # skip the cls and prefix tokens
        sentence_labels[:, 1:] = torch.randint(
            num_sentences, (modified_batch_size, seq_length - 1)
        ).to(input_ids.device)
        # index into input ids to get the corresponding labels
        input_ids = input_ids.view(modified_batch_size, num_sentences, -1)
        input_ids = input_ids.permute(0, 2, 1)
        retrieval_labels = input_ids[
            torch.arange(modified_batch_size)
            .to(input_ids.device)
            .unsqueeze(1)
            .expand(modified_batch_size, seq_length),
            torch.arange(seq_length)
            .to(input_ids.device)
            .unsqueeze(0)
            .expand(modified_batch_size, seq_length),
            sentence_labels,
        ]
        retrieval_labels[:, :1] = -100

        pad_mask = retrieval_labels == 1
        # wipe of 1 - (0.1  *  retrieval percentage) of pad tokens
        # pad_mask_wipe = pad_mask & torch.bernoulli(torch.full(retrieval_labels.shape, 1 - (0.1 * self.config.retrieval_percentage))).bool().to(input_ids.device)
        pad_mask_wipe = pad_mask
        # don't want to predict too many retrieval tokens, wipe of 80% of the labels and set to -100
        # retrieval_remove_indices = torch.bernoulli(torch.full(retrieval_labels.shape, 1 - self.config.retrieval_percentage)).bool().to(input_ids.device)
        non_pad_mask_wipe = ~pad_mask & torch.bernoulli(
            torch.full(retrieval_labels.shape, 1 - self.config.retrieval_percentage)
        ).bool().to(input_ids.device)
        retrieval_labels[non_pad_mask_wipe] = -100

        retrieval_labels[pad_mask_wipe] = -100

        retrieval_predictions = self.retrieval_head(sequence_output, sentence_labels)
        retrieval_loss = None
        task_loss = None
        loss = None
        if labels is not None:
            if self.num_labels == 1:
                #  We are doing regression
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                loss_fct = CrossEntropyLoss()
                task_loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
                retrieval_loss = loss_fct(
                    retrieval_predictions.view(-1, self.config.vocab_size),
                    retrieval_labels.view(-1),
                )
                loss = (self.task_loss_coeff * task_loss) + (
                    self.retrieval_loss_coeff * retrieval_loss
                )

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutputWithAuxiliary(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            task_loss=task_loss,
            retrieval_loss=retrieval_loss,
        )


class RobertaClassificationHeadParallelLikeLMNoPrefixMultipleHeads(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(self, config):
        super().__init__()
        self.num_sentences = config.num_sentences
        # initialize different MLPs for different sentences
        self.demux_module = RobertaDemuxModule(config)
        # shared vocab layers across different sentences
        self.dense_before_out_proj = nn.Linear(config.hidden_size, config.hidden_size)
        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
        self.layernorm_presoftmax = nn.LayerNorm(
            config.hidden_size, eps=config.layer_norm_eps
        )

    def forward(self, features):
        # extract the first <num sentence> representations and concatenate with the right word
        _, _, _ = features.shape
        all_feats = []
        for sent_id in range(self.num_sentences):
            cur_dense1 = getattr(self.demux_module, f"dense_{sent_id}")
            cur_dense2 = getattr(self.demux_module, f"dense2_{sent_id}")
            cur_layer_norm = getattr(self.demux_module, f"layer_norm_{sent_id}")
            cur_layer_norm2 = getattr(self.demux_module, f"layer_norm2_{sent_id}")
            dropout = getattr(self.demux_module, f"dropout_{sent_id}")

            cls_feat = features[:, 0, :]
            x = dropout(cls_feat)
            x = cur_dense1(x)
            x = gelu(x)
            x = cur_layer_norm(x)
            # x = dropout(x)
            # x = cur_dense2(x)
            # x = gelu(x)
            # x = cur_layer_norm2(x)
            all_feats.append(x)

        all_feats = torch.stack(all_feats, dim=1)
        # collapse the first 2 dimensions
        all_feats = all_feats.view(-1, all_feats.shape[-1])
        # project back to size of vocabulary with bias
        x = self.dense_before_out_proj(all_feats)
        x = gelu(x)
        x = self.layernorm_presoftmax(x)
        x = self.out_proj(x)

        return x


class RobertaDemuxModule(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_sentences = config.num_sentences
        # initialize different MLPs for different sentences
        for sent_id in range(self.num_sentences):
            setattr(
                self,
                f"dense_{sent_id}",
                nn.Linear(config.hidden_size, config.hidden_size),
            )
            setattr(
                self,
                f"dense2_{sent_id}",
                nn.Linear(config.hidden_size, config.hidden_size),
            )
            setattr(
                self,
                f"layer_norm_{sent_id}",
                nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps),
            )
            setattr(
                self,
                f"layer_norm2_{sent_id}",
                nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps),
            )
            setattr(self, f"dropout_{sent_id}", nn.Dropout(config.hidden_dropout_prob))