import torch
import numpy as np
import torch.nn as nn
from collections import defaultdict
from typing import Union, List
from utils.utils import NeighborSampler
import math
from models.MemoryModel import MessageAggregator, MemoryBank, GRUMemoryUpdater
from models.modules import TimeEncoder, MergeTimeEncoder


class RandomProjectionModule(nn.Module):
    def __init__(self, node_num, edge_num, dim_factor, num_layer, time_decay_weight, device, use_matrix, beginning_time,
                 matrix_type, not_scale, enforce_dim):
        super(RandomProjectionModule, self).__init__()
        self.node_num = node_num
        self.edge_num = edge_num
        if enforce_dim != -1:
            self.dim = enforce_dim
        else:
            self.dim = min(int(math.log(self.edge_num * 2)) * dim_factor, node_num)
        self.num_layer = num_layer
        self.time_decay_weight = time_decay_weight
        self.begging_time = beginning_time
        self.now_time = beginning_time
        self.device = device
        self.node_degrees = nn.Parameter(torch.zeros((node_num, self.num_layer + 1), dtype=torch.float),
                                         requires_grad=False)
        self.node_degrees.data[:, 0] = 1
        self.P = nn.Parameter(torch.zeros((self.num_layer + 1, self.num_layer + 1)), requires_grad=False)
        self.P.data[:-1, 1:] = torch.eye(self.num_layer)
        self.random_projections = nn.ParameterList()
        self.use_matrix = use_matrix
        self.matrix_type = matrix_type
        self.node_feature_dim = 128
        self.not_scale = not_scale
        if self.use_matrix:
            self.dim = self.node_num
            for i in range(self.num_layer + 1):
                if i == 0:
                    self.random_projections.append(
                        nn.Parameter(torch.eye(self.node_num), requires_grad=False))
                else:
                    self.random_projections.append(
                        nn.Parameter(torch.zeros_like(self.random_projections[i - 1]), requires_grad=False))
        else:
            for i in range(self.num_layer + 1):
                if i == 0:
                    self.random_projections.append(
                        nn.Parameter(torch.normal(0, 1 / math.sqrt(self.dim), (self.node_num, self.dim)),
                                     requires_grad=False))
                else:
                    self.random_projections.append(
                        nn.Parameter(torch.zeros_like(self.random_projections[i - 1]), requires_grad=False))
        self.pair_wise_feature_dim = (2 * self.num_layer + 2) ** 2
        self.mlp = nn.Sequential(nn.Linear(self.pair_wise_feature_dim, self.pair_wise_feature_dim * 4), nn.ReLU(),
                                 nn.Linear(self.pair_wise_feature_dim * 4, self.pair_wise_feature_dim))
        # self.node_mlp1 = MLPMixer(num_tokens=self.num_layer + 1, num_channels=self.dim, token_dim_expansion_factor=4.0,
        #                           channel_dim_expansion_factor=0.5, dropout=0.1)
        # self.node_mlp2 = nn.Linear(self.num_layer + 1, 1)
        # self.node_mlp3 = nn.Linear(self.dim, self.node_feature_dim)

    def update(self, src_node_ids: np.ndarray, dst_node_ids: np.ndarray,
               node_interact_times: np.ndarray):
        if self.matrix_type == 'sum':
            src_node_ids = torch.from_numpy(src_node_ids).to(self.device)
            dst_node_ids = torch.from_numpy(dst_node_ids).to(self.device)
            next_time = node_interact_times[-1]
            node_interact_times = torch.from_numpy(node_interact_times).to(dtype=torch.float, device=self.device)
            time_weight = torch.exp(-self.time_decay_weight * (next_time - node_interact_times))[:, None]

            degree_time_weight = torch.tensor(
                [np.power(np.exp(-self.time_decay_weight * (next_time - self.now_time)), i) for i in
                 range(self.num_layer + 1)], device=self.device, dtype=torch.float)
            self.node_degrees.data = self.node_degrees.data * degree_time_weight[None, :]
            self.node_degrees.scatter_add_(dim=0, index=src_node_ids[:, None].expand(-1, self.num_layer + 1),
                                           src=(self.node_degrees[dst_node_ids] @ self.P) * time_weight)
            self.node_degrees.scatter_add_(dim=0, index=dst_node_ids[:, None].expand(-1, self.num_layer + 1),
                                           src=(self.node_degrees[src_node_ids] @ self.P) * time_weight)

            # move anchor time to now_time
            for i in range(1, self.num_layer + 1):
                self.random_projections[i].data = self.random_projections[i].data * np.power(np.exp(
                    -self.time_decay_weight * (next_time - self.now_time)), i)
            # add link
            for i in range(self.num_layer, 0, -1):
                src_update_messages = self.random_projections[i - 1][dst_node_ids] * time_weight
                dst_update_messages = self.random_projections[i - 1][src_node_ids] * time_weight
                self.random_projections[i].scatter_add_(dim=0, index=src_node_ids[:, None].expand(-1, self.dim),
                                                        src=src_update_messages)
                self.random_projections[i].scatter_add_(dim=0, index=dst_node_ids[:, None].expand(-1, self.dim),
                                                        src=dst_update_messages)
            self.now_time = next_time
        elif self.matrix_type == 'normalize':
            src_node_ids = torch.from_numpy(src_node_ids).to(self.device)
            dst_node_ids = torch.from_numpy(dst_node_ids).to(self.device)
            next_time = node_interact_times[-1]
            node_interact_times = torch.from_numpy(node_interact_times).to(dtype=torch.float, device=self.device)
            time_weight = torch.exp(-self.time_decay_weight * (next_time - node_interact_times))[:, None] + 1e-9

            # move anchor time to now_time
            degree_time_weight = torch.tensor(
                [np.power(np.exp(-self.time_decay_weight * (next_time - self.now_time)), i) for i in
                 range(self.num_layer + 1)], device=self.device, dtype=torch.float)
            self.node_degrees.data = self.node_degrees.data * degree_time_weight[None, :]
            delta_degrees = torch.zeros_like(self.node_degrees.data)
            delta_degrees.scatter_add_(dim=0, index=src_node_ids[:, None].expand(-1, self.num_layer + 1),
                                       src=(self.node_degrees[dst_node_ids] @ self.P) * time_weight)
            delta_degrees.scatter_add_(dim=0, index=dst_node_ids[:, None].expand(-1, self.num_layer + 1),
                                       src=(self.node_degrees[src_node_ids] @ self.P) * time_weight)

            # add link
            for i in range(self.num_layer, 0, -1):
                src_update_messages = (self.random_projections[i - 1][dst_node_ids] - self.random_projections[i][
                    src_node_ids]) * time_weight / (self.node_degrees[src_node_ids, 1] + delta_degrees[
                    src_node_ids, 1])[:,
                                                   None]
                dst_update_messages = (self.random_projections[i - 1][src_node_ids] - self.random_projections[i][
                    dst_node_ids]) * time_weight / (self.node_degrees[dst_node_ids, 1] + delta_degrees[
                    dst_node_ids, 1])[:,
                                                   None]
                self.random_projections[i].scatter_add_(dim=0, index=src_node_ids[:, None].expand(-1, self.dim),
                                                        src=src_update_messages)
                self.random_projections[i].scatter_add_(dim=0, index=dst_node_ids[:, None].expand(-1, self.dim),
                                                        src=dst_update_messages)
            self.node_degrees.data = self.node_degrees.data + delta_degrees
            self.now_time = next_time
        else:
            raise ValueError("Not Implemented Matrix Type!")

    def get_random_projections(self, node_ids):
        random_projections = []
        for i in range(self.num_layer + 1):
            random_projections.append(self.random_projections[i][node_ids])
        return random_projections

    def get_pair_wise_feature(self, src_node_ids, dst_node_ids, edge_types=None):
        if edge_types is not None:
            mask = (edge_types == 1)
            m_src_node_ids = src_node_ids[mask]
            m_dst_node_ids = dst_node_ids[mask]
            src_node_ids[mask] = m_dst_node_ids
            dst_node_ids[mask] = m_src_node_ids

        src_random_projections = torch.stack(self.get_random_projections(src_node_ids), dim=1)
        dst_random_projections = torch.stack(self.get_random_projections(dst_node_ids), dim=1)
        random_projections = torch.cat([src_random_projections, dst_random_projections], dim=1)
        random_feature = torch.matmul(random_projections, random_projections.transpose(1, 2)).reshape(
            len(src_node_ids), -1)
        if self.not_scale:
            return self.mlp(random_feature)
        else:
            random_feature[random_feature < 0] = 0
            random_feature = torch.log(random_feature + 1.0)
            return self.mlp(random_feature)

    # def get_node_feature(self, node_ids):
    #     node_random_projections = torch.stack(self.get_random_projections(node_ids), dim=1)
    #     node_random_projections = self.node_mlp1(node_random_projections)
    #     node_random_projections = self.node_mlp2(node_random_projections.transpose(1, 2)).squeeze(dim=2)
    #     return self.node_mlp3(node_random_projections)

    def reset_random_projections(self):
        self.node_degrees.data[:, 1:] = 0
        for i in range(1, self.num_layer + 1):
            nn.init.zeros_(self.random_projections[i])
        self.now_time = self.begging_time
        if not self.use_matrix:
            nn.init.normal_(self.random_projections[0], mean=0, std=1 / math.sqrt(self.dim))

    def backup_random_projections(self):
        return self.now_time, self.node_degrees.clone(), [self.random_projections[i].clone() for i in
                                                          range(1, self.num_layer + 1)]

    def reload_random_projections(self, random_projections):
        now_time, node_degrees, random_projections = random_projections
        self.now_time = now_time
        self.node_degrees.data = node_degrees.clone()
        for i in range(1, self.num_layer + 1):
            self.random_projections[i].data = random_projections[i - 1].clone()


