import math

import torch
import torch.nn as nn
from torch.nn import functional as F


class MultiheadCrossAttention(nn.Module):
    """
    Optimized by batched matmul operations
    """

    def __init__(self, embed_dim, n_head, input_dim=None, context_dim=None, output_dim=None, dropout=0.0, bias=True):
        super().__init__()

        self.embed_dim = embed_dim
        self.input_dim = input_dim if input_dim else embed_dim
        self.context_dim = context_dim if context_dim else embed_dim
        self.output_dim = output_dim if output_dim else embed_dim
        assert embed_dim % n_head == 0
        self.n_head = n_head
        self.head_dim = embed_dim // n_head

        # key, query, value projections for all heads
        self.key = nn.Linear(self.context_dim, self.embed_dim, bias=bias)
        self.query = nn.Linear(self.input_dim, self.embed_dim, bias=bias)
        self.value = nn.Linear(self.context_dim, self.embed_dim, bias=bias)
        # regularization
        self.attn_drop = nn.Dropout(dropout, inplace=False)
        # output projection
        self.proj = nn.Linear(self.embed_dim, self.output_dim, bias=bias)

        if bias:
            for layer in [self.key, self.query, self.value, self.proj]:
                nn.init.zeros_(layer.bias)

    def forward(self, inputs, contexts):
        assert inputs.shape[0] == contexts.shape[0]

        (B, T, _) = inputs.shape
        (_, T_ctx, _) = contexts.shape

        inputs = inputs.transpose(0, 1).contiguous()  # (B, T, C_in) -> (T, B, C_in)
        contexts = contexts.transpose(0, 1).contiguous()  # (B, T_ctx, C_ctx) -> (T_ctx, B, C_ctx)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q = self.query(inputs).view(T, B * self.n_head, self.head_dim).transpose(0, 1)  # (B*nh, T, hd)
        k = self.key(contexts).view(T_ctx, B * self.n_head, self.head_dim).transpose(0, 1)  # (B*nh, T_ctx, hd)
        v = self.value(contexts).view(T_ctx, B * self.n_head, self.head_dim).transpose(0, 1)  # (B*nh, T_ctx, hd)

        # Tensor shape below: query: (B * nh, T, hd) X key: (B * nh, hd, T_ctx) -> (B * nh, T, T_ctx)
        scale = 1.0 / math.sqrt(self.head_dim)
        att = torch.bmm(q, k.transpose(-2, -1) * scale)
        att = F.softmax(att, dim=-1)
        att = self.attn_drop(att)

        self._attention_map = att.detach().clone()

        y = torch.bmm(att, v)  # (B*nh, T, T_ctx) X (B*nh, T_ctx, hd) -> (B*nh, T, hd)
        y = y.transpose(0, 1).contiguous().view(T, B, self.embed_dim)  # re-assemble all head outputs side by side

        # output projection
        y = self.proj(y)  # (T, B, C_out)

        return y.transpose(0, 1).contiguous()  # (T, B, C_out) -> (B, T, C_out)


class CrossAttentionBlock(nn.Module):
    def __init__(
        self,
        embed_dim,
        n_head,
        input_dim,
        context_dim,
        output_dim,
        dropout=0.,
        bias=True,
        input_layernorm=True,
        residual=False,
    ):
        super().__init__()

        if residual:
            assert input_dim == output_dim

        self.ln_inputs = nn.LayerNorm(input_dim) if input_layernorm else nn.Identity()
        self.cross_attention = MultiheadCrossAttention(
            embed_dim, n_head, input_dim, context_dim, output_dim, dropout=dropout, bias=bias
        )

        self.residual = residual

    def forward(self, inputs, contexts):
        normalized_inputs = self.ln_inputs(inputs)
        outputs = self.cross_attention(normalized_inputs, contexts=contexts)
        if self.residual:
            outputs = inputs + outputs

        return outputs
