import tensorflow as tf
import numpy as np
import time
import sys

n = 256
total_steps = 1000000
delay = 3
updates = []
avg_steps = 1
b_size = 256
cur_lr = 10 ** (-3.5)
reset_every_n = 100
experiment = "ptb"  # "cp" or "ptb"

val_steps = 2
val_batch = 2

if experiment == "cp":
    vocab = 4
elif experiment == "ptb":
    vocab = 50

# RTRL input
u = tf.get_variable('u', shape=(b_size, 1, vocab + n + 1), initializer=tf.zeros_initializer())
A = tf.get_variable('A', shape=(b_size, n, 2 * n), initializer=tf.zeros_initializer())

# RNN input
x_ = tf.placeholder(tf.int32, shape=[b_size])
x = tf.one_hot(x_, vocab)

h = tf.get_variable('h', shape=(b_size, n), initializer=tf.zeros_initializer())
b = tf.ones(shape=(b_size, 1))
target_ = tf.placeholder(tf.int32, shape=[b_size, 1])
target = tf.one_hot(target_, vocab)
lr = tf.placeholder(tf.float32)

# Neccesary steps for running RNN
W = tf.get_variable('W_hx', shape=(n + vocab + 1, 2 * n), initializer=tf.random_normal_initializer(stddev=0.01))
concat = tf.concat([x, h, b], axis=1)

# Calculate new hidden state
z_t = tf.split(tf.matmul(concat, W), 2, axis=1)
c, f = tf.nn.sigmoid(z_t[0]), tf.nn.sigmoid(z_t[1])
c_pre = c
c = 2 * c - 1
h_next = f * h + (1 - f) * c

# Get derivatives and run RTRL
c_, f_ = (1 - c_pre) * c_pre, (1 - f) * f
c_ = 2 * c_

D = tf.concat([tf.matrix_diag((1 - f) * c_), tf.matrix_diag((h - c) * f_)], axis=2)

c2h = tf.expand_dims((1 - f) * c_, axis=2) * tf.transpose(W[vocab:n + vocab, :n])
f2h = tf.expand_dims((h - c) * f_, axis=2) * tf.transpose(W[vocab:n + vocab, n:])
h2h = tf.matrix_diag(f)

H = c2h + f2h + h2h

sign = 2 * (tf.round(tf.random_uniform((b_size, 1, 1))) - 0.5)
sign_ = 2 * (tf.round(tf.random_uniform((b_size, 1, 1))) - 0.5)

temp = tf.matmul(H, A)

p1 = 0.00001 + (
tf.sqrt(tf.norm(temp, axis=(1, 2), keepdims=True) / (tf.norm(u, axis=(1, 2), keepdims=True) + 0.00001)))
p2 = 0.00001 + (tf.sqrt(
    tf.norm(D, axis=(1, 2), keepdims=True) / tf.expand_dims((tf.norm(concat, axis=(1), keepdims=True) + 0.00001),
                                                            axis=2)))

u_new = p1 * u + p2 * sign * tf.expand_dims(concat, axis=1)
A_new = (1 / p1) * temp + (1 / p2) * (sign * D)

# Generate prediction
W_out = tf.get_variable('W_out', shape=(n + 1, vocab), initializer=tf.random_normal_initializer(stddev=0.01))
h_out = tf.concat([h_next, b], axis=1)

y = tf.matmul(h_out, W_out)
y_soft = tf.nn.softmax(y)

loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=target, logits=y))
loss2wout = tf.gradients(loss, W_out)[0]

loss2h = tf.expand_dims(tf.gradients(loss, h_next)[0], axis=1)

loss2W = tf.matmul(loss2h, A_new)
loss2W = [tf.contrib.kfac.utils.kronecker_product(u_new[i], loss2W[i]) for i in range(b_size)]
loss2W = sum(loss2W) / b_size
loss2W_ = tf.reshape(loss2W, (n + vocab + 1, 2 * n))