class RPNet(torch.nn.Module):
    def __init__(self, node_raw_features: np.ndarray, edge_raw_features: np.ndarray, neighbor_sampler: NeighborSampler,
                 time_feat_dim: int, dropout: float, random_projections: Union[List[RandomProjectionModule], None],
                 num_layers: int,
                 num_neighbors: int,
                 device: str,
                 embedding_type: str,
                 time_encoder_type: str):
        """
        General framework for memory-based models, support TGN, DyRep and JODIE.
        :param node_raw_features: ndarray, shape (num_nodes + 1, node_feat_dim)
        :param edge_raw_features: ndarray, shape (num_edges + 1, edge_feat_dim)
        :param neighbor_sampler: NeighborSampler, neighbor sampler
        :param time_feat_dim: int, dimension of time features (encodings)
        :param dropout: float, dropout rate
        :param src_node_mean_time_shift: float, mean of source node time shifts
        :param src_node_std_time_shift: float, standard deviation of source node time shifts
        :param dst_node_mean_time_shift_dst: float, mean of destination node time shifts
        :param dst_node_std_time_shift: float, standard deviation of destination node time shifts
        :param device: str, device
        """
        super(RPNet, self).__init__()

        self.node_raw_features = torch.from_numpy(node_raw_features.astype(np.float32)).to(device)
        self.edge_raw_features = torch.from_numpy(edge_raw_features.astype(np.float32)).to(device)

        self.node_feat_dim = self.node_raw_features.shape[1]
        self.edge_feat_dim = self.edge_raw_features.shape[1]
        self.time_feat_dim = time_feat_dim
        self.dropout = dropout
        self.device = device

        # number of nodes, including the padded node
        self.num_nodes = self.node_raw_features.shape[0]

        self.random_projections = random_projections
        self.time_encoder = MergeTimeEncoder(encoder_type=time_encoder_type, time_dim=time_feat_dim)

        # embedding module
        self.embedding_type = embedding_type
        if embedding_type == 'node_wise':
            self.embedding_module = MLPEmbedding(node_raw_features=self.node_raw_features,
                                                 edge_raw_features=self.edge_raw_features,
                                                 neighbor_sampler=neighbor_sampler,
                                                 time_encoder=self.time_encoder,
                                                 node_feat_dim=self.node_feat_dim,
                                                 edge_feat_dim=self.edge_feat_dim,
                                                 time_feat_dim=self.time_feat_dim,
                                                 num_layers=num_layers,
                                                 num_neighbors=num_neighbors,
                                                 dropout=self.dropout,
                                                 random_projections=self.random_projections)
        elif embedding_type == 'link_wise':
            self.embedding_module = PairMLPEmbedding(node_raw_features=self.node_raw_features,
                                                     edge_raw_features=self.edge_raw_features,
                                                     neighbor_sampler=neighbor_sampler,
                                                     time_encoder=self.time_encoder,
                                                     node_feat_dim=self.node_feat_dim,
                                                     edge_feat_dim=self.edge_feat_dim,
                                                     time_feat_dim=self.time_feat_dim,
                                                     num_layers=num_layers,
                                                     num_neighbors=num_neighbors,
                                                     dropout=self.dropout,
                                                     random_projections=self.random_projections)
        else:
            raise ValueError("Not Implemented Embedding Module!")

    def compute_src_dst_node_temporal_embeddings(self, src_node_ids: np.ndarray, dst_node_ids: np.ndarray,
                                                 node_interact_times: np.ndarray):
        """
        compute source and destination node temporal embeddings
        :param src_node_ids: ndarray, shape (batch_size, )
        :param dst_node_ids:: ndarray, shape (batch_size, )
        :param node_interact_times: ndarray, shape (batch_size, )
        :return:
        """
        if self.embedding_type == 'node_wise':
            # Tensor, shape (2 * batch_size, )
            node_ids = np.concatenate([src_node_ids, dst_node_ids])
            # set node types to zero to not distinguish the src and dst nodes
            node_types = np.zeros(len(src_node_ids) + len(dst_node_ids))
            # compute the node temporal embeddings using the embedding module
            # Tensor, shape (2 * batch_size, node_feat_dim)
            node_embeddings = self.embedding_module.compute_node_temporal_embeddings(
                node_ids=node_ids,
                node_interact_times=np.concatenate([node_interact_times,
                                                    node_interact_times]),
                node_types=node_types
            )
        elif self.embedding_type == 'link_wise':
            node_embeddings = self.embedding_module.compute_node_temporal_embeddings(
                node_ids=np.concatenate([src_node_ids, dst_node_ids]),
                src_node_ids=np.tile(src_node_ids, 2),
                dst_node_ids=np.tile(dst_node_ids, 2),
                node_interact_times=np.tile(node_interact_times, 2))
        else:
            raise ValueError("Not Implemented Embedding Module!")
        src_node_embeddings, dst_node_embeddings = node_embeddings[:len(src_node_ids)], node_embeddings[
                                                                                        len(src_node_ids):]
        return src_node_embeddings, dst_node_embeddings

    def set_neighbor_sampler(self, neighbor_sampler: NeighborSampler):
        """
        set neighbor sampler to neighbor_sampler and reset the random state (for reproducing the results for uniform and time_interval_aware sampling)
        :param neighbor_sampler: NeighborSampler, neighbor sampler
        :return:
        """
        self.embedding_module.neighbor_sampler = neighbor_sampler
        if self.embedding_module.neighbor_sampler.sample_neighbor_strategy in ['uniform', 'time_interval_aware']:
            assert self.embedding_module.neighbor_sampler.seed is not None
            self.embedding_module.neighbor_sampler.reset_random_state()


