import numpy as np
import tensorflow as tf
#import tensorflow_datasets as tfds
from tensorflow.keras.datasets import cifar10
from deel.utils.ImageTransformer import ImageTransformer
import matplotlib
import matplotlib.pyplot as plt
import random

def plot_10_by_10_cifar(images, filename = None):
    """ Plot 100 MNIST images in a 10 by 10 table. Note that we crop
    the images so that they appear reasonably close together.  The
    image is post-processed to give the appearance of being continued."""
    fig = plt.figure()
    #image = np.concatenate(images, axis=1)
    nb = len(images)
    for x in range(10):
        for y in range(10):
            ax = fig.add_subplot(10, 10, 10*y+x+1)
            if x<5 :
                ind = random.randint(0,nb//2)
            else :
                ind = random.randint(nb//2,nb-1)
            plt.imshow(images[ind].reshape(32, 32, 3))
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)

    if filename is not None: 
        plt.savefig(
            filename,
            bbox_inches='tight'
        )
        plt.close(fig) 
    else :
        plt.show()
def otp_generator(batch_size,X,Y):

    Y_ix=np.array([i for i in range(Y.shape[0]) ])
    Y0_ix=Y_ix[Y==1]
    Y1_ix=Y_ix[Y==-1]
    half=Y.shape[0]//2
    while True:
        batch_x = np.zeros(((batch_size,)+(X[0].shape)), dtype=np.float32)
        batch_y=np.zeros((batch_size,1), dtype=np.float32)
        ind=np.random.choice(Y0_ix,size=batch_size//2,replace=False)
        batch_x[:batch_size//2,]=X[ind]
        batch_y[:batch_size//2,0]=Y[ind]
        ind=np.random.choice(Y1_ix,size=batch_size//2,replace=False)
        batch_x[batch_size//2:,]=X[ind]
        batch_y[batch_size//2:,0]=Y[ind]
        
        yield  batch_x, batch_y



def simple_generator(batch_size,X,Y):
    def simpl_generator():
        #Y_ix=np.arange(Y.shape[0])
        while True:
            batch_x = np.zeros(((batch_size,)+(X[0].shape)), dtype=np.float32)
            batch_y=np.zeros((batch_size,Y.shape[1]), dtype=np.float32)
            ind=np.random.randint(0,Y.shape[0],size=batch_size)#np.random.choice(Y_ix,size=batch_size,replace=False)
            batch_x[:]=X[ind]
            batch_y[:]=Y[ind]

            yield  batch_x, batch_y
    return simpl_generator


def cifar_generator_aug(batch_size,X,Y,random = False,rotation_range=15,
                               zoom_range=0.,
                               fill_mode='nearest',
                               flip_horizontal=True,
                               height_shift_range=0.1,
                               contrast_level= 0.2,
                               shear_range = 0,
                               random_crop = 0,
                               salt_and_pepper= 0.0,
                               gaussian_noise= 0.01,
                               width_shift_range=0.1):
    def au_generator():
        trans=ImageTransformer(rotation_range=rotation_range,
                               zoom_range=zoom_range,
                               fill_mode=fill_mode,
                               shear_range = shear_range,
                               flip_horizontal=flip_horizontal,
                               random_crop = random_crop,
                               height_shift_range=height_shift_range,
                               contrast_level=contrast_level,
                               salt_and_pepper= salt_and_pepper,
                               gaussian_noise=gaussian_noise,
                               width_shift_range=width_shift_range)
        #Y_ix=np.arange(Y.shape[0])
        pos = 0
        nb = X.shape[0]
        while True:
            batch_x = np.zeros(((batch_size,)+(X[0].shape)), dtype=np.float32)
            batch_y=np.zeros((batch_size,Y.shape[1]), dtype=np.float32)
            if random:
                ind = np.random.randint(0,Y.shape[0],size=batch_size)#np.random.choice(Y_ix,size=batch_size,replace=False)
            else :
                ind = [(pos + i)%nb for i in range(batch_size)]
            batch_x[:]=X[ind]

            for i in range(batch_size):

                    batch_x[i]=trans.random_transform(batch_x[i])
            batch_y[:]=Y[ind]
            pos = (pos+batch_size)%nb
            yield  batch_x, batch_y

    return au_generator

def cifar10_dataset_oneclass(selected_classes):

    (X_all, y_all), (X_test, y_test) = cifar10.load_data()
    X_all = X_all.reshape((-1, 32, 32, 3))
    X_test = X_test.reshape((-1, 32, 32, 3))
    
    y_all = np.reshape(y_all,(-1,))
    y_test = np.reshape(y_test,(-1,))


    selected = [y == selected_classes for y in y_all]
    X_all = X_all[selected]

    select_test = [y== selected_classes for y in y_test]
    y_b_test = np.zeros(y_test.shape)
    y_b_test[select_test] = 1



    #X_test = X_test.reshape(-1, img_rows, img_cols, nb_channel)
    X_all = X_all.astype('float32')
    X_test = X_test.astype('float32')
    X_all = X_all - 128  # -1 1 range
    X_test = X_test -128  # -1 1 range
    means = X_all.mean(axis=(0,1,2))
    X_all = (X_all - means)#/std
    X_test = (X_test - means)#/std
   
    X_all = X_all/255.
    X_test = X_test/255.

    return X_all,X_test,y_b_test,y_test


def get_post_process_function(center,rescale,means):
    def post_process(x):
        x = x.astype(np.float32)
        if rescale:
            x = x*255.
        if center:
            x = x+means
        x = x/255.
        return np.clip(x,0.,1.)
    return post_process


def cifar10_dataset(batch_size,to_categorical, 
                        selected_classes=None,
                        gtValues=None,
                        center = True,
                        rescale = False,
                        aug=False,
                        tf_dataset = False,
                        rotation_range=15,
                        zoom_range=0.,
                        fill_mode='nearest',
                        flip_horizontal=True,
                        height_shift_range=0.1,
                        contrast_level= 0.2,
                        shear_range = 0,
                        random_crop = 0,
                        salt_and_pepper= 0.0,
                        gaussian_noise= 0.00,
                        width_shift_range=0.1):
    
   
    if selected_classes is None:
        selected_classes = range(10)
    nb_classes = len(selected_classes)
    if nb_classes == 2:
        nb_classes = 1 ## binary
    if gtValues is None:
        index_selected_class = {selected_classes[i]:i for i in range(len(selected_classes))}
    else:
        assert len(gtValues)==nb_classes
        index_selected_class = {selected_classes[i]:gtValues[i] for i in range(len(selected_classes))}


    print(index_selected_class)

    # the data, shuffled and split between train and test sets
    (X_all, y_all), (X_test, y_test) = cifar10.load_data()

    
    X_all = X_all.reshape((-1, 32, 32, 3))
    X_test = X_test.reshape((-1, 32, 32, 3))
    
    y_all = np.reshape(y_all,(-1,))
    y_test = np.reshape(y_test,(-1,))

    print("Select only "+str(nb_classes)+" classes:"+str(selected_classes))
    select_all = [y in selected_classes for y in y_all]
    X_all = X_all[select_all]
    y_all = y_all[select_all]
    #print(y_all)
    y_all = [index_selected_class[y] for y in y_all]
    y_all = np.asarray(y_all)
    #y_all = np.reshape(y_all,(-1,1))
    max_train = int(len(X_all)*0.9)

    select_test = [y in selected_classes for y in y_test]
    X_test = X_test[select_test]
    y_test = y_test[select_test]
    #print(y_test.shape)
    y_test = [index_selected_class[y] for y in y_test]
    y_test = np.asarray(y_test)

    #y_test = np.reshape(y_test,(-1,1))
    print(y_test.shape)

    #X_test = X_test.reshape(-1, img_rows, img_cols, nb_channel)
    X_all = X_all.astype('float32')
    X_test = X_test.astype('float32')
    #X_all = X_all.astype('float32')/255
    #X_test = X_test.astype('float32')/255
    
    means = X_all.mean(axis=(0,1,2))
    std = X_all.std(axis=(0,1,2))
    #print(X_all.shape,X_all.mean(axis=(0,1,2)),X_all.std(axis=(0,1,2)))
    if center :

        
        X_all = (X_all - means)#/std
        X_test = (X_test - means)#/std
    if rescale :
        X_all = X_all/255.
        X_test = X_test/255.
    print(X_all.shape,X_all.mean(axis=(0,1,2)),X_all.std(axis=(0,1,2)))
    print(X_test.shape,X_test.mean(axis=(0,1,2)),X_test.std(axis=(0,1,2)))
    X_train = X_all
    X_valid = X_all
    Y_train = y_all
    Y_valid = y_all
    Y_test = y_test
    if to_categorical:
        Y_test = tf.keras.utils.to_categorical(Y_test,len(selected_classes))
        Y_train = tf.keras.utils.to_categorical(Y_train,len(selected_classes))
        Y_valid = tf.keras.utils.to_categorical(Y_valid,len(selected_classes))


    print(X_train.shape[0], 'train samples')
    print(X_valid.shape[0], 'valid samples')
    print(X_test.shape[0], 'test samples')
    
    if aug :
        if nb_classes == 1:
            Y_train = Y_train.reshape(Y_train.shape[0], 1)
            Y_test_r = Y_test.reshape(Y_test.shape[0], 1)
        else :
            Y_test_r = Y_test
        dtset = {'train' : cifar_generator_aug(batch_size,X_train,Y_train,
                                            rotation_range=rotation_range,
                                            fill_mode=fill_mode,
                                            shear_range = shear_range,
                                            zoom_range = zoom_range,
                                            random_crop = random_crop,
                                            flip_horizontal=flip_horizontal,
                                            height_shift_range=height_shift_range,
                                            contrast_level= contrast_level,
                                            salt_and_pepper= salt_and_pepper,
                                            gaussian_noise= gaussian_noise,
                                            width_shift_range=width_shift_range) ,
                'trainSize': X_train.shape[0],
                'valid' : simple_generator(batch_size,X_valid,Y_valid), 'validSize': X_valid.shape[0],
                'test' : simple_generator(batch_size,X_test,Y_test_r), 'testSize': X_test.shape[0],
                'test_XY' :(X_test,Y_test),
                'post_process' : get_post_process_function(center,rescale,means),
                 
                'batch_size': batch_size }
        if tf_dataset :
            dtset['train'] = tf.data.Dataset.from_generator(dtset['train'],(tf.float32, tf.float32)).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
            dtset['test'] = tf.data.Dataset.from_generator(dtset['test'],(tf.float32, tf.float32)).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
    else :
        for i in range(batch_size):
                dtset = {'train' : simple_generator(batch_size,X_train,Y_train) , 'trainSize': X_train.shape[0],
                'valid' : simple_generator(batch_size,X_valid,Y_valid), 'validSize': X_valid.shape[0],
                'test' : simple_generator(batch_size,X_test,Y_test), 'testSize': X_test.shape[0],
                 'test_XY' :(X_test,Y_test),
                'batch_size': batch_size }
    return dtset
    
