#!/usr/bin/env python
# coding: utf-8

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.linalg as linalg
import torch_geometric.utils as utils
import torch_geometric.nn as gnn
from prettytable import PrettyTable
from sklearn.metrics import f1_score
from torch.distributions import Bernoulli, MultivariateNormal
from torch_geometric.data import Data
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import DataLoader, NeighborLoader
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils.num_nodes import maybe_num_nodes
from tqdm.notebook import tqdm

# Required to avoid type 3 fonts in figure pdfs.
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

torch.set_printoptions(precision=2,sci_mode=False, linewidth=200)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Using device:', device)
loss_fn = nn.BCEWithLogitsLoss()


# Data generation
def generate_csbm_data(n_points, n_features, sigma, p, q):
    mu = torch.zeros(n_features, dtype=torch.float, device=device)
    mu[0] = 0.5
    X = np.random.normal(scale=sigma, size=(n_points, n_features))
    X = torch.tensor(X, dtype=torch.float, device=device)
    X[:n_points//2] -= mu
    X[n_points//2:] += mu
    y = torch.zeros(n_points, dtype=torch.long, device=device)
    y[n_points//2:] = 1.0
    data = Data(x=X, y=y, edge_index=None)
    
    # The inbuilt function stochastic_blockmodel_graph does not support
    # random permutations of the nodes, hence, design it manually.
    # Use with_replacement=True to include self-loops.
    probs = torch.tensor([[p, q], [q, p]], dtype=torch.float).to(device)
    row, col = torch.combinations(torch.arange(n_points), r=2, with_replacement=True).t().to(device)
    mask = torch.bernoulli(probs[data.y[row], data.y[col]]).to(torch.bool)
    data.edge_index = torch.stack([row[mask], col[mask]], dim=0)
    data.edge_index = utils.to_undirected(data.edge_index, num_nodes=n_points)
    data.y = data.y.to(torch.float).unsqueeze(1)
    return data


# Training / testing
def train_model(model, lr, data):
    model.to(device)
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    data = data.to(device)
    # pbar = tqdm(range(200), leave=False, desc='Training')
    for epoch in range(50):
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        loss = loss_fn(out, data.y)
        loss.backward()
        optimizer.step()
        # print(f'Epoch {epoch+1:03d}: Loss = {loss.item():.4f}')
        # pbar.set_postfix({'Loss': loss.item()})

def test_model(model, data):
    model.to(device)
    model.eval()
    data = data.to(device)
    with torch.no_grad():
        out = model(data.x, data.edge_index)
        pred = (out > 0.0).float()
    acc = (pred == data.y).float().mean().item()
    return acc


# Model architectures
class RobConv1(MessagePassing):
    def __init__(self, in_channels, out_channels, num_convolutions):
        super().__init__(aggr='add')  # "Add" aggregation (Step 5).
        self.lin = nn.Linear(in_channels, out_channels, bias=False)
        self.bias = nn.Parameter(torch.empty(out_channels))
        self.num_convolutions = num_convolutions
        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.bias.data.zero_()

    def forward(self, x, edge_index):
        # Linear transformation of the input features.
        x = self.lin(x)

        # Add self-loops to the adjacency matrix.
        edge_index, _ = utils.add_remaining_self_loops(edge_index, num_nodes=x.size(0))

        # Compute normalization.
        row, col = edge_index
        deg = utils.degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # To isolate the rank-1 component
        deg_sqrt = deg.pow(0.5).unsqueeze(1)
        total_edges = edge_index.size(1)
        
        # Propagate messages and remove rank-1 component, add bias.
        for _ in range(self.num_convolutions):
            rank1_comp = deg_sqrt.T@x
            rank1_comp = deg_sqrt@rank1_comp / total_edges
            x = self.propagate(edge_index, x=x, norm=norm)
            x -= rank1_comp
        
        x += self.bias
        return x

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

class RobConv2(MessagePassing):
    def __init__(self, in_channels, out_channels, num_convolutions):
        super().__init__(aggr='add')
        self.lin = nn.Linear(in_channels, out_channels, bias=False)
        self.bias = nn.Parameter(torch.empty(out_channels))
        self.num_convolutions = num_convolutions
        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.bias.data.zero_()

    def forward(self, x, edge_index):
        # Linear transformation of the input features.
        x = self.lin(x)

        # Add self-loops to the adjacency matrix.
        edge_index, _ = utils.add_remaining_self_loops(edge_index, num_nodes=x.size(0))

        # Compute normalization.
        row, col = edge_index
        deg = utils.degree(col, x.size(0), dtype=x.dtype)
        norm = 1.0 / deg.mean()

        # To isolate the rank-1 component
        n = x.size(0)
        J = torch.ones((n, 1), device=x.device) / np.sqrt(n)
        
        # Propagate messages and remove rank-1 component, add bias.
        for _ in range(self.num_convolutions):
            rank1_comp = J.T@x
            rank1_comp = J@rank1_comp
            x = self.propagate(edge_index, x=x, norm=norm)
            x -= rank1_comp
        
        x += self.bias
        return x

    def message(self, x_j, norm):
        return norm * x_j

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels, num_convolutions):
        super().__init__(aggr='add')
        self.lin = nn.Linear(in_channels, out_channels, bias=False)
        self.bias = nn.Parameter(torch.empty(out_channels))
        self.num_convolutions = num_convolutions
        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.bias.data.zero_()

    def forward(self, x, edge_index):
        edge_index, _ = utils.add_self_loops(edge_index, num_nodes=x.size(0))
        x = self.lin(x)
        row, col = edge_index
        deg = utils.degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        for _ in range(self.num_convolutions):
            x = self.propagate(edge_index, x=x, norm=norm)
        
        x += self.bias
        return x

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]

        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j

class GCN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, n_layers, num_convolutions):
        super(GCN, self).__init__()
        self.n_layers = n_layers
        self.relu = nn.ReLU()
        channels = [input_dim] + [hidden_dim]*(n_layers-1) + [output_dim]
        self.module_list = []
        for i in range(n_layers):
            self.module_list.append(GCNConv(channels[i], channels[i+1], num_convolutions))
        self.module_list = nn.ModuleList(self.module_list)

    def forward(self, x, edge_index):
        for (i, module) in enumerate(self.module_list):
            x = module(x, edge_index)
            x = self.relu(x) if i < self.n_layers - 1 else x
        return x

