from typing import Optional

import torch
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer

from bbo.algorithms.np.transformer.positional_embedding import get_pos_embedding_cls


class Transformer(nn.Module):
    def __init__(
        self,
        x_dim, 
        n_out,
        d_model,
        n_head,
        n_hidden,
        dropout,
        n_layer,
        pos_embedding: Optional[str] = None,
    ):
        super().__init__()
        self.x_dim = x_dim
        self.n_out = n_out
        self.d_model = d_model
        self.n_head = n_head
        self.n_hidden = n_hidden
        self.dropout = dropout
        self.n_layer = n_layer
        
        self.x_embedding = nn.Linear(x_dim, d_model)
        self.y_embedding = nn.Linear(1, d_model)
        self.pos_embedding = get_pos_embedding_cls(pos_embedding)(d_model, None)
        encoder_layers = TransformerEncoderLayer(d_model, n_head, n_hidden, dropout, 'gelu', batch_first=False)
        self.transformer_encoder = TransformerEncoder(encoder_layers, n_layer)
        self.decoder = nn.Sequential(
            nn.Linear(d_model, n_hidden),
            nn.GELU(),
            nn.Linear(n_hidden, n_out)
        )

    @staticmethod
    def generate_square_subsequent_mask(sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    @staticmethod
    def generate_D_q_matrix(sz, query_size):
        train_size = sz-query_size
        mask = torch.zeros(sz,sz) == 0
        mask[:,train_size:].zero_()
        mask |= torch.eye(sz) == 1
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def encode(self, x_src, y_src, single_eval_pos, src_mask=None):
        x_emb = self.x_embedding(x_src.float())
        y_emb = self.y_embedding(y_src)
        train_x = x_emb[: single_eval_pos] + y_emb[: single_eval_pos]
        src = torch.cat([train_x, x_emb[single_eval_pos: ]])
        if src_mask is None:
            src_mask = self.generate_D_q_matrix(len(x_src), len(x_src)-single_eval_pos)
        src_mask = src_mask.to(src)
        output = self.transformer_encoder(src, src_mask)
        return output

    def decode(self, encoder_output):
        return self.decoder(encoder_output)

    def forward(self, x_src, y_src, single_eval_pos, src_mask=None):
        """
        Inputs:
            x_src: (seq_len, batch, dim)
            y_src: (seq_len, batch, 1)
        """
        assert x_src.ndim == y_src.ndim == 3

        encoder_output = self.encode(x_src, y_src, single_eval_pos, src_mask)
        decoder_output = self.decode(encoder_output)
        return decoder_output[single_eval_pos: ]

    def predict(self, context_x, context_y, query_x):
        """
        Inputs:
            context_x: (seq_len, 1, dim)
            context_y: (seq_len, 1, 1)
            query_x: (query_len, 1, dim)
        """
        single_eval_pos = len(context_x)
        x_src = torch.cat([context_x, query_x], dim=0)
        y_src = context_y
        out = self.forward(x_src, y_src, single_eval_pos)
        return out