import torch
from tqdm import tqdm
from utils import get_split
import random
import torch.nn as nn
from Loss import InfoNCE, BarlowTwins, loss_dependence
from utils import get_structural_encoding, LREvaluator
from Model import GCN, Encoder
from Enhance import SimFeatureEnhance
import argparse
import numpy as np
from data_utils import Data_Loader, set_seed
import sys
import os
sys.path.insert(0, sys.path[0]+"/../../")
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


def train(encoder_model, args, data, optimizer, epoch):
    encoder_model.train()
    optimizer.zero_grad()
    edge_index, edge_weight = data.edge_index, data.edge_attr
    z, z1, z2 = encoder_model(data.x, edge_index, data.rw_embeddings, edge_weight)
    p1 = encoder_model.enhance1.w
    p2 = encoder_model.enhance2.w
    h1, h2 = [encoder_model.project(x) for x in [z1, z2]]
    loss = 0.
    if args.losses == "info": #InfoNCE
         loss = InfoNCE(h1, h2)
    elif args.losses =="barlow":
        loss = BarlowTwins(h1, h2) #BarlowTwins
    if args.num_w != 1:
        loss_dep = loss_dependence(p1, p2, p1.shape[0], lamda1=args.lamda1, lamda2=args.lamda2)
        loss = loss + args.beta * loss_dep
    loss.backward()
    optimizer.step()
    return loss.item()

def test(encoder_model, data, id):
    encoder_model.eval()
    edge_index, edge_weight = data.edge_index, data.edge_attr
    z, z1, z2 = encoder_model(data.x, edge_index, data.rw_embeddings,edge_weight)
    embeds = z
    if data.name in ['wikics', 'computers', 'photo', 'physics']:
        split = get_split(num_samples=embeds.size()[0], train_ratio=0.1, test_ratio=0.8)
    elif data.name in ['cora', 'citeseer', 'pubmed']:
        split = get_split(num_samples=embeds.size()[0], train_ratio=0.1, test_ratio=0.8)
    result = LREvaluator()(embeds, data.y, split)
    return result


def main(args):
    run = args.runs
    dataset = Data_Loader(args.dataset)
    data = dataset[0].to(device)
    data.num_classes = dataset.num_classes
    data.name = args.dataset
    activation = ({'relu': nn.ReLU, 'prelu': nn.PReLU, 'lrelu': nn.LeakyReLU, 'elu': nn.ELU})[args.activation]


    if args.order>0:
        if os.path.exists('rwse/rw_{}_{}'.format(args.dataset, args.order)):
            rw = torch.load('rwse/rw_{}_{}'.format(args.dataset, args.order))
            print("Load")
        else:
            rw = get_structural_encoding(data.edge_index.cpu(), data.num_nodes, str_enc_dim=args.order).to(device)
            torch.save(rw, 'rwse/rw_{}_{}'.format(args.dataset, args.order))
            print("Save")
        data.rw_embeddings = rw
    else:
        data.rw_embeddings = torch.empty(0, 0).to(device)


    Mi_results, Ma_results = [], []
    for id in range(run):
        seed = random.randint(1, 100)
        print("seed:", seed)
        set_seed(seed)
        enhance1 = SimFeatureEnhance(data.num_features, args.num_w, str_enc_dim=args.order, thre=args.theta).to(device)
        enhance2 = SimFeatureEnhance(data.num_features, args.num_w, str_enc_dim=args.order, thre=args.theta).to(device)
        gconv = GCN(input_dim=dataset.num_features, hidden_dim=args.hidden, activation=activation, num_layers=args.num_layers).to(device)
        encoder_model = Encoder(encoder=gconv, input_dim=args.hidden, hidden_dim=args.hidden, enhance=(enhance1, enhance2)).to(device)

        encoder_model.enhance1.w.param_name = 'special_param'
        encoder_model.enhance2.w.param_name = 'special_param'
        special_params = []
        other_params = []
        for param_name, param in encoder_model.named_parameters():
            if hasattr(param, 'param_name') and param.param_name == 'special_param':
                special_params.append(param)
            else:
                other_params.append(param)
        param_groups = [
            {'params': special_params, 'lr': args.lr_p, 'weight_decay': args.weight_decay_p},
            {'params': other_params,'lr': args.lr, 'weight_decay': args.weight_decay}
        ]
        optimizer_Adam = torch.optim.Adam(param_groups)

        with tqdm(total=args.epoch, desc='(T)') as pbar:
            for epoch in range(0, args.epoch):
                loss = train(encoder_model, args, data, optimizer_Adam, epoch)
                pbar.set_postfix({'loss': loss})
                pbar.update()
        result = test(encoder_model, data, id)
        Mi_results.append(result["micro_f1"])
        Ma_results.append(result["macro_f1"])

    Mi_mean, Mi_std = np.mean(Mi_results), np.std(Mi_results)
    Ma_mean, Ma_std = np.mean(Ma_results), np.std(Ma_results)
    print("\nMicro-F1, Mean ± Std:", Mi_mean, "±", Mi_std)
    print("\nMacro-F1, Mean ± Std:", Ma_mean, "±", Ma_std)




parser = argparse.ArgumentParser()
parser.add_argument('--runs', type=int, default=10)
parser.add_argument('--epoch', type=int, default=10)#
parser.add_argument('--dataset', type=str, default='cora')  
parser.add_argument('--activation', type=str, default='relu')
parser.add_argument('--hidden', type=int, default=512)
parser.add_argument('--num_layers', type=int, default=2)#
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--weight_decay', type=float, default=0) 
parser.add_argument('--lr_p', type=float, default=0.001)
parser.add_argument('--weight_decay_p', type=float, default=0) 
parser.add_argument('--num_w', type=int, default=20) 
parser.add_argument('--losses', type=str, default='barlow') # info , barlow
parser.add_argument('--aug', type=str, default='sim')
parser.add_argument('--beta', type=float, default=1.0)
parser.add_argument('--theta', type=float, default=0.6)
parser.add_argument('--lamda1', type=float, default=0.0001)
parser.add_argument('--lamda2', type=float, default=0.001)
parser.add_argument('--order', type=int, default=16)
args = parser.parse_args()
print("data:", args.dataset)
main(args)
