import torch
import torch.nn as nn

from nfmc_jax.bayesian_network.network import BayesianNetwork
from nfmc_jax.utils.torch_distributions import CustomDistribution
from nfmc_jax.bayesian_network.gaussian import Isotropic, ConditionalIsotropic

from nfmc_jax.bayesian_network.rq_spline import SplineFlow, ConditionalSplineFlow
from nfmc_jax.utils.torch_distributions import Funnel
import torch.optim as optim
from copy import deepcopy
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

torch.manual_seed(1)

n_dim = 2
n_train = 250
n_val = 5000
n_test = 5000
n_epochs = 50_000

distribution = Funnel(n_dim=n_dim)
x_train = distribution.sample(n_train)
x_val = distribution.sample(n_val)
x_test = distribution.sample(n_test)

# plt.scatter(x_train[:, 0], x_train[:, 1])
# plt.scatter(x_val[:, 0], x_val[:, 1])
# plt.show()

# Standardize data for now. We can make this a "flow layer" later.
# data_mean = x_train.mean(dim=0)
# data_std = x_train.std(dim=0)
# x_train = ((x_train - data_mean) / data_std) / 10
# x_val = ((x_val - data_mean) / data_std) / 10
# x_test = ((x_test - data_mean) / data_std) / 10

# x_train /= 10
# x_val /= 10
# x_test /= 10

bn = BayesianNetwork(
    nodes=[Isotropic(1), ConditionalIsotropic(1, 1)],
    edges=[(0, 1)],
    mask=torch.tensor([0] + [1] * (n_dim - 1))
)
x0 = torch.tensor([[0.0], [0.5]])  # Solution
x0 = torch.tensor([[1.0], [0.0]])  # Solution
print(bn.nodes[1].network.linear.weight.data)
# bn.nodes[1].network.linear.weight.data = x0
optimizer = optim.Adam(bn.parameters())
# optimizer = optim.SGD(bn.parameters(), lr=1e-2, momentum=0.9)
# optimizer = optim.LBFGS(bn.parameters(), lr=1e-1, max_iter=100)

bn.train()

best_loss_val = torch.inf
best_epoch = 0
# best_state = deepcopy(bn.state_dict())
for epoch in range(1, n_epochs + 1):
    loss_train, loss_val = bn.train_step(x_train, x_val, optimizer)
    # if loss_val < best_loss_val:
    #     best_loss_val = loss_val
    #     best_epoch = epoch
    #     best_state = deepcopy(bn.state_dict())
    if epoch % 1000 == 0:
        print(
            f'[{epoch:>4}] Train loss: {float(loss_train):7.3f}, '
            f'Val loss: {float(loss_val):7.3f}, '
            f'Best val loss: {float(best_loss_val):7.3f} @ [{best_epoch:>4}]'
        )

# bn.load_state_dict(best_state)

# LBFGS
# def closure():
#     optimizer.zero_grad()
#     loss = -bn.log_prob(x_train).mean()
#     loss.backward()
#     return loss
#
#
# optimizer.step(closure)
#
# with torch.no_grad():
#     loss_train = -bn.log_prob(x_train).mean()
#     loss_val = -bn.log_prob(x_val).mean()
#     print(f'Train loss: {float(loss_train):7.3f}, Val loss: {float(loss_val):7.3f}')
#     print()

bn.eval()

print('Weight:', bn.nodes[1].network.linear.weight)
print('Bias:', bn.nodes[1].network.linear.bias)

# with torch.no_grad():
#     bn_samples = bn.sample(n_test)
#     df = pd.concat([
#         pd.DataFrame(x_test.numpy()),
#         pd.DataFrame(bn_samples.numpy()),
#         # pd.DataFrame(x_train.numpy()),
#     ]).reset_index(drop=True)
#     df['distribution'] = ['Test'] * n_test + ['BN'] * n_test  # + ['Train'] * n_train
#     sns.pairplot(df, corner=True, hue='distribution', diag_kind="hist")
#     plt.show()
#
#     sns.kdeplot(data=df[[0, 'distribution']].rename(columns={0: '0'}), x="0", hue='distribution')
#     plt.yscale('log')
#     plt.show()
#
#     sns.kdeplot(data=df[[1, 'distribution']].rename(columns={1: '1'}), x="1", hue='distribution')
#     plt.yscale('log')
#     plt.show()
#
#     sns.kdeplot(data=df[[0, 'distribution']].rename(columns={0: '0'}), x="0", hue='distribution')
#     plt.show()
#
#     sns.kdeplot(data=df[[1, 'distribution']].rename(columns={1: '1'}), x="1", hue='distribution')
#     plt.show()
