import tensorflow as tf
import numpy as np
import scipy.fftpack
import os
from keras.datasets import mnist

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

def show(img):
    img = img.real
    remap = " .*#" + "#" * 100
    img = (img.flatten()) * 3
    print("START")
    for i in range(28):
        print("".join([remap[int(round(x))] for x in img[i * 28:i * 28 + 28]]))

(x_train, y_train), (x_test, y_test) = mnist.load_data()

img_rows = img_cols = 28
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255

def dct2(xs):
    dct = xs
    dct = tf.transpose(dct, [0, 1, 3, 2])
    dct = tf.spectral.dct(dct, norm='ortho')
    dct = tf.transpose(dct, [0, 1, 3, 2])
    
    dct = tf.transpose(dct, [0, 3, 2, 1])
    dct = tf.spectral.dct(dct, norm='ortho')
    dct = tf.transpose(dct, [0, 3, 2, 1])

    return dct

def idct2(xs):
    dct = xs
    dct = tf.transpose(dct, [0, 1, 3, 2])
    dct = tf.spectral.idct(dct, norm='ortho')
    dct = tf.transpose(dct, [0, 1, 3, 2])
    
    dct = tf.transpose(dct, [0, 3, 2, 1])
    dct = tf.spectral.idct(dct, norm='ortho')
    dct = tf.transpose(dct, [0, 3, 2, 1])

    return dct

def keep_top_k(x, k):
    top_k = tf.math.top_k(tf.reshape(x, (BS, 28*28)), k=k)[0]
    mask = tf.cast(x>=tf.reshape(top_k[:,-1],[-1,1,1,1]),dtype=tf.float32)
    return x*mask

BS = 1000
xs = tf.placeholder(tf.float32, (BS, 28, 28, 1))

x = tf.constant(np.zeros(xs.shape, dtype=np.float32))
e = tf.constant(np.zeros(xs.shape, dtype=np.float32))
for i in range(10):
    x = dct2(xs - e)
    x = keep_top_k(x, 30)
    
    idct = idct2(x)
    e = (xs-idct)
    e = keep_top_k(e, 30)

reconstructed = idct2(x)

error = tf.reduce_sum((reconstructed-xs)**2,axis=(1,2,3))

grad = tf.gradients(error, [xs])[0]

batch = x_test[:BS]
flat_batch = np.reshape(batch, [-1, 28*28])

with tf.Session() as sess:
    for i in range(30):
        out, g, e = sess.run((reconstructed, grad, error), {xs: batch})

        flat_grads = np.reshape(g, [-1, 28*28])
        which = np.argmax(np.abs(flat_grads),axis=1)

        flat_batch[np.arange(BS), which] = (flat_grads[np.arange(BS),which]<0)
    
        #show((g-np.min(g))/(np.max(g)-np.min(g)))
        print(e)
        show(batch)
        show(out)