class MLPEmbedding(nn.Module):
    def __init__(self, node_raw_features: torch.Tensor, edge_raw_features: torch.Tensor,
                 neighbor_sampler: NeighborSampler,
                 time_encoder: nn.Module, node_feat_dim: int, edge_feat_dim: int, time_feat_dim: int,
                 num_layers: int, num_neighbors: int, dropout: float, random_projections: List[RandomProjectionModule]):
        """
        Graph attention embedding module.
        :param node_raw_features: Tensor, shape (num_nodes + 1, node_feat_dim)
        :param edge_raw_features: Tensor, shape (num_edges + 1, edge_feat_dim)
        :param neighbor_sampler: NeighborSampler, neighbor sampler
        :param time_encoder: TimeEncoder
        :param node_feat_dim: int, dimension of node features
        :param edge_feat_dim: int, dimension of edge features
        :param time_feat_dim:  int, dimension of time features (encodings)
        :param num_layers: int, number of temporal graph convolution layers
        :param dropout: float, dropout rate
        """
        super(MLPEmbedding, self).__init__()

        self.node_raw_features = node_raw_features
        self.edge_raw_features = edge_raw_features
        self.neighbor_sampler = neighbor_sampler
        self.time_encoder = time_encoder
        self.node_feat_dim = node_feat_dim
        self.edge_feat_dim = edge_feat_dim
        self.time_feat_dim = time_feat_dim
        self.num_layers = num_layers
        self.num_neighbors = num_neighbors
        self.dropout = dropout
        self.random_projections = random_projections
        if self.random_projections is None:
            self.random_feature_dim = 0
        else:
            self.random_feature_dim = (self.random_projections[0].pair_wise_feature_dim) * len(self.random_projections)
            # self.random_feature_dim = self.random_projections[0].node_feature_dim
        self.projection_layer = nn.Sequential(
            nn.Linear(node_feat_dim + edge_feat_dim + time_feat_dim + self.random_feature_dim, self.node_feat_dim * 2),
            nn.ReLU(), nn.Linear(self.node_feat_dim * 2, self.node_feat_dim))
        self.mlp_mixers = nn.ModuleList([
            MLPMixer(num_tokens=self.num_neighbors, num_channels=self.node_feat_dim,
                     token_dim_expansion_factor=0.5,
                     channel_dim_expansion_factor=4.0, dropout=self.dropout)
            for _ in range(self.num_layers)
        ])

    def compute_node_temporal_embeddings(self, node_ids: np.ndarray,
                                         node_interact_times: np.ndarray, node_types=None):
        """
        given memory, node ids node_ids, and the corresponding time node_interact_times,
        return the temporal embeddings after convolution at the current_layer_num
        :param node_ids: ndarray, shape (batch_size, ), node ids
        :param node_interact_times: ndarray, shape (batch_size, ), node interaction times
        """

        device = self.node_raw_features.device
        batch_size = node_ids.shape[0]

        # get temporal neighbors, including neighbor ids, edge ids and time information
        # neighbor_node_ids ndarray, shape (batch_size, num_neighbors)
        # neighbor_edge_ids ndarray, shape (batch_size, num_neighbors)
        # neighbor_times ndarray, shape (batch_size, num_neighbors)
        neighbor_node_ids, neighbor_edge_ids, neighbor_times = \
            self.neighbor_sampler.get_historical_neighbors(node_ids=node_ids,
                                                           node_interact_times=node_interact_times,
                                                           num_neighbors=self.num_neighbors)
        neighbor_node_features = self.node_raw_features[torch.from_numpy(neighbor_node_ids)]

        neighbor_delta_times = node_interact_times[:, np.newaxis] - neighbor_times
        neighbor_time_features = self.time_encoder(
            torch.from_numpy(neighbor_delta_times).float()[:, :, None].to(device))
        # time features need to be masked for illegal node
        neighbor_edge_features = self.edge_raw_features[torch.from_numpy(neighbor_edge_ids)]
        if self.random_projections is not None:
            neighbor_random_feature_list = []
            for random_projection in self.random_projections:
                neighbor_random_feature_list.append(random_projection.get_pair_wise_feature(
                    src_node_ids=np.repeat(node_ids, self.num_neighbors),
                    dst_node_ids=neighbor_node_ids.reshape(-1),
                    edge_types=np.repeat(node_types, self.num_neighbors)).
                                                    reshape(batch_size, self.num_neighbors, -1))
                # neighbor_random_feature_list.append(
                #     random_projection.get_node_feature(neighbor_node_ids.reshape(-1)).reshape(batch_size,
                #                                                                               self.num_neighbors, -1))
            neighbor_random_features = torch.cat(neighbor_random_feature_list, dim=2)
            # random feature need to be masked for illegal node

            neighbor_combine_features = torch.cat(
                [neighbor_node_features, neighbor_time_features, neighbor_edge_features, neighbor_random_features],
                dim=2)
        else:
            neighbor_combine_features = torch.cat(
                [neighbor_node_features, neighbor_time_features, neighbor_edge_features], dim=2)
        # shape (batch, num_neighbors, node_feat_dim)
        output = self.projection_layer(neighbor_combine_features)
        output.masked_fill(torch.from_numpy(neighbor_node_ids == 0)[:, :, None].to(device), 0)
        for mlp_mixer in self.mlp_mixers:
            output = mlp_mixer(output)
        # shape (batch, node_feat_dim)
        output = torch.mean(output, dim=1)
        return output


