
# coding: utf-8

# In[ ]:


import numpy as np
import tensorflow as tf


# ### Loss definition

# In[ ]:


def log_t(u, t):
    """Computes log_t for `u`."""
    if t == 1:
        return tf.log(u)
    else:
        return (u**(1.0 - t) - 1.0) / (1.0 - t)

def exp_t(u, t):
    """Computes exp_t for `u`"""
    if t == 1:
        return tf.exp(u)
    else:
        return tf.nn.relu(1.0 + (1.0 - t) * u) ** (1.0 / (1.0 - t))

def compute_normalization(activations, t, num_iters=10):
    """Returns the normalization value for each example.

    Args:
      num_iters: Number of iteration to run the method.
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      t: Temperature 2 (> 1.0 for tail heaviness).

    Return:
      A tensor of same rank as activation with the last dimension being 1. 
    """
    mu = tf.reduce_max(activations, -1, keep_dims=True)
    normalized_activations_step_0 = activations - mu

    def iter_condition(i, normalized_activations):
        return (i < num_iters)

    def iter_body(i, normalized_activations):
        Z = tf.reduce_sum(exp_t(normalized_activations, t), -1, keep_dims=True)
        normalized_activations_t = normalized_activations_step_0 * tf.pow(Z, 1 - t)
        return [i + 1, normalized_activations_t]

    iter, normalized_activations_t = tf.while_loop(iter_condition, iter_body,
                                                 [0, normalized_activations_step_0],
                                                 maximum_iterations=num_iters)
    Z = tf.reduce_sum(exp_t(normalized_activations_t, t), -1, keep_dims=True)
    return -log_t(1.0 / Z, t) + mu


def bi_tempered_logistic_loss(activations, labels, t1, t2):
    """Computes the Bi-Tempered logistic loss.

    Args:
    activations: A multi-dimensional tensor with last dimension `num_classes`.
    labels: batch_size
    t1: Temperature 1 (< 1.0 for boundedness).
    t2: Temperature 2 (> 1.0 for tail-heaviness).

    Returns:
    A loss tensor.
    """
    G = compute_normalization(activations, t2, num_iters=10)
    shifted_activation = tf.nn.relu(1.0 + (1.0 - t2) * (activations - G))
    one_minus_t1 = (1.0 - t1)
    one_minus_t2 = (1.0 - t2)
    return -tf.reduce_sum(tf.multiply(tf.pow(shifted_activation, one_minus_t1/one_minus_t2)-1.0, labels), -1)/one_minus_t1 + 1.0/(1.0 + one_minus_t1) * (tf.reduce_sum(tf.pow(shifted_activation, (1.0 + one_minus_t1)/one_minus_t2), -1) - 1.0)
    


# ### A simple example: MNIST classification

# In[ ]:


from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=False)

train_X = mnist.train.images
train_Y = mnist.train.labels

test_X = mnist.test.images
test_Y = mnist.test.labels

num_classes = 10
num_epochs = 500
disp_epoch = 10
lr = 1.0
batch_size = 128
dropout = 0.75
num_train_examples = train_X.shape[0]


# In[ ]:


def conv2d(x, W, b, k=1):
    x = tf.nn.conv2d(x, W, strides = [1,k,k,1], padding='SAME')
    x = tf.nn.bias_add(x, b)
    return tf.nn.relu(x)

def maxpool2d(x, k=2):
    return tf.nn.max_pool(x, ksize=[1,k,k,1], strides=[1,k,k,1], padding='SAME') 

def cnn2d(x, weight, bias, dropout):
    x = tf.reshape(x, (-1, 28, 28, 1))
    
    layer1 = conv2d(x, weight['c1'], bias['c1'])
    layer1 = maxpool2d(layer1)
    
    layer2 = conv2d(layer1, weight['c2'], bias['c2'])
    layer2 = maxpool2d(layer2)
    
    f1 = tf.reshape(layer2, (-1, weight['f1'].get_shape().as_list()[0]))
    f1 = tf.nn.relu(tf.add(tf.matmul(f1, weight['f1']), bias['f1']))
    f1 = tf.nn.dropout(f1, dropout)
    
    f2 = tf.add(tf.matmul(f1, weight['f2']), bias['f2'])
    return f2


