import tensorflow as tf


# Copied from NTC
def get_activation(activation: str, dtype=None):
    if not activation or activation.lower() == 'none':
        return None
    if activation in ("gdn", "igdn"):
        import tensorflow_compression as tfc
        if activation == "gdn":
            return tfc.GDN(dtype=dtype)
        elif activation == "igdn":
            return tfc.GDN(inverse=True, dtype=dtype)
    else:
        return getattr(tf.nn, activation)


def make_mlp(units, activation, name='mlp', input_shape=None, dtype=None, no_last_activation=True, return_layers=False):
    kwargs = [dict(  # pylint:disable=g-complex-comprehension
        units=u, use_bias=True, activation=activation,
        name=f"{name}_{i}", dtype=dtype,
    ) for i, u in enumerate(units)]
    if input_shape is not None:
        kwargs[0].update(input_shape=input_shape)
    if no_last_activation:
        kwargs[-1].update(activation=None)
    layers = [tf.keras.layers.Dense(**k) for k in kwargs]
    if return_layers:
        return layers
    else:
        return tf.keras.Sequential(layers, name=name)


# # Custom analysis/synthesis transforms
# def make_analysis_transform(num_filters, kernel_dims, preprocess_layers=[]):
#     from model_utils import enc_conv
#     layers = preprocess_layers + [enc_conv(filters=filters, kernel_support=(kernel_dim, kernel_dim),
#                                            name=f'enc_{i}') for i, (filters, kernel_dim) in
#                                   enumerate(zip(num_filters, kernel_dims))]

class cifar:  # using this class just as a namspace
    class AnalysisTransform(tf.keras.Sequential):
        """The analysis transform."""

        def __init__(self, num_filters, num_output_filters=None, dense_units=[], dense_activation='leaky_relu'):
            import tensorflow_compression as tfc
            super().__init__(name="analysis")
            if not num_output_filters:
                num_output_filters = num_filters
            self.add(tf.keras.layers.Lambda(lambda x: x / 255.))
            self.add(tfc.SignalConv2D(
                num_filters, (5, 5), name="layer_0", corr=True, strides_down=2,
                padding="same_zeros", use_bias=True,
                activation=tfc.GDN(name="gdn_0")))
            self.add(tfc.SignalConv2D(
                num_filters, (5, 5), name="layer_1", corr=True, strides_down=2,
                padding="same_zeros", use_bias=True,
                activation=tfc.GDN(name="gdn_1")))
            if not dense_units:
                self.add(tfc.SignalConv2D(
                    num_output_filters, (3, 3), name="layer_2", corr=True, strides_down=2,
                    padding="same_zeros", use_bias=False,
                    activation=None))
            else:
                self.add(tfc.SignalConv2D(
                    num_filters, (3, 3), name="layer_2", corr=True, strides_down=2,
                    padding="same_zeros", use_bias=True,
                    activation=tfc.GDN(name="gdn_2")))
                self.add(tf.keras.layers.Flatten())
                mlp = make_mlp(dense_units, activation=get_activation(dense_activation))
                self.add(mlp)

            self.downsample_factors = [2, 2, 2]

    class SynthesisTransform(tf.keras.Sequential):
        """The synthesis transform."""

        def __init__(self, num_filters, dense_units=[], dense_activation='leaky_relu', decoder_init_conv_dim=None):
            import tensorflow_compression as tfc
            super().__init__(name="synthesis")
            if dense_units:
                # The InputLayer is probably important; borrowed from https://www.tensorflow.org/tutorials/generative/cvae#network_architecture
                self.add(tf.keras.layers.InputLayer(input_shape=(dense_units[0],)), )
                mlp = make_mlp(dense_units[1:], activation=get_activation(dense_activation), no_last_activation=False)
                self.add(mlp)
                self.add(
                    tf.keras.layers.Reshape(target_shape=(decoder_init_conv_dim, decoder_init_conv_dim, num_filters)))
            self.add(tfc.SignalConv2D(
                num_filters, (3, 3), name="layer_0", corr=False, strides_up=2,
                padding="same_zeros", use_bias=True,
                activation=tfc.GDN(name="igdn_0", inverse=True)))
            self.add(tfc.SignalConv2D(
                num_filters, (5, 5), name="layer_1", corr=False, strides_up=2,
                padding="same_zeros", use_bias=True,
                activation=tfc.GDN(name="igdn_1", inverse=True)))
            self.add(tfc.SignalConv2D(
                3, (5, 5), name="layer_2", corr=False, strides_up=2,
                padding="same_zeros", use_bias=True,
                activation=None))
            self.add(tf.keras.layers.Lambda(lambda x: x * 255.))


# TODO: add data preprocessing; use_bias in the final conv even if using fc
class mnist:  # using this class just as a namspace
    class AnalysisTransform(tf.keras.Sequential):
        """The analysis transform."""

        def __init__(self, num_filters, num_output_filters=None, dense_units=[], dense_activation='leaky_relu'):
            """

            :param num_filters:
            :param num_output_filters:
            :param dense_units: list of ints; last entry determines the encoder output dimension.
            :param dense_activation:
            """
            import tensorflow_compression as tfc
            super().__init__(name="analysis")
            if not num_output_filters:
                num_output_filters = num_filters
            self.add(tfc.SignalConv2D(
                num_filters, (5, 5), name="layer_1", corr=True, strides_down=2,
                padding="same_zeros", use_bias=True,
                activation=tfc.GDN(name="gdn_1")))
            if not dense_units:
                self.add(tfc.SignalConv2D(
                    num_output_filters, (5, 5), name="layer_2", corr=True, strides_down=2,
                    padding="same_zeros", use_bias=False,
                    activation=None))
            else:
                self.add(tfc.SignalConv2D(
                    num_filters, (5, 5), name="layer_2", corr=True, strides_down=2,
                    padding="same_zeros", use_bias=False,
                    activation=tfc.GDN(name="gdn_1")))
                self.add(tf.keras.layers.Flatten())
                mlp = make_mlp(dense_units, activation=get_activation(dense_activation))
                self.add(mlp)

            self.downsample_factors = [2, 2]

    class SynthesisTransform(tf.keras.Sequential):
        """The synthesis transform."""

        def __init__(self, num_filters, dense_units=[], dense_activation='leaky_relu', decoder_init_conv_dim=None):
            """

            :param num_filters:
            :param dense_units: list of ints; the first entry determines the decoder input dimension.
            :param dense_activation:
            """
            import tensorflow_compression as tfc
            super().__init__(name="synthesis")
            if dense_units:
                # The InputLayer is probably important; borrowed from https://www.tensorflow.org/tutorials/generative/cvae#network_architecture
                self.add(tf.keras.layers.InputLayer(input_shape=(dense_units[0],)), )
                mlp = make_mlp(dense_units[1:], activation=get_activation(dense_activation), no_last_activation=False)
                self.add(mlp)
                self.add(
                    tf.keras.layers.Reshape(target_shape=(decoder_init_conv_dim, decoder_init_conv_dim, num_filters)))
            self.add(tfc.SignalConv2D(
                num_filters, (5, 5), name="layer_0", corr=False, strides_up=2,
                padding="same_zeros", use_bias=True,
                activation=tfc.GDN(name="igdn_0", inverse=True)))
            self.add(tfc.SignalConv2D(
                1, (5, 5), name="layer_1", corr=False, strides_up=2,
                padding="same_zeros", use_bias=True,
                activation=None))
