from typing import Optional
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Activation, Conv1D, Dropout

from rinokeras.layers import Stack, SelfAttention, PositionEmbedding


class EluConvElu(Stack):

    def __init__(self,
                 filters: int,
                 kernel_size: int) -> None:
        layers = [
            Activation('elu'),
            Conv1D(filters, kernel_size, 1, padding='same'),
            Activation('elu')]
        super().__init__(layers)


class ProteinResidual(Model):

    def __init__(self,
                 filters: int,
                 kernel_size: int,
                 dropout: Optional[float] = None) -> None:
        super().__init__()
        self.forward = EluConvElu(filters, kernel_size)
        self.output_and_gate = Conv1D(2 * filters, kernel_size, 1, padding='same')
        self.dropout = Dropout(0 if dropout is None else dropout)

    def call(self, inputs, padding_mask=None):
        if padding_mask is not None:
            inputs = inputs * padding_mask

        shortcut = inputs
        intermediate = self.forward(inputs)

        if padding_mask is not None:
            intermediate = intermediate * padding_mask

        output, gate = tf.split(self.output_and_gate(intermediate), axis=-1, num_or_size_splits=2)
        gate = tf.nn.sigmoid(gate)
        gated_out = output * gate

        if padding_mask is not None:
            gated_out = gated_out * padding_mask

        gated_out = self.dropout(gated_out)

        return shortcut + gated_out


class ProteinSelfAttention(Model):

    def __init__(self,
                 n_heads: int,
                 key_size: int,
                 value_size: int,
                 dropout: Optional[float],
                 kernel_initializer: Optional[tf.keras.initializers.Initializer] = 'glorot_uniform',
                 kernel_regularizer=None,
                 bias_regularizer=None,
                 activity_regularizer=None) -> None:
        super().__init__()
        self.self_attention = SelfAttention(
            'scaled_dot',
            n_heads=n_heads,
            dropout=dropout,
            key_size=key_size,
            value_size=value_size,
            kernel_initializer=kernel_initializer,
            kernel_regularizer=kernel_regularizer,
            bias_regularizer=bias_regularizer,
            activity_regularizer=activity_regularizer)

    def call(self, inputs, mask, return_attention_weights=False):
        output, attention_weights = self.self_attention(
            inputs, mask=mask, return_attention_weights=True)

        if return_attention_weights:
            return output, attention_weights
        else:
            return output


class ProteinConvAttentionBlock(Model):

    def __init__(self,
                 filters: int,
                 kernel_size: int,
                 dropout: Optional[float] = None) -> None:
        super().__init__()
        dropout = 0 if dropout is None else dropout
        self.resnet = Stack([
            ProteinResidual(filters, kernel_size, int(i == 0) * dropout) for i in range(4)])
        self.posembed = PositionEmbedding(concat=True)
        self.attention_block = ProteinSelfAttention(
            n_heads=1,
            key_size=16,
            value_size=128,
            dropout=dropout)
        self.res_conv = EluConvElu(filters, 1)
        self.attention_conv = EluConvElu(filters, 1)
        self.output_layer = EluConvElu(filters, 1)

    def call(self, inputs, original, padding_mask=None, attention_mask=None):
        resout = self.resnet(inputs, padding_mask=padding_mask)
        # attention_in = self.posembed(tf.concat((resout, original), -1))
        attention_in = tf.concat((resout, original), -1)
        if padding_mask is not None:
            attention_in = attention_in * padding_mask
        attention_out = self.attention_block(attention_in, mask=attention_mask)

        resout = self.res_conv(resout)
        attention_out = self.attention_conv(attention_out)

        combined = resout + attention_out

        return self.output_layer(combined)


class ProteinSNAIL(Model):

    def __init__(self,
                 filters: int,
                 kernel_size: int,
                 num_blocks: int,
                 dropout: Optional[float] = None) -> None:
        super().__init__()
        self.initial_proj = Conv1D(filters, kernel_size, 1, padding='same')
        self.snail = Stack([ProteinConvAttentionBlock(filters, kernel_size, dropout) for _ in range(num_blocks)])

    def call(self, inputs, padding_mask=None, mask=None):
        proj = self.initial_proj(inputs)
        output = self.snail(proj, original=inputs, padding_mask=padding_mask, attention_mask=mask)
        return output
