import tensorflow as tf
from tensorflow.keras.layers import Activation,Layer,ReLU, Input, Flatten, MaxPool2D, Add,BatchNormalization,Dense,Conv2D,GlobalAveragePooling2D
from tensorflow.keras import backend as K
from tensorflow.python.keras.models import Model
import numpy as np


class Swish(Layer):
    def __init__(self, **kwargs):
        super(Swish, self).__init__(**kwargs)

    def call(self, inputs):
        return K.sigmoid(inputs) * inputs
    #def get_config(self):
    #    base_config = super(Swish, self).get_config()
     #   return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        return input_shape
    
class Mish(Layer):


    def __init__(self, **kwargs):
        super(Mish, self).__init__(**kwargs)
        self.supports_masking = True

    def call(self, inputs):
        return inputs * K.tanh(K.softplus(inputs))

    #def get_config(self):
    #    base_config = super(Mish, self).get_config()
    #    return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        return input_shape
    
    
def dense_layer(filters, 
                activation=ReLU, 
                strides=(1,1),
                dropout = 0,
                use_bias = True,
                batchnormalization=False,
                batch_momentum=0.99,
                regul=0.00,
                kernel_constraint=None,
                last_activation=None,
                name=None):
    def f(x):
        bias = not batchnormalization and use_bias
        if regul==0:
            regularizer=None
        else:
            regularizer=l2(regul)
        x =Dense(filters,use_bias=bias,kernel_regularizer=regularizer,kernel_constraint=kernel_constraint)(x)
        if dropout!=0:
            x = Dropout(rate=dropout) (x)
        if batchnormalization:
            x = BatchNormalization(momentum=batch_momentum)(x)
        if last_activation is not None:
            x=Activation(last_activation,name=name)(x)
        else:
            x = activation(name=name)(x)
        #
        return x
    return f       

def get_mlp(shape,
            hidden_layers_size=[],
            last_activation=None,
            use_bias = True,
            activation=ReLU,
            dropout = 0,
            batchnormalization=False,
            batch_momentum=0.99,
            regul=0.00,
            kernel_constraint=None,
            nb_classes = 1):
    inputs=Input(shape)
    x = inputs
    for lay_size in hidden_layers_size:
        x = dense_layer(lay_size,
                        use_bias=use_bias,
                        activation=activation,
                        dropout = dropout,
                        batchnormalization=batchnormalization,
                        batch_momentum = batch_momentum,
                        regul = regul,
                        kernel_constraint = kernel_constraint)(x)
    
    x = dense_layer(nb_classes,
                    use_bias=use_bias,
                    activation=None,
                    dropout = 0,
                    last_activation=last_activation,
                    batchnormalization=False)(x)
   
    model =Model(inputs=inputs, outputs=x)
    return model


def u_conv(filters, dropout=0, kernel=3,
           padding='same',batch_momentum=0.99,
           activation=ReLU, strides=(1,1), 
           initial = 'he_normal',
           batchnormalization =False,regul=0.00,
           kernel_constraint=None):
    def f(x):
        bias = not batchnormalization
        if regul==0:
            regularizer=None
        else:
            regularizer=l2(regul)
       
        x =Conv2D(filters, (kernel, kernel), kernel_initializer=initial, strides=strides, use_bias=bias,
                      padding=padding,kernel_regularizer=regularizer,kernel_constraint=kernel_constraint)(x)
        if dropout!=0:
            x = SpatialDropout2D(rate=dropout) (x)
        if batchnormalization:
            x = BatchNormalization(momentum=batch_momentum)(x)
        x = activation()(x)
        
        return x
    return f

def VGG_layer(filters,level=3,dropout=0, kernel=3,padding='same',activation=ReLU, initial='he_normal', 
              batchnormalization = False,batch_momentum=0.99,first=False, stride = False,last = False,regul=0.00,kernel_constraint=None):
    def f(x):
        for i in range(level):
            strides = (1,1)
            if stride and i ==level-1 and not last:
                strides = (2,2)
            x = u_conv(filters,dropout=dropout, 
                       kernel=kernel,padding=padding,
                       strides = strides,
                       activation=activation,initial=initial, 
                       batchnormalization = batchnormalization,
                       batch_momentum=batch_momentum,regul=regul,kernel_constraint=kernel_constraint)(x)
        if not stride:
            x = MaxPool2D()(x)
        return x
    return f

