# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
# CRIAQ and ANITI - https://www.deel.ai/
# =====================================================================================
from tensorflow.keras.initializers import Initializer, Orthogonal, GlorotUniform
from tensorflow.keras import initializers
import tensorflow as tf
from .normalizers import reshaped_kernel_orthogonalization,stiefel_project
from tensorflow.keras.utils import register_keras_serializable


@register_keras_serializable("deel-lip", "SpectralInitializer")
class SpectralInitializer(Initializer):
    def __init__(
        self,
        niter_spectral=3,
        niter_bjorck=15,
        k_coef_lip=1.0,
        base_initializer="orthogonal",
    ) -> None:
        """
        Initialize a kernel to be 1-lipschitz orthogonal using bjorck
        normalization.

        Args:
            niter_spectral: number of iteration to do with the iterative power method
            niter_bjorck: number of iteration to do with the bjorck algorithm
            base_initializer: method used to generate weights before applying the
                orthonormalization
        """
        self.niter_spectral = niter_spectral
        self.niter_bjorck = niter_bjorck
        self.k_coef_lip = k_coef_lip
        self.base_initializer = initializers.get(base_initializer)
        super(SpectralInitializer, self).__init__()

    def __call__(self, shape, dtype=None, partition_info=None):
        w = self.base_initializer(shape=shape, dtype=dtype)
        wbar, u, sigma = reshaped_kernel_orthogonalization(
            w,
            None,
            self.k_coef_lip,
            self.niter_spectral,
            self.niter_bjorck,
        )
        return wbar

    def get_config(self):
        return {
            "niter_spectral": self.niter_spectral,
            "niter_bjorck": self.niter_bjorck,
            "k_coef_lip": self.k_coef_lip,
            "base_initializer": initializers.serialize(self.base_initializer),
        }
    
    
@register_keras_serializable("deel-lip", "StiefelInitializer")
class StiefelInitializer(Initializer):
    def __init__(
        self,
        base_initializer=Orthogonal(gain=1.0, seed=None),
    ) -> None:
        """
        Initialize a kernel to be 1-lipschitz orthogonal using bjorck
        normalization.

        Args:
            niter_spectral: number of iteration to do with the iterative power method
            niter_bjorck: number of iteration to do with the bjorck algorithm
            base_initializer: method used to generate weights before applying the
                orthonormalization
        """
        self.base_initializer = initializers.get(base_initializer)
        super(StiefelInitializer, self).__init__()

    def __call__(self, shape, dtype=None, partition_info=None):
        print("stiefel init")
        w = self.base_initializer(shape=shape, dtype=dtype)
        w_shape = w.shape
    
        wbar = tf.reshape(w, [-1, w_shape[-1]])
        wbar = stiefel_project(wbar)
        wbar = tf.reshape(wbar,w_shape)
        return wbar

    def get_config(self):
        return {
            "base_initializer": initializers.serialize(self.base_initializer),
        }

@register_keras_serializable("deel-lip", "FrobenusInitializer")
class FrobenusInitializer(Initializer):
    def __init__(
        self,axis_norm = None,
        base_initializer=GlorotUniform( seed=None),
    ) -> None:
        """
        Initialize a kernel to be 1-lipschitz orthogonal using bjorck
        normalization.

        Args:
            niter_spectral: number of iteration to do with the iterative power method
            niter_bjorck: number of iteration to do with the bjorck algorithm
            base_initializer: method used to generate weights before applying the
                orthonormalization
        """
        self.axis_norm=axis_norm
        self.base_initializer = initializers.get(base_initializer)
        super(FrobenusInitializer, self).__init__()

    def __call__(self, shape, dtype=None, partition_info=None):
        w = self.base_initializer(shape=shape, dtype=dtype)
        
        wbar = (w / tf.norm(w, axis=self.axis_norm)
            )
        return wbar

    def get_config(self):
        return {"axis_norm" : self.axis_norm,
            "base_initializer": initializers.serialize(self.base_initializer),
        }