loss2W_avg = tf.get_variable("loss2W_avg", shape=loss2W_.get_shape(), initializer=tf.zeros_initializer())
loss2wout_avg = tf.get_variable("loss2wout_avg", shape=loss2wout.get_shape(), initializer=tf.zeros_initializer())

optimizer = tf.train.AdamOptimizer(learning_rate=lr)
loss2W_avg_ = tf.clip_by_norm(loss2W_avg / avg_steps, 10000000)

train = optimizer.apply_gradients([(loss2W_avg_, W),
                                   (loss2wout_avg / avg_steps, W_out)])

clear = [loss2W_avg.assign(loss2W_avg * 0),
         loss2wout_avg.assign(loss2wout_avg * 0)]

resets__ = tf.placeholder(tf.float32, shape=[b_size, 1, 1])
clear__ = [u.assign(u * resets__), A.assign(A * resets__), h.assign(h * resets__[:, 0])]
clear__ptb = [u.assign(u * 0), A.assign(A * 0)]  # , h.assign(h * 0)]

debug_grad = tf.gradients(loss, h)[0]

with tf.control_dependencies([H, D, loss2W_, loss2wout, debug_grad]):  # , G1_new, G2_new]):
    updates.append(loss2W_avg.assign_add(loss2W_))
    updates.append(loss2wout_avg.assign_add(loss2wout))

    updates.append(h.assign(h_next))

    updates.append(u.assign(u_new))
    updates.append(A.assign(A_new))

####VALIDATION-starts#######
############################

x_val_ = tf.placeholder(tf.int32, shape=[val_batch, val_steps])
x_val = tf.one_hot(x_val_, vocab)
b_val = tf.ones(shape=(val_batch, 1))

target_val_ = tf.placeholder(tf.int32, shape=[val_batch, val_steps])
target_val = tf.one_hot(target_val_, vocab)
h1_val = tf.zeros(shape=(val_batch, n))
h_val = h1_val

loss_val = 0
for i in range(val_steps):
    concat_val = tf.concat([x_val[:, i], h_val, b_val], axis=1)
    z_t_val = tf.split(tf.matmul(concat_val, W), 2, axis=1)

    c_val, f_val = tf.nn.sigmoid(z_t_val[0]), tf.nn.sigmoid(z_t_val[1])
    c_val = 2 * c_val - 1
    h_val = f_val * h_val + (1 - f_val) * c_val

    h_out_val = tf.concat([h_val, b_val], axis=1)

    y_val = tf.matmul(h_out_val, W_out)
    # y_softb = tf.nn.softmax(yb)

    loss_val += tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=target_val[:, i], logits=y_val))

loss_val = loss_val / val_steps

########################
####VALIDATION-end######

T_ = 1

