import os
from typing import Callable, Optional

import numpy as np
import torch
from torch_geometric.data import InMemoryDataset

from temporal_graph.data import TemporalData


class Tmall(InMemoryDataset):
    url = ("https://www.dropbox.com/sh/palzyh5box1uc1v/"
           "AACSLHB7PChT-ruN-rksZTCYa?dl=0")

    def __init__(
        self,
        root: str,
        transform: Optional[Callable] = None,
        pre_transform: Optional[Callable] = None,
        force_reload: bool = False,
    ):
        super().__init__(root, transform, pre_transform,
                         force_reload=force_reload)
        self.load(self.processed_paths[0], data_cls=TemporalData)

    def download(self):
        raise RuntimeError(
            f"Dataset not found. Please download '{self.raw_file_names}' from "
            f"'{self.url}' and move it to '{self.raw_dir}'")

    @property
    def raw_file_names(self) -> str:
        return ['tmall.txt', 'node2label.txt']

    @property
    def processed_file_names(self) -> str:
        return 'data.pt'

    def process(self):
        src = []
        dst = []
        t = []
        path = os.path.join(self.raw_dir, 'tmall.txt')
        with open(path) as f:
            for line in f:
                x, y, z = line.strip().split()
                src.append(int(x))
                dst.append(int(y))
                t.append(float(z))
        num_nodes = max(max(src), max(dst)) + 1
        src = torch.tensor(src, dtype=torch.long)
        dst = torch.tensor(dst, dtype=torch.long)
        t = torch.tensor(t, dtype=torch.float)

        t, perm = t.sort()
        src = src[perm]
        dst = dst[perm]

        nodes = []
        labels = []
        path = os.path.join(self.raw_dir, 'node2label.txt')
        with open(path) as f:
            for line in f:
                node, label = line.strip().split()
                nodes.append(int(node))
                labels.append(int(label))

        from sklearn.preprocessing import LabelEncoder
        labels = LabelEncoder().fit_transform(labels)
        y = torch.full((num_nodes, ), -1, dtype=torch.long)
        y[nodes] = torch.tensor(labels, dtype=torch.long)

        data = TemporalData(src=src, dst=dst, t=t, y=y, num_nodes=num_nodes)

        path = os.path.join(self.raw_dir, 'tmall.npy')
        if os.path.exists(path):
            print('Loading processed node features...')
            x = np.load(path)
            x = torch.tensor(x).to(torch.float).transpose(0, 1).contiguous()
            # reindexing
            # according to the SpikeNet paper
            others = set(range(num_nodes)) - set(nodes)
            all_nodes = nodes + list(others)
            new_x = torch.zeros_like(x)
            new_x[all_nodes] = x
            x = new_x
            # Merge snapshots with a window size 10,
            # according to the SpikeNet paper
            data = data.merge(step=10)
        else:
            x = None
        data.x = x
        data = data if self.pre_transform is None else self.pre_transform(data)
        self.save([data], self.processed_paths[0])
