#!/usr/bin/env python
# coding: utf-8


import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.linalg as linalg
import torch_geometric.utils as utils
import torch_geometric.nn as gnn
from IPython.display import display, clear_output
from sklearn.metrics import f1_score
from torch_geometric.datasets import Planetoid
from torch_geometric.nn.conv import MessagePassing, GCNConv
from torch_geometric.utils.num_nodes import maybe_num_nodes
from tqdm import tqdm

# Required to avoid type 3 fonts in figure pdfs.
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

torch.set_printoptions(precision=2,sci_mode=False, linewidth=200)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Using device:', device)
loss_fn = nn.CrossEntropyLoss()


def accuracy(output, labels):
    preds = output.argmax(dim=1)
    correct = (preds == labels).sum().item()
    return correct / labels.size(0)

def train_model(model, optimizer, data, epochs=200):
    model.to(device)
    model.train()
    data = data.to(device)
    # pbar = tqdm(range(epochs), leave=False, desc='Training')
    for epoch in range(epochs):
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        loss = loss_fn(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        # print(f'Epoch {epoch+1:03d}: Loss = {loss.item():.4f}')
        # pbar.set_postfix({'Loss': loss.item()})

def test_model(model, data):
    model.to(device)
    model.eval()
    data = data.to(device)
    with torch.no_grad():
        out = model(data.x, data.edge_index)
    acc = accuracy(out[data.test_mask], data.y[data.test_mask])
    return acc



class RobConv1(MessagePassing):
    def __init__(self, in_channels, out_channels, num_convolutions):
        super().__init__(aggr='add')  # "Add" aggregation (Step 5).
        self.lin = nn.Linear(in_channels, out_channels, bias=False)
        self.bias = nn.Parameter(torch.empty(out_channels))
        self.num_convolutions = num_convolutions
        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.bias.data.zero_()

    def forward(self, x, edge_index):
        # Linear transformation of the input features.
        x = self.lin(x)

        # Add self-loops to the adjacency matrix.
        edge_index, _ = utils.add_remaining_self_loops(edge_index, num_nodes=x.size(0))

        # Compute normalization.
        row, col = edge_index
        deg = utils.degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # To isolate the rank-1 component
        deg_sqrt = deg.pow(0.5).unsqueeze(1)
        total_edges = edge_index.size(1)
        
        # Propagate messages and remove rank-1 component, add bias.
        for _ in range(self.num_convolutions):
            rank1_comp = deg_sqrt.T@x
            rank1_comp = deg_sqrt@rank1_comp / total_edges
            x = self.propagate(edge_index, x=x, norm=norm)
            x -= rank1_comp
        
        x += self.bias
        return x

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

class RobConv2(MessagePassing):
    def __init__(self, in_channels, out_channels, num_convolutions):
        super().__init__(aggr='add')
        self.lin = nn.Linear(in_channels, out_channels, bias=False)
        self.bias = nn.Parameter(torch.empty(out_channels))
        self.num_convolutions = num_convolutions
        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.bias.data.zero_()

    def forward(self, x, edge_index):
        # Linear transformation of the input features.
        x = self.lin(x)

        # Add self-loops to the adjacency matrix.
        edge_index, _ = utils.add_remaining_self_loops(edge_index, num_nodes=x.size(0))

        # Compute normalization.
        row, col = edge_index
        deg = utils.degree(col, x.size(0), dtype=x.dtype)
        norm = 1.0 / deg.mean()

        # To isolate the rank-1 component
        n = x.size(0)
        J = torch.ones((n, 1), device=x.device) / np.sqrt(n)
        
        # Propagate messages and remove rank-1 component, add bias.
        for _ in range(self.num_convolutions):
            rank1_comp = J.T@x
            rank1_comp = J@rank1_comp
            x = self.propagate(edge_index, x=x, norm=norm)
            x -= rank1_comp
        
        x += self.bias
        return x

    def message(self, x_j, norm):
        return norm * x_j

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels, num_convolutions):
        super().__init__(aggr='add')
        self.lin = nn.Linear(in_channels, out_channels, bias=False)
        self.bias = nn.Parameter(torch.empty(out_channels))
        self.num_convolutions = num_convolutions
        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.bias.data.zero_()

    def forward(self, x, edge_index):
        edge_index, _ = utils.add_self_loops(edge_index, num_nodes=x.size(0))
        x = self.lin(x)
        row, col = edge_index
        deg = utils.degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        for _ in range(self.num_convolutions):
            x = self.propagate(edge_index, x=x, norm=norm)
        
        x += self.bias
        return x

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