# #### Network weights

# In[ ]:


weight = {
    'c1': tf.Variable(np.random.normal(size=(5, 5, 1, 32)), dtype=tf.float32),
    'c2': tf.Variable(np.random.normal(size=(5, 5, 32, 64)), dtype=tf.float32),
    'f1': tf.Variable(np.random.normal(size=(7*7*64, 1024)), dtype=tf.float32),
    'f2': tf.Variable(np.random.normal(size=(1024, num_classes)), dtype=tf.float32)
}

bias = {
    'c1': tf.Variable(np.random.normal(size=(32)), dtype=tf.float32),
    'c2': tf.Variable(np.random.normal(size=(64)), dtype=tf.float32),
    'f1': tf.Variable(np.random.normal(size=(1024)), dtype=tf.float32),
    'f2': tf.Variable(np.random.normal(size=(num_classes)), dtype=tf.float32)
}


# #### Bi-tempered loss temperatures

# In[ ]:


t1 = 0.5  # boundedness 0 <= t_1 < 1
t2 = 4.0  # tail-heaviness t_2 > 1


# In[ ]:


X = tf.placeholder(tf.float32, (None, 784))
Y = tf.placeholder(tf.int32, (None))
keep_prob = tf.placeholder(tf.float32)

logits = cnn2d(X, weight, bias, keep_prob)
preds = tf.argmax(logits, axis=1)
accu = tf.reduce_mean(tf.cast(tf.equal(Y, tf.cast(preds, tf.int32)), tf.float32))
loss = tf.reduce_mean(bi_tempered_logistic_loss(logits, tf.one_hot(Y, num_classes), t1, t2))
train_op = tf.train.AdadeltaOptimizer(learning_rate=lr).minimize(loss)

init = tf.global_variables_initializer()


# #### Label noise level

# In[ ]:


noise_level = 0.4

num_noisy_examples = int(num_train_examples * noise_level)

train_Y_noisy = train_Y.copy()
noisy_examples_index = np.random.permutation(num_train_examples)[:num_noisy_examples]

# assign a random class label other than the true class
noisy_labels = train_Y[noisy_examples_index]
noisy_labels = (noisy_labels + np.random.uniform(low=1, high=num_classes, size=num_noisy_examples)) % num_classes
train_Y_noisy[noisy_examples_index] = noisy_labels.astype(np.int32)

# verify the corrupted labels
print("percent of corrupted labels: %2.2f%%" % (np.mean(train_Y != train_Y_noisy) * 100.0))


# #### Network training

# In[ ]:


with tf.Session() as sess:
    sess.run(init)
    for epoch in range(num_epochs):
        training_index = np.random.permutation(num_train_examples)
        for batch_num in range(int(num_train_examples/batch_size)):
            batch_index = training_index[np.r_[0:batch_size] + (batch_num * batch_size)]
            train_x, train_y_noisy = train_X[batch_index,:], train_Y_noisy[batch_index]
            sess.run(train_op, feed_dict = {X: train_x, Y: train_y_noisy, keep_prob:dropout})
        if (epoch+1) % disp_epoch == 0:
            train_y = train_Y[batch_index]
            accu_value, loss_value = sess.run((accu, loss), feed_dict = {X: train_x, Y: train_y, keep_prob:1.0})
            print("epoch %d, clean training batch accuracy %f, loss %f" % (epoch+1, accu_value, loss_value))
    accu_value, loss_value = sess.run((accu, loss), feed_dict = {X: test_X, Y: test_Y, keep_prob:1.0})
    print("test accuracy %f, test loss %f" % (accu_value, loss_value))