class GCNRob1(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, n_layers, num_convolutions):
        super(GCNRob1, self).__init__()
        self.n_layers = n_layers
        self.relu = nn.ReLU()
        channels = [input_dim] + [hidden_dim]*(n_layers-1) + [output_dim]
        self.module_list = []
        for i in range(n_layers):
            self.module_list.append(RobConv1(channels[i], channels[i+1], num_convolutions))
        self.module_list = nn.ModuleList(self.module_list)

    def forward(self, x, edge_index):
        for (i, module) in enumerate(self.module_list):
            x = module(x, edge_index)
            x = self.relu(x) if i < self.n_layers - 1 else x
        return x

class GCNRob2(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, n_layers, num_convolutions):
        super(GCNRob2, self).__init__()
        self.n_layers = n_layers
        self.relu = nn.ReLU()
        channels = [input_dim] + [hidden_dim]*(n_layers-1) + [output_dim]
        self.module_list = []
        for i in range(n_layers):
            self.module_list.append(RobConv2(channels[i], channels[i+1], num_convolutions))
        self.module_list = nn.ModuleList(self.module_list)

    def forward(self, x, edge_index):
        for (i, module) in enumerate(self.module_list):
            x = module(x, edge_index)
            x = self.relu(x) if i < self.n_layers - 1 else x
        return x


# Plotting helpers
labels = {
    'GCN': 'Original GCN',
    'GCNRob1': r'GCN with $vv^T$ removed',
    'GCNRob2': r'GCN with $\mathbf{11}^T$ removed'
}

linestyles = [':']*4 + ['-']*4
markers = {
    'GCN': 'o',
    'GCNRob1': 's',
    'GCNRob2': '*'
}

def plot_with_std(x, y, yerr, label, linestyle='-', marker='o'):
    y = np.asarray(y)
    plt.plot(x, y, linewidth=2, linestyle= linestyle, marker=marker, markersize=4, label=label)
    if yerr is not None:
        yerr = np.asarray(yerr)
        plt.fill_between(x, np.clip(y - yerr, 0.5, 1), np.clip(y + yerr, 0.5, 1), alpha=0.05)