class PairMLPEmbedding(nn.Module):
    def __init__(self, node_raw_features: torch.Tensor, edge_raw_features: torch.Tensor,
                 neighbor_sampler: NeighborSampler,
                 time_encoder: nn.Module, node_feat_dim: int, edge_feat_dim: int, time_feat_dim: int,
                 num_layers: int, num_neighbors: int, dropout: float, random_projections: List[RandomProjectionModule]):
        """
        Graph attention embedding module.
        :param node_raw_features: Tensor, shape (num_nodes + 1, node_feat_dim)
        :param edge_raw_features: Tensor, shape (num_edges + 1, edge_feat_dim)
        :param neighbor_sampler: NeighborSampler, neighbor sampler
        :param time_encoder: TimeEncoder
        :param node_feat_dim: int, dimension of node features
        :param edge_feat_dim: int, dimension of edge features
        :param time_feat_dim:  int, dimension of time features (encodings)
        :param num_layers: int, number of temporal graph convolution layers
        :param dropout: float, dropout rate
        """
        super(PairMLPEmbedding, self).__init__()

        self.node_raw_features = node_raw_features
        self.edge_raw_features = edge_raw_features
        self.neighbor_sampler = neighbor_sampler
        self.time_encoder = time_encoder
        self.node_feat_dim = node_feat_dim
        self.edge_feat_dim = edge_feat_dim
        self.time_feat_dim = time_feat_dim
        self.num_layers = num_layers
        self.num_neighbors = num_neighbors
        self.dropout = dropout
        self.random_projections = random_projections
        if self.random_projections is None:
            self.random_feature_dim = 0
        else:
            self.random_feature_dim = self.random_projections[0].pair_wise_feature_dim * 2 * len(
                self.random_projections)
        self.projection_layer = nn.Sequential(
            nn.Linear(node_feat_dim + edge_feat_dim + time_feat_dim + self.random_feature_dim, self.node_feat_dim * 2),
            nn.ReLU(), nn.Linear(self.node_feat_dim * 2, self.node_feat_dim))
        self.mlp_mixers = nn.ModuleList([
            MLPMixer(num_tokens=self.num_neighbors, num_channels=self.node_feat_dim,
                     token_dim_expansion_factor=0.5,
                     channel_dim_expansion_factor=4.0, dropout=self.dropout)
            for _ in range(self.num_layers)
        ])

    def compute_node_temporal_embeddings(self, node_ids: np.ndarray, src_node_ids: np.ndarray,
                                         dst_node_ids: np.ndarray, node_interact_times: np.ndarray):
        """
        given memory, node ids node_ids, and the corresponding time node_interact_times,
        return the temporal embeddings after convolution at the current_layer_num
        :param node_ids: ndarray, shape (batch_size, ), node ids
        :param node_interact_times: ndarray, shape (batch_size, ), node interaction times
        """

        device = self.node_raw_features.device
        batch_size = src_node_ids.shape[0]
        embedding_list = []
        # get temporal neighbors, including neighbor ids, edge ids and time information
        # neighbor_node_ids ndarray, shape (batch_size, num_neighbors)
        # neighbor_edge_ids ndarray, shape (batch_size, num_neighbors)
        # neighbor_times ndarray, shape (batch_size, num_neighbors)
        neighbor_node_ids, neighbor_edge_ids, neighbor_times = \
            self.neighbor_sampler.get_historical_neighbors(node_ids=node_ids,
                                                           node_interact_times=node_interact_times,
                                                           num_neighbors=self.num_neighbors)
        neighbor_node_features = self.node_raw_features[torch.from_numpy(neighbor_node_ids)]
        neighbor_delta_times = node_interact_times[:, np.newaxis] - neighbor_times
        neighbor_time_features = self.time_encoder(
            torch.from_numpy(neighbor_delta_times).float()[:, :, None].to(device))
        # time features need to be masked for illegal node
        neighbor_edge_features = self.edge_raw_features[torch.from_numpy(neighbor_edge_ids)]
        if self.random_projections is not None:
            neighbor_random_feature_list = []
            for random_projection in self.random_projections:
                single_neighbor_random_features = random_projection.get_pair_wise_feature(
                    src_node_ids=np.tile(neighbor_node_ids.reshape(-1), 2),
                    dst_node_ids=np.concatenate(
                        [np.repeat(src_node_ids, self.num_neighbors), np.repeat(dst_node_ids, self.num_neighbors)]))
                # [batch,num_neighbors,random_dim*2]
                single_neighbor_random_features = torch.cat(
                    [single_neighbor_random_features[:len(node_ids) * self.num_neighbors],
                     single_neighbor_random_features[len(node_ids) * self.num_neighbors:]],
                    dim=1).reshape(len(node_ids), self.num_neighbors, -1)
                neighbor_random_feature_list.append(single_neighbor_random_features)
                # print('----------------------')
                # print(single_neighbor_random_features[0,:5,:5])
                # print(single_neighbor_random_features[0,:5,-5:])
                # print('----------------------')
            neighbor_random_features = torch.cat(neighbor_random_feature_list, dim=2)
            # [batch,num_neighbors,random_dim*2]
            neighbor_combine_features = torch.cat(
                [neighbor_node_features, neighbor_time_features, neighbor_edge_features, neighbor_random_features],
                dim=2)
        else:
            neighbor_combine_features = torch.cat(
                [neighbor_node_features, neighbor_time_features, neighbor_edge_features], dim=2)

        # shape (batch, num_neighbors, node_feat_dim)
        embeddings = self.projection_layer(neighbor_combine_features)
        embeddings.masked_fill(torch.from_numpy(neighbor_node_ids == 0)[:, :, None].to(device), 0)
        for mlp_mixer in self.mlp_mixers:
            embeddings = mlp_mixer(embeddings)
        # shape (batch, node_feat_dim)
        embeddings = torch.mean(embeddings, dim=1)

        return embeddings


