from deel.lip.layers import (
    SpectralConv2D, 
    Identity,
    SpectralDense, 
    FrobeniusDense,
    GlobalAveragePooling2D, 
    ScaledAveragePooling2D, 
    ScaledGlobalL2NormPooling2D,
    ScaledL2NormPooling2D, 
    InvertibleDownSampling,
    CircularPadding,
    SymmetricPadding,
    ScaledGlobalAveragePooling2D)


from deel.lip.extra_layers import SpectralDepthwiseConv2D, BatchCentering

from deel.lip.cayley_layers import CayleyConv2D
from deel.lip.normalizers import DEFAULT_NITER_BJORCK, DEFAULT_NITER_SPECTRAL, DEFAULT_NITER_SPECTRAL_INIT
from deel.lip.regularizers import NormRegularizer
from deel.lip.initializers import SpectralInitializer
from deel.lip.model import Model as LipModel
import tensorflow as tf
from tensorflow.keras import Model as KerasModel
from tensorflow.keras.layers import (
    ReLU, 
    Input, 
    Flatten, 
    MaxPool2D, 
    Add,
    AveragePooling2D ,
    BatchNormalization,
    Dense,
    Lambda,
    Conv2D)
from tensorflow.keras import backend as K
#from tensorflow.python.keras.models import Model
from deel.lip.extra_layers import StiefelConv
import numpy as np


def get_conv2D(filters, kernel_size=(3,3), strides=(1, 1), padding='same',activation=ReLU,batchnormalization = False):
    def f(x):
        x = Conv2D(filters, kernel_size=kernel_size,strides =strides, padding=padding)(x)
        if batchnormalization :
            x = BatchNormalization()(x)
        if activation is not None:
            x = activation()(x)
        return x
    return f
    
def vgg_small_images(shape, nb_classes=1,filter_size = 16,batchnormalization = True,verbose = False):
    inputs=Input(shape)
    
    x = get_conv2D(filter_size,batchnormalization=batchnormalization)(inputs)
    x = get_conv2D(filter_size,batchnormalization=batchnormalization)(x)
    x = get_conv2D(filter_size,batchnormalization=batchnormalization)(x)
    x = get_conv2D(filter_size,batchnormalization=batchnormalization,strides = (2,2))(x)

    x = get_conv2D(2*filter_size,batchnormalization=batchnormalization)(x)
    x = get_conv2D(2*filter_size,batchnormalization=batchnormalization)(x)
    x = get_conv2D(2*filter_size,batchnormalization=batchnormalization,strides = (2,2))(x)

    x = get_conv2D(4*filter_size,batchnormalization=batchnormalization)(x)
    x = get_conv2D(4*filter_size,batchnormalization=batchnormalization)(x)
    x = get_conv2D(4*filter_size,batchnormalization=batchnormalization)(x)

    x = GlobalAveragePooling2D()(x)

    x = Dense(nb_classes,activation="softmax")(x) 
    model=KerasModel(inputs=inputs, outputs=x)
    if verbose :
        model.summary()
    return model 

def vgg_large_images(shape, nb_classes=1,filter_size = 16,batchnormalization = True,verbose = False):
    inputs=Input(shape)
    
    x = get_conv2D(filter_size,batchnormalization=batchnormalization)(inputs)
    x = get_conv2D(filter_size,batchnormalization=batchnormalization)(x)
    x = get_conv2D(filter_size,batchnormalization=batchnormalization,strides = (2,2))(x)

    x = get_conv2D(2*filter_size,batchnormalization=batchnormalization)(x)
    x = get_conv2D(2*filter_size,batchnormalization=batchnormalization)(x)
    x = get_conv2D(2*filter_size,batchnormalization=batchnormalization,strides = (2,2))(x)

    x = get_conv2D(4*filter_size,batchnormalization=batchnormalization)(x)
    x = get_conv2D(4*filter_size,batchnormalization=batchnormalization)(x)
    x = get_conv2D(4*filter_size,batchnormalization=batchnormalization,strides = (2,2))(x)

    x = get_conv2D(8*filter_size,batchnormalization=batchnormalization)(x)
    x = get_conv2D(8*filter_size,batchnormalization=batchnormalization)(x)
    x = get_conv2D(8*filter_size,batchnormalization=batchnormalization)(x)

    x = GlobalAveragePooling2D()(x)

    x = Dense(nb_classes,activation=None)(x) 
    model=KerasModel(inputs=inputs, outputs=x)
    if verbose :
        model.summary()
    return model 