def VGG(shape, nb_filter = [], conv_layers_size=[], kernel_size=3, dense_layers_size=[], nb_classes=1, padding='same',dropout=0,regul=0,
               activation_conv=None, activation_dense=None, use_bias=True, stride = False,batchnormalization=False, activation_lastlayer=None,verbose=False):
    K.clear_session()

    inputs = Input(shape)
    x = inputs
    if not isinstance(kernel_size, list):
        kernel_size = [kernel_size for i in range(len(conv_layers_size))]
    for i,(l,nb,k)  in enumerate(zip(conv_layers_size,nb_filter,kernel_size)):
        #print(i, ( i== len(conv_layers_size)-1))
        x = VGG_layer(nb, 
                      level = l, 
                      kernel = k,
                      padding = padding,           
                      activation = activation_conv, 
                      stride = stride,
                      last =( i== len(conv_layers_size)-1),
                      batchnormalization =  batchnormalization)(x)
        
    if len(dense_layers_size) == 0:
            x = GlobalAveragePooling2D()(x)
    else :
        x = Flatten()(x)

        for lay_size in dense_layers_size:
            x = dense_layer(lay_size,
                            use_bias=use_bias,
                            activation=activation_dense,
                            dropout = dropout,
                            batchnormalization=batchnormalization,
                            regul = regul)(x)
    
    x = dense_layer(nb_classes,
                    use_bias=use_bias,
                    activation=None,
                    dropout = 0,
                    last_activation=activation_lastlayer,
                    batchnormalization=False)(x)
    model = Model(inputs=inputs, outputs=x)
    if verbose :
        model.summary()
    return model

def res_init(filters,padding='same',initial='he_normal',activation=ReLU, batch_momentum=0.99,
              batchnormalization = False,regul=0.00,kernel_constraint=None) :
    def f(x):
        
        
        bias = not batchnormalization
        if regul==0:
            x =Conv2D(filters, (7, 7), kernel_initializer=initial, strides=(2,2), use_bias=bias, padding=padding)(x)
        else :
            x =Conv2D(filters, (7, 7), kernel_initializer=initial, strides=(2,2), use_bias=bias, padding=padding,kernel_regularizer=l2(regul))(x)
    
        if batchnormalization: 
            x = BatchNormalization(momentum=batch_momentum,renorm=False,epsilon=0.001)(x)
        x = activation()(x)
        
      
        return x
    return f

def res_block(filters,level=3,kernel=3,padding=('same','same'),initial='he_normal',strides=(1,1),first=True,activation=ReLU, batch_momentum=0.99,
              batchnormalization = False,regul=0.00,kernel_constraint=None) :
    def f(x):
        f_first=first
        first_level=True
        for i in range(level):
            #if not f_first and i==0:
            #    strides=(2,2)
            #else:
            #    strides=(1,1)
            
            x = res_layer(filters, kernel=kernel,padding=padding,activation=activation,initial=initial, strides=strides,first=f_first,
                       batchnormalization = batchnormalization,first_level=first_level,batch_momentum=batch_momentum,regul=regul,kernel_constraint=kernel_constraint)(x)
            f_first=False
            first_level=False
        return x
    return f


def res_layer(filters,kernel=3,padding=('same','same'),initial='he_normal',strides=(1,1),first_level=False,activation=ReLU,first=False, bootleneck=True, batch_momentum=0.99,
              batchnormalization = False,regul=0.00,kernel_constraint=None) :
    def f(x):
        

        bias = not batchnormalization
        short=x
        #res=Conv2D(4*filters, (1, 1), kernel_initializer=initial, strides=strides, use_bias=bias)(x)
        #print(first)
        if first_level:
            short=Conv2D(filters, (1, 1), kernel_initializer=initial, strides=strides, use_bias=bias)(short)
        if not first:
            if batchnormalization: 
                x = BatchNormalization(momentum=batch_momentum,renorm=False,epsilon=0.001)(x)

            x = activation()(x)
        
        if bootleneck:
            kernel=1
        else:
            kernel=3
        if regul==0:
            x =Conv2D(filters, (kernel, kernel), kernel_initializer=initial, strides=strides, use_bias=bias, padding=padding[0])(x)
        else:
            x =Conv2D(filters, (kernel, kernel), kernel_initializer=initial,  strides=strides,use_bias=bias, padding=padding[0],kernel_regularizer=l2(regul))(x)
    
        if batchnormalization: 
            x = BatchNormalization(momentum=batch_momentum,renorm=False,epsilon=0.001)(x)
        x = activation()(x)
        kernel=3
        if regul==0:
            x =Conv2D(filters//4, (kernel, kernel), kernel_initializer=initial, use_bias=bias, padding=padding[1])(x)
        else :
            x =Conv2D(filters//4, (kernel, kernel), kernel_initializer=initial, use_bias=bias, padding=padding[1],kernel_regularizer=l2(regul))(x)
        if bootleneck:
            if batchnormalization: 
                x = BatchNormalization(momentum=batch_momentum,renorm=False,epsilon=0.001)(x)
            x = activation()(x)
            kernel=1
            if regul==0:
                x =Conv2D(filters, (kernel, kernel), kernel_initializer=initial, use_bias=bias, padding=padding[1])(x)
            else :
                x =Conv2D(filters, (kernel, kernel), kernel_initializer=initial, use_bias=bias, padding=padding[1],kernel_regularizer=l2(regul))(x)
        x=add([x,short])
        return x
    return f