if experiment == "ptb":
    inp_ptb = np.load('ptb.train.npy')[:, None]
    data_s = inp_ptb.shape[0]

    out_ptb = np.random.randint(vocab - 3, size=(data_s,))[:, None]
    out_ptb[:-1] = inp_ptb[1:]

    idx = np.linspace(0, data_s - 1, b_size, dtype=np.int32)

    inp_val = np.load('ptb.valid.npy')
    # print(inp_val[:10])
    data_s_val = inp_val.shape[0]
    out_val = np.random.randint(vocab - 3, size=(data_s_val,))
    out_val[:-1] = inp_val[1:]

    chunk_val = val_steps * val_batch

    data_s_val = inp_val.shape[0] - (inp_val.shape[0] % chunk_val)

    inp_val = inp_val[:data_s_val].reshape(val_batch, data_s_val // chunk_val, val_steps)
    out_val = out_val[:data_s_val].reshape(val_batch, data_s_val // chunk_val, val_steps)

    data_s_val = inp_val.shape[1]

    # print(inp_val[0, 0], inp_val[0, 1])
    # print(aaa)

elif experiment == "cp":
    idx = np.zeros(b_size, np.int32) - 1
    data_i = [[0]] * b_size
    data_o = [[0]] * b_size


#######Task4######

def cp_data():
    global idx, data_i, data_o, b_size, T_, avg_log_loss
    resets = np.ones((b_size,))
    cp_inp = np.zeros((b_size,))
    cp_out = np.zeros((b_size,))

    for i in range(b_size):
        cur_idx = idx[i]
        if cur_idx == -1 or cur_idx == data_i[i].shape[0]:
            if avg_log_loss < 0.15:
                T_ += 1
                avg_log_loss = 0.3
            T = max(1, T_ - np.random.randint(5))

            inp = np.random.randint(2, size=(2 * T + 2))
            # print(inp)
            inp[0] = 2
            # inp[T] = 2
            inp[T + 1:] = 3
            # inp[2*T + 1] = 2

            out = np.random.randint(2, size=(2 * T + 2))
            # print(out)
            out[:T + 1] = 3
            out[T + 1: 2 * T + 2] = inp[:T + 1]
            # out[2*T] = 2
            # out[2*T:] = 3

            # print(inp[:T])
            # print(inp)
            # print(out)
            # print(aaa)
            data_i[i] = inp
            data_o[i] = out
            idx[i] = 0
            cur_idx = idx[i]
            resets[i] = 0
        cp_inp[i] = data_i[i][cur_idx]
        cp_out[i] = data_o[i][cur_idx]

    return cp_inp, cp_out, resets


##########################

avg_loss = 0.1
avg_log_loss = -np.log2(0.5)
prev_clear = 0
avg_clear = 1000
norm = tf.norm(u) / b_size

best_val = 99

cur_time = time.time()
cur_T = 1
gpu_options = tf.GPUOptions(allow_growth=True)

saver = tf.train.Saver()

with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
    sess.run(tf.global_variables_initializer())
    # saver.restore(sess, sys.argv[1])
    steps = 0
    while steps < total_steps:
        steps += 1

        if experiment == "ptb" and (steps % 10000) == 0:
            cur_loss_val = 0
            cur_h_val = sess.run(h1_val)
            for i in range(data_s_val):
                feed_dict = {x_val_: inp_val[:, i],
                             target_val_: out_val[:, i],
                             h1_val: cur_h_val}
                cur_run = sess.run([loss_val, h_val], feed_dict=feed_dict)
                cur_loss_val += cur_run[0] * np.log2(np.e)
                cur_h_val = cur_run[1]
                # print(i, data_s_val)
            cur_loss_val /= data_s_val

            print('VAL BPC: ', cur_loss_val)

            if best_val > cur_loss_val:
                best_val = cur_loss_val
                saver.save(sess, "./model.ckpt")

        if experiment == "ptb":
            idx = idx % data_s
            feed_dict = {x_: inp_ptb[idx, 0],
                         target_: out_ptb[idx]}
            resets = np.random.choice([0, 1], size=(b_size,), p=[0.01, 0.99])
            sess.run(clear__, feed_dict={resets__: resets[:, None, None]})
        elif experiment == "cp":
            cp_inp, cp_out, resets = cp_data()
            feed_dict = {resets__: resets[:, None, None]}
            sess.run(clear__, feed_dict=feed_dict)
            feed_dict = {x_: cp_inp, target_: cp_out[:, None]}

        idx += 1

        cur_loss, y__, _, n_ = sess.run([loss, y_soft, updates, norm], feed_dict=feed_dict)

        if steps % avg_steps == avg_steps - 1:
            sess.run(train, {lr: cur_lr})
            sess.run([clear])

        avg_log_loss = avg_log_loss * 0.999 + 0.001 * cur_loss * np.log2(np.e)

        # if steps == 499999:
        #    cur_lr /= 3.

        if steps % 1000 == 999:
            if experiment == 'cp':
                print('Steps:', steps, 'sequence length:', T_, 'BPC:', avg_log_loss)
            else:
                print('Steps:', steps, 'BPC:', avg_log_loss)
            print((time.time() - cur_time))
            cur_time = time.time()
            sys.stdout.flush()

