import torch
import torch.nn as nn

from nfmc_jax.utils.torch_distributions import CustomDistribution


class BayesianNetwork(nn.Module, CustomDistribution):
    def __init__(self, nodes, edges, mask):  # Nodes need to be topologically sorted (sampling order)
        super().__init__()
        self.nodes = nn.ModuleList(nodes)
        self.edges = edges
        self.n_dim = sum([node.n_dim for node in nodes])
        self.mask = mask

    def regularization(self):
        return torch.stack([n.regularization() for n in self.nodes]).sum()

    def in_neighbors(self, node_number: int):
        mask = torch.zeros(self.n_dim, dtype=torch.bool)
        for (x, y) in self.edges:
            if y == node_number:
                mask[self.mask == x] = 1
        return mask

    def sample_instance(self, *args, **kwargs):
        x = torch.zeros(self.n_dim)
        for node_number, node in enumerate(self.nodes):
            in_neighbors = self.in_neighbors(node_number)
            if not in_neighbors.any():
                x[self.mask == node_number] = node.sample(1)
            else:
                x[self.mask == node_number] = node.sample(x[in_neighbors].view(-1, in_neighbors.sum()))
        return x

    def sample(self, n):
        x = torch.zeros(n, self.n_dim)
        for node_number, node in enumerate(self.nodes):
            in_neighbors = self.in_neighbors(node_number)
            if not in_neighbors.any():
                x[:, self.mask == node_number] = node.sample(n)
            else:
                x[:, self.mask == node_number] = node.sample(x[:, in_neighbors])
        return x

    def log_prob_instance(self, x, *args, **kwargs):
        raise NotImplementedError

    def log_prob(self, x, *args, **kwargs):
        total = torch.zeros(len(x), dtype=x.dtype)
        for node_number, node in enumerate(self.nodes):
            in_neighbors = self.in_neighbors(node_number)
            if not in_neighbors.any():
                total += node.log_prob(x[:, self.mask == node_number])
            else:
                total += node.log_prob(x[:, self.mask == node_number], x[:, in_neighbors])
        return total

    def train_step(self, x_train, x_val, optimizer):
        optimizer.zero_grad()
        train_loss = -self.log_prob(x_train).mean() + self.regularization() * 0.1
        train_loss.backward()
        optimizer.step()
        with torch.no_grad():
            val_loss = -self.log_prob(x_val).mean()
        return train_loss, val_loss


if __name__ == '__main__':
    from nfmc_jax.bayesian_network.rq_spline import SplineFlow, ConditionalSplineFlow
    from nfmc_jax.utils.torch_distributions import Funnel
    import torch.optim as optim
    from copy import deepcopy
    import pandas as pd
    import seaborn as sns
    import matplotlib.pyplot as plt

    torch.manual_seed(123)

    n_dim = 4
    n_train = 250
    n_val = 50
    n_test = 1000
    n_epochs = 500

    distribution = Funnel(n_dim=n_dim)
    x_train = distribution.sample(n_train)
    x_val = distribution.sample(n_val)
    x_test = distribution.sample(n_test)

    bn = BayesianNetwork(
        nodes=[SplineFlow(n_dim=1), ConditionalSplineFlow(n_dim=n_dim - 1, n_dim_cond=1)],
        edges=[(0, 1)],
        mask=torch.tensor([0] + [1] * (n_dim - 1))
    )
    optimizer = optim.Adam(bn.parameters())  # , weight_decay=1e+1)


    @torch.no_grad()
    def reconstruction_info():
        z_train = bn.nodes[1].forward(x_train[:, 1:], x_train[:, 0].reshape(-1, 1))[0]
        x_reconstructed = bn.nodes[1].inverse(z_train, x_train[:, 0].reshape(-1, 1))[0]
        mae = float(torch.abs(x_reconstructed - x_train[:, 1:]).mean())
        max_norm = float(torch.abs(x_reconstructed - x_train[:, 1:]).max())
        return mae, max_norm


    bn.train()
    best_loss_val = torch.inf
    best_epoch = 0
    best_state = deepcopy(bn.state_dict())
    for epoch in range(1, n_epochs + 1):
        loss_train, loss_val = bn.train_step(x_train, x_val, optimizer)
        if loss_val < best_loss_val:
            best_loss_val = loss_val
            best_epoch = epoch
            best_state = deepcopy(bn.state_dict())
        if epoch % 1 == 0:
            with torch.no_grad():
                mae, max_norm = reconstruction_info()
                bn.nodes[1].reset_dropout_masks()
            print(
                f'[{epoch:>4}] Train loss: {float(loss_train):7.3f}, '
                f'Val loss: {float(loss_val):7.3f}, '
                f'Best val loss: {float(best_loss_val):7.3f} @ [{best_epoch:>4}], '
                f'Mean train rec error: {mae:7.3f}, '
                f'Max train rec error: {max_norm:7.3f}, '
            )

    bn.load_state_dict(best_state)
    bn.eval()

    with torch.no_grad():
        bn_samples = bn.sample(n_test)
        df = pd.concat([
            pd.DataFrame(x_test.numpy()),
            pd.DataFrame(bn_samples.numpy()),
            pd.DataFrame(x_train.numpy()),
        ]).reset_index(drop=True)
        df['distribution'] = ['Test'] * n_test + ['BN'] * n_test + ['Train'] * n_train
        sns.pairplot(df, corner=True, hue='distribution')
        plt.show()
