import numpy as np
from sklearn import metrics
import matplotlib.pyplot as plt
from itertools import permutations
import time
from tqdm import tqdm
import seaborn as sns
import scipy
# Figure in main text
# plt.figure(figsize=(4, 8))
plt.style.use('default')
# example data
fig, ax = plt.subplots(figsize=(6, 4))
color = sns.color_palette("husl", 2)
lw = 1.5

signal_mat_1 = np.load('signal_theory_2.npy')[:,:-1]
# x = np.array([100, 300, 1000, 2000, 5000, 10000, 50000])
x = np.array([1,2,3,4,5,6])
y = np.mean(signal_mat_1, axis=0)
yerr = np.std(signal_mat_1, axis=0)
ax.errorbar(x,y, yerr=yerr, marker='o', color=color[0], lw=lw, linestyle='-.')
plt.plot(x,y, linestyle='-.', label='2-polynomial decay', color=color[0])

signal_mat_2 = np.load('signal_theory_4.npy')[:,:-1]
x = np.array([1,2,3,4,5,6]) + 0.05
y = np.mean(signal_mat_2, axis=0)
yerr = np.std(signal_mat_2, axis=0)
ax.errorbar(x,y, yerr=yerr, marker='^', color=color[1], lw=lw, linestyle='-')
plt.plot(x,y, linestyle='-', label='4-polynmial decay', color=color[1])

plt.plot(x, np.ones(6), linestyle='--', color='grey')

plt.xlabel('Dimension $p$')
plt.ylabel('Signal Strength $\Delta_k^2$')

plt.xticks(np.array([1,2,3,4,5,6]), np.array([100, 300, 1000, 2000, 5000, 10000]))
plt.legend(loc="lower right")
plt.ylim([0.3, 1.1])
plt.savefig('plot4.pdf')
plt.show()