class FeedForwardNet(nn.Module):

    def __init__(self, input_dim: int, dim_expansion_factor: float, dropout: float = 0.0):
        """
        two-layered MLP with GELU activation function.
        :param input_dim: int, dimension of input
        :param dim_expansion_factor: float, dimension expansion factor
        :param dropout: float, dropout rate
        """
        super(FeedForwardNet, self).__init__()

        self.input_dim = input_dim
        self.dim_expansion_factor = dim_expansion_factor
        self.dropout = dropout

        self.ffn = nn.Sequential(nn.Linear(in_features=input_dim, out_features=int(dim_expansion_factor * input_dim)),
                                 nn.GELU(),
                                 nn.Dropout(dropout),
                                 nn.Linear(in_features=int(dim_expansion_factor * input_dim), out_features=input_dim),
                                 nn.Dropout(dropout))

    def forward(self, x: torch.Tensor):
        """
        feed forward net forward process
        :param x: Tensor, shape (*, input_dim)
        :return:
        """
        return self.ffn(x)


class MLPMixer(nn.Module):

    def __init__(self, num_tokens: int, num_channels: int, token_dim_expansion_factor: float = 0.5,
                 channel_dim_expansion_factor: float = 4.0, dropout: float = 0.0):
        """
        MLP Mixer.
        :param num_tokens: int, number of tokens
        :param num_channels: int, number of channels
        :param token_dim_expansion_factor: float, dimension expansion factor for tokens
        :param channel_dim_expansion_factor: float, dimension expansion factor for channels
        :param dropout: float, dropout rate
        """
        super(MLPMixer, self).__init__()

        self.token_norm = nn.LayerNorm(num_tokens)
        self.token_feedforward = FeedForwardNet(input_dim=num_tokens, dim_expansion_factor=token_dim_expansion_factor,
                                                dropout=dropout)

        self.channel_norm = nn.LayerNorm(num_channels)
        self.channel_feedforward = FeedForwardNet(input_dim=num_channels,
                                                  dim_expansion_factor=channel_dim_expansion_factor,
                                                  dropout=dropout)

    def forward(self, input_tensor: torch.Tensor):
        """
        mlp mixer to compute over tokens and channels
        :param input_tensor: Tensor, shape (batch_size, num_tokens, num_channels)
        :return:
        """
        # mix tokens
        # Tensor, shape (batch_size, num_channels, num_tokens)
        hidden_tensor = self.token_norm(input_tensor.permute(0, 2, 1))
        # Tensor, shape (batch_size, num_tokens, num_channels)
        hidden_tensor = self.token_feedforward(hidden_tensor).permute(0, 2, 1)
        # Tensor, shape (batch_size, num_tokens, num_channels), residual connection
        output_tensor = hidden_tensor + input_tensor

        # mix channels
        # Tensor, shape (batch_size, num_tokens, num_channels)
        hidden_tensor = self.channel_norm(output_tensor)
        # Tensor, shape (batch_size, num_tokens, num_channels)
        hidden_tensor = self.channel_feedforward(hidden_tensor)
        # Tensor, shape (batch_size, num_tokens, num_channels), residual connection
        output_tensor = hidden_tensor + output_tensor

        return output_tensor
