import os
import torch
import copy
import numpy as np
from math import ceil, log2, pi
from torch import nn
import matplotlib
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from model import QSP_Func_Poly, Discretization, QNN, QNN_with_perfect_discretization
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.colors import LinearSegmentedColormap
from scipy import interpolate

matplotlib.rc('text', usetex=True)
matplotlib.rcParams['text.latex.preamble']=[r"\usepackage{amsmath}"]
matplotlib.style.use('classic')
plt.rcParams['mathtext.rm'] = 'Arial'
plt.rcParams['mathtext.it'] = 'Arial'
plt.rcParams['font.family'] = 'Arial'
plt.rcParams['font.serif'] = 'Arial'

# ======================= target function ==================
def target_func(x):
    x0 = x[0]
    x1 = x[1]
    return ((x0**2 + x1 - 1.5*pi)**2 + (x0 + x1**2 + pi)**2 + (x0 + x1 - 0.5 * pi)**2) / (5*pi**2)

input_data = np.mgrid[0:1:0.1, 0:1:0.1].reshape(2,-1).T
input_data = torch.from_numpy(input_data).float()
true_label = np.array([target_func(x) for x in input_data])
true_label = torch.from_numpy(true_label).float().view(-1)
criterion  = torch.nn.MSELoss(reduction="sum")
criterion1 = torch.nn.MSELoss(reduction="mean")

# ======================= hyperparameter ==================
paras = {}
paras['data_dim']        = 2
paras['d']               = paras['data_dim']
paras['K']               = 10
paras['eps']             = 0.01
paras['s']               = 2
paras['depth_constant']  = 2
paras['batch_size']      = 100
paras['random_seed_1']   = 28_10_2000
paras['random_seed_2']   = 13_02_1967
paras['random_seed_3']   = 27_11_2000
paras['max_shots']       = 1000
paras['learning_rate']   = 2e-3

model = QNN(s=paras['s'], depth_constant=paras['depth_constant'], K=paras['K'], eps=paras['eps'], d=paras['data_dim'],  random_seed_1=paras['random_seed_1'], random_seed_2=paras['random_seed_2'])
print('Localization parameters number: ', len(model.disc_model.phi))
print('Polynomial parameters number: ', model.poly_model.phi.shape, ' + ', model.poly_model.eta.shape)
model.load_discretization_model(r'.\model\localization\model_poly_K10_322.pth')
model.load_state_dict(torch.load(r".\model\poly\model_K10_s2_126.pth"))

y_pred     = model(input_data)
z_pred     = y_pred.detach().numpy()
x, y = np.mgrid[0:1:0.1, 0:1:0.1]
xnew, ynew = np.mgrid[0:1:0.01, 0:1:0.01]
tck = interpolate.bisplrep(x, y, z_pred, s=2)
znew = interpolate.bisplev(xnew[:,0], ynew[0,:], tck)

print('loss sums: ', criterion(y_pred, true_label).item())
print('loss means: ', criterion1(y_pred, true_label).item())

fig = plt.figure()
ax  = fig.add_subplot(111, projection='3d')
ax.set_zlim((0.58, 0.70))
ax.set_xticks([0, 0.2, 0.4, 0.6, 0.8])
ax.set_yticks([0, 0.2, 0.4, 0.6, 0.8])
ax.set_zticks([0.58, 0.62, 0.66, 0.7], fontsize=20)
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
ax.tick_params('z', labelsize=18)
ax.set_xlabel('$x$', fontsize=28, labelpad=8)
ax.set_ylabel('$y$', fontsize=28, labelpad=8)
ax.set_zlabel('PQC output', fontsize=20, labelpad=12)
plt.title('$K=10$', fontsize=24)

colors = [(0, '#E9C4CB'), (0.25, '#C0A3C0'), (0.5, '#9284B4'), (0.75, '#5E67AA'), (1, '#1E50A1')]
my_cmap = LinearSegmentedColormap.from_list('custom_cmap', colors)

ax.plot_surface(xnew, ynew, znew, rstride=1, cstride=1, alpha=None, linewidth=0.01, cmap=my_cmap, antialiased=True)
plt.savefig(r'.\pqc_K10.png', bbox_inches='tight', pad_inches=0.4, dpi=500)