def get_lip_dense(filters, 
                 activation=ReLU, 
                 use_bias=True, 
                 kCoefLip=1.0,
                 by_constraint = False,
                 niter_spectral=DEFAULT_NITER_SPECTRAL, 
                 niter_bjorck=DEFAULT_NITER_BJORCK):
    def f(x):
        nonlocal activation
        base_activation = None
        if activation == "sigmoid":
            base_activation = "sigmoid"
            activation = None
        bias = use_bias
        regularizer = None
        kernel_constraint = None
        
        if filters == 1:
            x = FrobeniusDense(filters, activation=base_activation, 
                               use_bias=bias, k_coef_lip=kCoefLip, by_constraint=by_constraint)(x)  ## better to separate kernel and activation
        else:
            x = SpectralDense(filters, activation=base_activation, 
                              use_bias=bias, kernel_initializer="orthogonal" , by_constraint=by_constraint,
                              k_coef_lip=kCoefLip, niter_spectral= niter_spectral,
                              niter_bjorck = niter_bjorck)(x) ## better to separate kernel and activation
        if activation is not None:
            x = activation()(x)
        return x
    return f

def get_wass_MLP(shape,hidden_layers_size=[], 
                 last_activation=None,use_bias = True, last_bias = True,
                 activation=ReLU, 
                 nb_classes = 1,
                 kCoefLip=1.0, 
                 flatten = False,
                 by_constraint = False,
                 niter_spectral = DEFAULT_NITER_SPECTRAL, 
                 niter_bjorck = DEFAULT_NITER_BJORCK):
    inputs=Input(shape)
    activity_regularizer=None
    wass_net = inputs
    if flatten:
        wass_net=Flatten()(inputs)    
    for lay_size in hidden_layers_size:
        wass_net=get_lip_dense(lay_size,use_bias=use_bias,
                              activation=activation,kCoefLip=kCoefLip, 
                              niter_spectral= niter_spectral, by_constraint=by_constraint,
                              niter_bjorck = niter_bjorck)(wass_net)
    #wass_net = Dense(nb_classes, activation = "sigmoid")(wass_net)
    wass_net=get_lip_dense(nb_classes,use_bias=last_bias,
                          activation=last_activation,by_constraint=by_constraint,
                          kCoefLip=kCoefLip, niter_spectral= niter_spectral, 
                          niter_bjorck = niter_bjorck)(wass_net)  
    wass=LipModel(inputs=inputs, outputs=wass_net)
    return wass






