from torch import nn, einsum
from einops import rearrange


class FFN(nn.Module):
    def __init__(self,
                 width,
                 mult=4,
                 ):
        super().__init__()
        inner_width = int(width * mult)
        self.net = nn.Sequential(
            nn.Linear(width, inner_width),
            nn.GELU(),
            nn.Linear(inner_width, width)
        )
        self.input_norm = nn.LayerNorm(width)

    def forward(self, x):
        x = self.input_norm(x)
        return self.net(x)

class Attention(nn.Module):
    def __init__(self,
                 width,
                 attention_heads=8,
                 ):
        super().__init__()
        self.attention_heads = attention_heads
        width_head = int(width / attention_heads)
        self.scale = width_head ** -0.5
        self.create_qkv = nn.Linear(width, width * 3, bias=False)
        self.out = nn.Linear(width, width)
        self.input_norm = nn.LayerNorm(width)

    def forward(self, x, alibi):
        x = self.input_norm(x)
        q, k, v = self.create_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.attention_heads), (q, k, v))
        attention_scores = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
        attention_scores = attention_scores + alibi
        attn = attention_scores.softmax(dim=-1)
        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        return self.out(rearrange(out, 'b h n d -> b n (h d)'))

class CrossAttention(nn.Module):
    def __init__(self,
                 width,
                 attention_heads=8,
                 ):
        super().__init__()
        self.attention_heads = attention_heads
        width_head = int(width / attention_heads)
        self.scale = width_head ** -0.5
        self.create_q = nn.Linear(width, width, bias=False)
        self.create_k = nn.Linear(width, width, bias=False)
        self.create_v = nn.Linear(width, width, bias=False)
        self.to_out = nn.Linear(width, width)
        self.input_norm = nn.LayerNorm(width)

    def forward(self, x, context, alibi):
        x = self.input_norm(x)
        context = self.input_norm(context)
        q = self.create_q(x)
        k = self.create_k(context)
        v = self.create_v(context)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.attention_heads), (q, k, v))
        attention_scores = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
        attention_scores = attention_scores + alibi
        attn = attention_scores.softmax(dim=-1)
        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class BaseTransformer(nn.Module):
    def __init__(self,
                 width,
                 layers,
                 attention_heads=8,
                 ff_mult=4,
                 final_norm=True,
                 ):
        super().__init__()
        self.final_norm = final_norm
        self.layers = nn.ModuleList([])
        for _ in range(layers):
            self.layers.append(nn.ModuleList([
                Attention(width=width, attention_heads=attention_heads),
                FFN(width=width, mult=ff_mult),
            ]))
        if self.final_norm:
            self.norm_out = nn.LayerNorm(width)

    def forward(self, x, alibi):
        for self_attn, ffn in self.layers:
            x = self_attn(x, alibi) + x
            x = ffn(x) + x
        if self.final_norm:
            return self.norm_out(x)
        else:
            return x

class BaseTransformerCrossAttn(nn.Module):
    def __init__(self,
                 width,
                 layers,
                 attention_heads=8,
                 ff_mult=4,
                 ):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(layers):
            self.layers.append(nn.ModuleList([
                Attention(width=width, attention_heads=attention_heads),
                CrossAttention(width=width, attention_heads=attention_heads),
                FFN(width=width, mult=ff_mult),
            ]))
        self.norm_out = nn.LayerNorm(width)

    def forward(self, x, context, alibi):
        for self_attn, cross_attn, ffn in self.layers:
            x = self_attn(x, alibi) + x
            x = cross_attn(x, context, alibi) + x
            x = ffn(x) + x
        x = self.norm_out(x)
        return x