class GCN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, n_layers, num_convolutions=1):
        super(GCN, self).__init__()
        self.n_layers = n_layers
        self.relu = nn.ReLU()
        channels = [input_dim] + [hidden_dim]*(n_layers-1) + [output_dim]
        self.module_list = []
        for i in range(n_layers):
            self.module_list.append(GCNConv(channels[i], channels[i+1], num_convolutions))
        self.module_list = nn.ModuleList(self.module_list)

    def forward(self, x, edge_index):
        for (i, module) in enumerate(self.module_list):
            x = module(x, edge_index)
            x = self.relu(x) if i < self.n_layers - 1 else x
        return x

class GCNRob1(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, n_layers, num_convolutions=1):
        super(GCNRob1, self).__init__()
        self.n_layers = n_layers
        self.relu = nn.ReLU()
        channels = [input_dim] + [hidden_dim]*(n_layers-1) + [output_dim]
        self.module_list = []
        for i in range(n_layers):
            self.module_list.append(RobConv1(channels[i], channels[i+1], num_convolutions))
        self.module_list = nn.ModuleList(self.module_list)

    def forward(self, x, edge_index):
        for (i, module) in enumerate(self.module_list):
            x = module(x, edge_index)
            x = self.relu(x) if i < self.n_layers - 1 else x
        return x

class GCNRob2(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, n_layers, num_convolutions=1):
        super(GCNRob2, self).__init__()
        self.n_layers = n_layers
        self.relu = nn.ReLU()
        channels = [input_dim] + [hidden_dim]*(n_layers-1) + [output_dim]
        self.module_list = []
        for i in range(n_layers):
            self.module_list.append(RobConv2(channels[i], channels[i+1], num_convolutions))
        self.module_list = nn.ModuleList(self.module_list)

    def forward(self, x, edge_index):
        for (i, module) in enumerate(self.module_list):
            x = module(x, edge_index)
            x = self.relu(x) if i < self.n_layers - 1 else x
        return x


def experiment(model, data):
    optimizer = torch.optim.Adam(model.parameters(), 0.01, weight_decay=5e-4)
    train_model(model, optimizer, data)
    test_acc = test_model(model, data)
    return test_acc


def per_dataset(dataset, n_trials=10):
    num_features = dataset.num_features
    num_classes = dataset.num_classes

    x_values = []
    acc_gcn_values = []
    acc_gcnr1_values = []
    acc_gcnr2_values = []

    data = dataset._data.to(device)
    plt.figure(figsize=(10, 5))
    pbar = tqdm(range(2, 33, 2), leave=False, desc='Layers')
    for k in pbar:
        acc_gcn = 0.
        acc_gcnr1 = 0.
        acc_gcnr2 = 0.
        for t in range(n_trials):
            gcn_model = GCN(input_dim=num_features, hidden_dim=16, output_dim=num_classes, n_layers=k)
            gcnr_model1 = GCNRob1(input_dim=num_features, hidden_dim=16, output_dim=num_classes, n_layers=k)
            gcnr_model2 = GCNRob2(input_dim=num_features, hidden_dim=16, output_dim=num_classes, n_layers=k)
            
            acc_gcn += experiment(gcn_model, data)
            acc_gcnr1 += experiment(gcnr_model1, data)
            acc_gcnr2 += experiment(gcnr_model2, data)
        
        acc_gcn /= n_trials
        acc_gcnr1 /= n_trials
        acc_gcnr2 /= n_trials
        x_values.append(k)
        acc_gcn_values.append(acc_gcn)
        acc_gcnr1_values.append(acc_gcnr1)
        acc_gcnr2_values.append(acc_gcnr2)
 
    fig = plt.figure(figsize=(6,3), facecolor=[1,1,1])
    plt.xlabel('Number of Layers', fontsize=18)
    plt.ylabel('Accuracy', fontsize=18)
    plt.plot(x_values, acc_gcn_values, linestyle='-', marker='.', label='GCN')
    plt.plot(x_values, acc_gcnr1_values, linestyle='-', marker='s', label=r'GCN with $vv^T$ removed')
    plt.plot(x_values, acc_gcnr2_values, linestyle='-', marker='*', label=r'GCN with $\mathbf{11}^T$ removed')
    plt.grid()
    plt.legend(loc='upper right')
    plt.show()
    fig.savefig(f'figures/GCN_vs_GCNRobust_{dataset.name}_layers.pdf', bbox_inches='tight')

datasets = [
    Planetoid(root='data/', name='Cora'),
    Planetoid(root='data/', name='CiteSeer'),
    Planetoid(root='data/', name='PubMed')]

cora = datasets[0]
per_dataset(cora, n_trials=50)

citeseer = datasets[1]
per_dataset(citeseer, n_trials=50)

pubmed = datasets[2]
per_dataset(pubmed, n_trials=50)