def get_lipConv2D(filters, kernel_size=(3,3),
                  padding='same', 
                  activation=ReLU, 
                  use_bias=True,
                  strides=(1,1), 
                  conv_first = False,
                  by_constraint=False,
                  normconstaint = False,
                  batch_centering = False,
                  pixelwise=True,
                  channelwise=False,
                  stiefel = False,
                  kCoefLip=1.0,
                  k_coef_grad = 1.0,
                  batchnormalization = False,
                  regul_type = "spectral_conv",
                  norm_coeff = 0,
                  lambda_orth = 1.,
                  rescale = 1,
                  niter_spectral=DEFAULT_NITER_SPECTRAL, 
                  niter_bjorck=DEFAULT_NITER_BJORCK):
    
    def f(x):
        pad = padding
        bias = use_bias and not batchnormalization
        regularizer = None
        kernel_constraint = None
       
        if pad == 'circular':
            x = CircularPadding(padding=(kernel_size[0]//2,kernel_size[0]//2))(x)
            pad = 'valid'
        if pad == 'symmetric':
            x = SymmetricPadding(padding=(kernel_size[0]//2,kernel_size[0]//2))(x)
            pad = 'valid'
        if stiefel :
            x =StiefelConv(filters, kernel_size, 
                           strides=strides, 
                           use_bias=bias,
                           normconstaint =normconstaint,
                           conv_first=conv_first,
                           padding=pad,k_coef_lip=kCoefLip)(x)
        else :
            regul = regul_type
            if conv_first:
                regul = "trans_bjork_coeff"
            
            
            x =SpectralConv2D(filters, kernel_size,
                              kernel_initializer="orthogonal",
                              regul_type=regul,
                              strides=strides, use_bias=bias,
                              conv_first=conv_first,
                              k_coef_grad = k_coef_grad,
                              lambda_orth = lambda_orth,
                              normconstaint =normconstaint,
                              padding=pad,k_coef_lip=1, by_constraint=by_constraint,
                              niter_spectral=niter_spectral,niter_bjorck=niter_bjorck)(x)
            if norm_coeff > 0 :
                x = NormRegularizer(coeff = norm_coeff, rescale = rescale)(x)
        if batchnormalization :
            x = BatchNormalization()(x)
        if batch_centering:
            x =BatchCentering( pixelwise=pixelwise,channelwise=channelwise)(x)
        if activation is not None:
            x = activation()(x)
        return x
    return f








def deel_lip_vgg_graph( nb_classes=1,kernel_size=(3,3),
                       coeffs=1.0,filter_size=16,
                       layers_per_depth=[],dense_layers_size=[],
                       padding='same',
                       regul_type = "spectral_conv",
                       by_constraint=False,
                       batch_centering = False,
                       normconstaint = False,
                       pixelwise=True,
                       channelwise=False,
                       batchnormalization = False,
                       k_coef_grad = 1.0,
                       stiefel = False,
                       out_activation = False,
                       l2_average = False,
                       altern_size = False,
                       norm_coeff = 0,
                       bias_output = True,
                       rescale = 1,
                       lambda_orth = 1.,
                       last_bjork = False,
                       activation_conv=None,activation_dense=None,use_bias=True,
                       use_stride=False,poolType="avg",batchNorm=0.0,
                       niter_spectral=DEFAULT_NITER_SPECTRAL, niter_bjorck=DEFAULT_NITER_BJORCK, 
                       splitLastLayer=False, activation_lastlayer = None):
    
    def f(x):
        nonlocal filter_size
        poll2fct={'avg':ScaledAveragePooling2D,'max':MaxPool2D,'l2norm':ScaledL2NormPooling2D,'inv':InvertibleDownSampling}
        last_activation=None
        total_coeffs = 1
        conv_first = True
        change = 0
        #print("use_stride",use_stride)
        for pos, (layers,nb,kernel) in enumerate(zip(layers_per_depth,filter_size,kernel_size)):
            for i in range(layers):
                strides=(1,1)
                #if use_stride and i == 0 and pos!=0:
                if altern_size:
                    change = 2*(i%2)
                #change = 0
                if use_stride and i == layers-1 and pos!=len(layers_per_depth)-1:
                    strides=(2,2)
                x = get_lipConv2D(nb+change, kernel_size=kernel,
                                  kCoefLip=coeffs,
                                  padding=padding,
                                  by_constraint=by_constraint,
                                  strides = strides,
                                  batch_centering = batch_centering,
                                  k_coef_grad = k_coef_grad,
                                  normconstaint = normconstaint,
                                  pixelwise=pixelwise,
                                  channelwise=channelwise,
                                  batchnormalization = batchnormalization,
                                  stiefel = stiefel,
                                  lambda_orth = lambda_orth,
                                  regul_type = regul_type,
                                  niter_bjorck=niter_bjorck,
                                  conv_first =conv_first,
                                  norm_coeff = norm_coeff,
                                  rescale = rescale,
                                  activation=activation_conv,
                                  use_bias=use_bias)(x)
                total_coeffs *= coeffs
                conv_first = False
            if not use_stride:
                x=poll2fct[poolType](pool_size=(2, 2))(x)
        #if len(dense_layers_size) == 0:
        if l2_average :
            x = ScaledGlobalL2NormPooling2D()(x)
        else:
            x = ScaledGlobalAveragePooling2D()(x)
        if out_activation :
            if batch_centering:
                x =BatchCentering()(x)
            if activation_dense is not None:
                x = activation_dense()(x)
        #
    #else:
        #x = Flatten()(x)
        for lay_size in dense_layers_size:
            x = get_lip_dense(lay_size, use_bias=use_bias, 
                              by_constraint=by_constraint,
                              activation=activation_dense, kCoefLip=coeffs,
                             niter_spectral=5, niter_bjorck=7)(x)
            total_coeffs *= coeffs
        #print(f"total_coeffs {total_coeffs}")
        if nb_classes == 1:
            x = FrobeniusDense(1, 
                               activation=last_activation, 
                               disjoint_neurons=True,
                               by_constraint = by_constraint,
                               use_bias=bias_output, k_coef_lip=1., 
                               kernel_initializer="orthogonal")(x)  
        else :
            if last_bjork:
                x =SpectralDense(nb_classes, activation=None, 
                              use_bias=bias_output, kernel_initializer="orthogonal" , by_constraint=False,
                              k_coef_lip=1., niter_spectral= niter_spectral,
                              niter_bjorck = niter_bjorck)(x) ## better to separate kernel and activation
            elif splitLastLayer and (nb_classes>1):
                #x = tf.concat([get_lip_dense(1, by_constraint=by_constraint,use_bias=use_bias, activation=last_activation, kCoefLip=1.0/total_coeffs)(x) for c in
                #               range(nb_classes)], axis=-1)
                x = FrobeniusDense(nb_classes, 
                                activation=last_activation, 
                                disjoint_neurons=True,
                                by_constraint = by_constraint,
                                use_bias=bias_output, k_coef_lip=1., 
                                kernel_initializer="orthogonal")(x)
                
            else:
                vals = []
                for i in range(nb_classes):
                    sub_class=get_lip_dense(128,use_bias=True,by_constraint=by_constraint,activation=activation_dense,
                                    kCoefLip=1.,
                                    #kCoefLip=1.0,
                                    niter_spectral= niter_spectral, niter_bjorck = niter_bjorck)(x)  
                    sub_class = FrobeniusDense(1, 
                                activation=last_activation, 
                                disjoint_neurons=True,
                                by_constraint = by_constraint,
                                use_bias=bias_output, k_coef_lip=1., 
                                kernel_initializer="orthogonal")(sub_class)  
                    vals.append(sub_class)
                x = tf.keras.layers.Concatenate()(vals)
            
        if activation_lastlayer is not None:
            x=activation_lastlayer()(x)
        return x
    return f





def deel_lip_vgg(shape, nb_classes=1,kernel_size=3,coeffs=1,filter_size=16,
                 layers_per_depth=[], dense_layers_size=[],regul_type = "spectral_conv",
                 padding='same',activation_conv=None,
                 activation_dense=None,
                 use_bias=True,use_stride=False,
                 poolType="avg",lambdaE=0.0, 
                 by_constraint=False,
                 batch_centering = False,
                 l2_average = False,
                 normconstaint = False,
                 out_activation = False,
                 pixelwise=True,
                 channelwise=False,
                 k_coef_grad = 1.0,
                 stiefel = False,
                 norm_coeff = 0,
                 altern_size = False,
                 bias_output = True,
                 rescale = 1,
                 last_bjork = False,
                 lambda_orth = 1.,
                 batchnormalization = False,
                 niter_spectral=DEFAULT_NITER_SPECTRAL, 
                 niter_bjorck=DEFAULT_NITER_BJORCK, splitLastLayer=False,
                 activation_lastlayer = None,
                verbose= False):
    K.clear_session()
    if not isinstance(filter_size, list):
        filter_size = [filter_size*2**i for i in range(len(layers_per_depth))]
    if not isinstance(kernel_size, list):
        kernel_size = [kernel_size for i in range(len(layers_per_depth))]
    k_size=[(k,k) for k in kernel_size]
    inputs=Input(shape)
    net = deel_lip_vgg_graph(filter_size=filter_size,
                             layers_per_depth=layers_per_depth,
                             dense_layers_size = dense_layers_size,
                             coeffs=coeffs,
                             pixelwise=pixelwise,
                             channelwise=channelwise,
                             out_activation = out_activation,
                             batchnormalization = batchnormalization,
                             by_constraint=by_constraint,
                             batch_centering = batch_centering,
                             kernel_size=k_size,
                             l2_average = l2_average,
                             k_coef_grad = k_coef_grad,
                             norm_coeff = norm_coeff,
                             lambda_orth = lambda_orth,
                             altern_size = altern_size,
                             bias_output = bias_output,
                             rescale = rescale,
                             regul_type =regul_type,
                             normconstaint = normconstaint,
                             stiefel = stiefel,
                             last_bjork = last_bjork,
                             nb_classes=nb_classes,padding=padding,activation_conv=activation_conv,
                             activation_dense=activation_dense,use_bias=use_bias,
                             use_stride=use_stride,poolType=poolType,batchNorm=0.0, 
                             niter_spectral=niter_spectral, 
                             niter_bjorck=niter_bjorck,splitLastLayer=splitLastLayer,
                             activation_lastlayer=activation_lastlayer)(inputs)
    model=LipModel(inputs=inputs, outputs=net)
    if verbose :
        model.summary()
    return model


def cifar_vgg(shape, nb_classes=1,filter_size=16):
    inputs=Input(shape)