def plot_metrics(fname, title, xlabel, ylabel, xaxis, yaxes, yerrs, scales=('linear', 'linear'), vert_lines=None):
    fig = plt.figure(figsize=(6,3), facecolor=[1,1,1])
    # plt.title(title, fontsize=18)
    plt.xscale(scales[0])
    plt.yscale(scales[1])
    plt.xlabel(xlabel, fontsize=18)
    plt.ylabel(ylabel, fontsize=18)
    for yaxis, yerr, model_type in zip(yaxes, yerrs, labels.keys()):
        plot_with_std(xaxis, yaxis, yerr, labels[model_type], linestyle='-', marker=markers[model_type])
    
    if vert_lines is not None:
        for vert_line in vert_lines:
            plt.axvline(x=vert_line[0], color='black', linestyle='-.', linewidth=2, label=vert_line[1])
    
    plt.grid()
    plt.legend()
    plt.show()
    fig.savefig(fname, bbox_inches='tight')

def experiment(n_trials, model, n, d, sigma, p, q, pbar=None):
    accs = np.zeros(n_trials)
    for t in range(n_trials):
        model.to(device)
        train_data = generate_csbm_data(n, d, sigma, p, q)
        train_model(model, 0.01, train_data)
        test_data = generate_csbm_data(n, d, sigma, p, q)
        accs[t] = test_model(model, test_data)
        if pbar is not None:
            pbar.set_postfix({'Trial': t+1, 'Accuracy': accs[t]})
    return accs.mean().item(), accs.std().item()


# Parameter-specific helpers for $\sigma$, $\gamma$ and $k=$ num_convs
def evaluate_metrics_sigma(n_trials, n, d, sigmas, p, q, num_convolutions):
    accs_gcn_means = np.zeros(len(sigmas))
    accs_gcn_stds = np.zeros(len(sigmas))
    accs_gcnrob1_means = np.zeros(len(sigmas))
    accs_gcnrob1_stds = np.zeros(len(sigmas))
    accs_gcnrob2_means = np.zeros(len(sigmas))
    accs_gcnrob2_stds = np.zeros(len(sigmas))
    mbar = tqdm(sigmas, desc='Varying sigma')
    for i, sigma in enumerate(mbar):
        gcn = GCN(input_dim=d, hidden_dim=1, output_dim=1, n_layers=1, num_convolutions=num_convolutions)
        accs_gcn_means[i], accs_gcn_stds[i] = experiment(n_trials, gcn, n, d, sigma, p, q)
        gcnrob1 = GCNRob1(input_dim=d, hidden_dim=1, output_dim=1, n_layers=1, num_convolutions=num_convolutions)
        accs_gcnrob1_means[i], accs_gcnrob1_stds[i] = experiment(n_trials, gcnrob1, n, d, sigma, p, q)
        gcnrob2 = GCNRob2(input_dim=d, hidden_dim=1, output_dim=1, n_layers=1, num_convolutions=num_convolutions)
        accs_gcnrob2_means[i], accs_gcnrob2_stds[i] = experiment(n_trials, gcnrob2, n, d, sigma, p, q)
    
    acc_means = [accs_gcn_means, accs_gcnrob1_means, accs_gcnrob2_means]
    acc_stds = [accs_gcn_stds, accs_gcnrob1_stds, accs_gcnrob2_stds]
    return acc_means, acc_stds

def evaluate_metrics_gamma(n_trials, n, d, sigma, p, qs, num_convolutions):
    accs_gcn_means = np.zeros(len(qs))
    accs_gcn_stds = np.zeros(len(qs))
    accs_gcnrob1_means = np.zeros(len(qs))
    accs_gcnrob1_stds = np.zeros(len(qs))
    accs_gcnrob2_means = np.zeros(len(qs))
    accs_gcnrob2_stds = np.zeros(len(qs))
    mbar = tqdm(qs, desc='Varying gamma')
    for i, q in enumerate(mbar):
        gcn = GCN(input_dim=d, hidden_dim=1, output_dim=1, n_layers=1, num_convolutions=num_convolutions)
        accs_gcn_means[i], accs_gcn_stds[i] = experiment(n_trials, gcn, n, d, sigma, p, q)
        gcnrob1 = GCNRob1(input_dim=d, hidden_dim=1, output_dim=1, n_layers=1, num_convolutions=num_convolutions)
        accs_gcnrob1_means[i], accs_gcnrob1_stds[i] = experiment(n_trials, gcnrob1, n, d, sigma, p, q)
        gcnrob2 = GCNRob2(input_dim=d, hidden_dim=1, output_dim=1, n_layers=1, num_convolutions=num_convolutions)
        accs_gcnrob2_means[i], accs_gcnrob2_stds[i] = experiment(n_trials, gcnrob2, n, d, sigma, p, q)
    
    acc_means = [accs_gcn_means, accs_gcnrob1_means, accs_gcnrob2_means]
    acc_stds = [accs_gcn_stds, accs_gcnrob1_stds, accs_gcnrob2_stds]
    return acc_means, acc_stds


