import numpy as np
import matplotlib.pyplot as plt
from IPython import embed
import scipy.stats


ns = np.load('data/ns.npy')
ds = np.load('data/ds.npy')
data_slope = np.load('data/data_slope.npy')
data_holdout = np.load('data/data_holdout.npy')
data_test = np.load('data/data_test.npy')
data_all = np.load('data/data_all.npy')
ns_holdout = np.load('data/ns_holdout.npy')

print data_slope.shape
embed()
means_slope = np.mean(data_slope, axis=0)
means_holdout = np.mean(data_holdout, axis=0)
means_test = np.mean(data_test, axis=0)

std_holdout = scipy.stats.sem(data_holdout, axis=0)
std_test = scipy.stats.sem(data_test, axis=0)
std_slope = scipy.stats.sem(data_slope, axis=0)

lw = 3
means_all = []
for k in range(len(ds)):
    res = data_all[:, k, :]
    means_k = np.mean(res,axis=0)

    means_all.append(means_k)

    plt.plot(ns, means_k, label='d: ' + str(ds[k]), linestyle='--', linewidth=lw, alpha = .75)


fs = 16
plt.fill_between(ns, means_slope - std_slope, means_slope + std_slope, color='green', alpha=.2)
plt.plot(ns, means_slope, label='SLOPE', color='green', linewidth=lw)
plt.fill_between(ns_holdout, means_holdout - std_holdout, means_holdout + std_holdout, color='red', alpha=.2)
plt.plot(ns_holdout, means_holdout, label='Hold-out', color='red', linewidth=lw)
plt.fill_between(ns_holdout, means_test - std_test, means_test + std_test, color='blue', alpha=.2)
plt.plot(ns_holdout, means_test, label='ModBE', color='blue', linewidth=lw)

plt.yscale('log')
plt.title('Contextual Bandit', fontsize=fs)
plt.ylabel('Log Regret', fontsize=fs)
plt.xlabel('Dataset size', fontsize=fs)
plt.grid(linestyle=':', alpha=.9)  
plt.legend(fontsize=12, loc='upper right')
plt.tight_layout()

plt.savefig('figures/cb.pdf')
plt.show()
