from models.ssm import DiagonalSSM, SimpleRoland
from temporal_graph.transforms import RandomNodeSplit, ToTemporalUndirected, StratifyNodeSplit
from temporal_graph.datasets import DBLP, STARDataset, Tmall
from tqdm import tqdm
from torch_geometric.transforms import AddSelfLoops, Compose, ToUndirected
from torch_geometric.loader import NeighborLoader
from torch_geometric.data import Data
from torch_geometric import seed_everything
from sklearn import metrics
from logger import setup_logger
import torch.nn.functional as F
import torch
import argparse
import time
from copy import copy


parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=2024)
parser.add_argument('--dataset', type=str, default="Tmall")
parser.add_argument('--hidden_channels', type=int, default=32)
parser.add_argument('--epochs', type=int, default=51)
parser.add_argument('--learning_rate', type=float, default=0.01)
parser.add_argument('--weight_decay', type=float, default=0)
parser.add_argument('--ssm_format', type=str, default='siso')
parser.add_argument('--token_mixer', type=str, default='interp')
parser.add_argument('--train_size', type=float, default=0.8)
parser.add_argument('--device', type=int, default=0)
parser.add_argument('--log_name', type=str, default="")
parser.add_argument('--model_name', type=str, default="ssm")
args = parser.parse_args()

seed_everything(args.seed)


args = parser.parse_args()
args.test_size = 1 - args.train_size
args.train_size = args.train_size - 0.05
args.val_size = 0.05


config = {
    "train_size": args.train_size,
    'hidden_channels': args.hidden_channels,
    'learning_rate': args.learning_rate,
    "weight_decay": args.weight_decay,
    "ssm_format": args.ssm_format,
    "token_mixer": args.token_mixer,
    'epochs': args.epochs,
    "seed": args.seed
}


logger = setup_logger(
    output=f"logs{args.log_name}/{args.dataset}/{args.train_size:.2f}", name="test")

path = './data/Tmall'

transform = Compose(
    [ToTemporalUndirected(),
     StratifyNodeSplit(num_val=args.val_size, num_test=args.test_size, unknown=-1)])
data = Tmall(root=path, transform=transform, force_reload=False)[0]
device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'


def to_data(temporal_data):
    x = temporal_data.x
    return Data(x=x,
                edge_index=temporal_data.edge_index,
                y=temporal_data.y)


bins = data.time_stamps
snapshots = [data.snapshot(end=i, last_node_attr=True)
             for i in range(data.num_snapshots)]
snapshots = [to_data(snapshot) for snapshot in snapshots]


if args.model_name == "ssm":
    from models.ssm import DiagonalSSM, SimpleRoland
elif args.model_name == "mamba":
    from models.mamba import DiagonalSSM, SimpleRoland
elif args.model_name == "mambav2":
    from models.mamba_v2 import DiagonalS6SSM as DiagonalSSM

print(snapshots[:2])
print(data)
model = DiagonalSSM(
    data.x.size(-1),
    data.y.max().item()+1,
    hidden_channels=config["hidden_channels"],
    ssm_format=config["ssm_format"],
    token_mixer=config["token_mixer"],
).to(device)
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'])
train_loader = NeighborLoader(data.to_static(),
                              # num_neighbors=[-1, -1],
                              num_neighbors=[5, 5],
                              batch_size=2048,
                              shuffle=True,
                              input_nodes=data.train_mask)


num_nodes = data.x.size(0)


def get_subgraph_snapshots(batch):
    edge_index = batch.edge_index
    edge_time = batch.edge_time
    subgraphs = []
    for i, t in enumerate(bins[:-1]):
        mask = edge_time <= t
        g = Data(x=batch.x[:, i, :], edge_index=edge_index[:, mask])
        subgraphs.append(g)
    batch.x = batch.x[:, -1, :]
    batch.edge_index = edge_index
    subgraphs.append(batch)
    return subgraphs


def train(snapshots):
    model.train()
    total_loss = 0.
    for batch in tqdm(train_loader):
        subgraphs = get_subgraph_snapshots(batch)
        subgraphs = [subgraph.to(device) for subgraph in subgraphs]
        batch = subgraphs[-1]
        batch_size = batch.batch_size
        optimizer.zero_grad()
        out = model(subgraphs)
        loss = F.cross_entropy(out[:batch_size], batch.y[:batch_size])
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss


@torch.no_grad()
def test(snapshots):
    model.eval()
    snapshots = [copy(snapshot).to(device) for snapshot in snapshots]
    pred = model(snapshots).argmax(dim=-1)
    metric_macros = []
    metric_micros = []
    for mask in [data.val_mask, data.test_mask]:
        if mask.sum() == 0:
            metric_macros.append(0)
            metric_micros.append(0)
        else:
            metric_macros.append(metrics.f1_score(
                data.y[mask].cpu().numpy(), pred[mask].cpu().numpy(), average='macro'))
            metric_micros.append(metrics.f1_score(
                data.y[mask].cpu().numpy(), pred[mask].cpu().numpy(), average='micro'))
    return metric_macros, metric_micros


best_val = -1e5
best_test = -1e5
best_metric_macros = None

start_time = time.time()
for epoch in range(1, config['epochs']+1):
    loss = train(snapshots)
    metric_macros, metric_micros = test(snapshots)
    val_acc, test_acc = metric_micros
    if best_val < val_acc:
        best_val = val_acc
        best_test = test_acc
        best_metric_macros = metric_macros
    logger.info(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
    logger.info(
        f'MICROS: Val: {val_acc:.2%}, Test: {test_acc:.2%}, Best Test: {best_test:.2%}')
    logger.info(
        f'MACROS: Val: {best_metric_macros[0]:.2%}, Test: {best_metric_macros[1]:.2%}')
end_tim = time.time()
logger.info(f'Time: {end_tim-start_time:.2f}s')
logger.info(config)