# Varying $\sigma$
n_trials = 50
n = 2000
d = 20
sigmas = np.geomspace(0.1, 20, num=30)
ratios = 1/sigmas
p = 2*(np.log(n)**3)/n
# p = 0.5
q = p/5
gamma = (p-q)/(p+q)
print(f'Condition for partial recovery: gamma={gamma:.2f} > {1/np.sqrt(n*p):.2f}')
C1, C2, C = 0.01, 2, 1
error_term1 = 1/(n*(gamma**2)*(p+q))
error_term2 = lambda s, k: s*s*np.log(n)*(C2/(gamma*np.sqrt(n*p)))**(2*k)
ratio_thres_1 = 3*np.sqrt(np.log(n)/n)
ratio_thres_2 = lambda k: np.sqrt(np.log(n))*(C2/(gamma*np.sqrt(n*p)))**k
for num_convs in [1, 2, 4, 8, 10, 12, 16]:
    print(p, q)
    print(f'Condition for exact recovery: gamma={gamma:.2f} > {0.7*num_convs*np.sqrt(np.log(n)/(n*p)):.2f}')
    accs_means, accs_stds = evaluate_metrics_sigma(n_trials, n, d, sigmas, p, q, num_convolutions=num_convs)
    yaxes = accs_means
    yerrs = accs_stds
    for yaxis in yaxes:
        for i in range(1, len(yaxis)-1):
            yaxis[i] = (1/3)*(yaxis[i-1] + yaxis[i] + yaxis[i+1])

    # theoretical_acc = np.clip(1 - np.array([C1*(error_term1 + error_term2(sigma, num_convs)) for sigma in sigmas]), 0.5, 1)
    # yaxes = [accs_means[0], accs_means[1], theoretical_acc]
    # yerrs = [accs_stds[0], accs_stds[1], None]
    ratio_vert = np.max([ratio_thres_1, C*ratio_thres_2(num_convs)])
    vert_lines = [[ratio_vert, 'Exact recovery threshold']]
    plot_metrics(
        fname=f'figures/sigma_n=2000_d=20_p={p:.2f}_q={q:.2f}_k={num_convs}.pdf',
        title=f'{num_convs} Convolutions' if num_convs > 1 else '1 Convolution',
        xlabel=r'$\frac{\|\mu-\nu\|}{\sigma}$', ylabel='Accuracy',
        xaxis=ratios, yaxes=yaxes, yerrs=yerrs,
        scales=('log', 'linear'), vert_lines=vert_lines)


# Varying $\gamma$
n_trials = 50
n = 2000
d = 20
sigma = 1
p = 2*(np.log(n)**2)/n
qs = np.linspace(0, p, num=30)
gammas = np.array([(p-q)/(p+q) for q in qs])
C = 5
gamma_thres = lambda k: 0.5 * ((sigma*sigma*np.log(n))**(0.5/k)) * C / np.sqrt(n*p)
for num_convs in [1, 2, 4, 8, 10, 12, 16]:
    print(p, 1./sigma)
    accs_means, accs_stds = evaluate_metrics_gamma(n_trials, n, d, sigma, p, qs, num_convolutions=num_convs)
    vert_lines = [[gamma_thres(num_convs), 'Exact recovery threshold']]
    for yaxis in accs_means:
        for i in range(1, len(yaxis)-1):
            yaxis[i] = (1/3)*(yaxis[i-1] + yaxis[i] + yaxis[i+1])
    plot_metrics(
        fname=f'figures/gamma_n=2000_d=20_sigma={sigma}_p={p:.2f}_k={num_convs}.pdf',
        title=f'{num_convs} Convolutions' if num_convs > 1 else '1 Convolution',
        xlabel=r'$\gamma$', ylabel='Accuracy',
        xaxis=gammas, yaxes=accs_means, yerrs=accs_stds,
        vert_lines=vert_lines)
