import os
import torch
from torch_geometric.data import Data, Batch
from utils.utils_batching_preproc import create_batch_from_task_list, preprocess_task
import time

from data.dataloader import linearize


class Collater:
    def __init__(self, device, text_embeddings_dict, use_original_features):
        self.device = device
        self.text_dict = text_embeddings_dict
        self.use_original_features = use_original_features

    def __call__(self, batch):
        t1 = time.time()
        n_task = len(batch)
        #  This method returns an "old" batch - some changes to it below
        text_enc_features = None  # -- set for now to None - KG features are preprocessed in advance!  --  # if self.use_original_features else self.text_encoder
        batch = [preprocess_task(task, text_enc_features) for task in batch]
        batch = create_batch_from_task_list(batch, keep_original_ys=False, shuffle=False, copy_to_device=self.device)
        all_labels = batch[2]
        all_labels = torch.stack([self.text_dict[label] for label in all_labels]).to(self.device)
        y_true_matrix = batch[0].y_task_labels
        metagraph_edges, metagraph_edge_attr = batch[3], batch[4]
        query_set_mask = batch[6]
        t2 = time.time()
        task_mask = batch[9]

        metagraph_edge_index = batch[3]

        ###### only works for kg with binary classification!!#####
        b_mask = query_set_mask.reshape(n_task, -1).bool()
        inputs_idx = metagraph_edge_index.reshape(2, len(b_mask), -1)[0]
        output_idx = metagraph_edge_index.reshape(2, len(b_mask), -1)[1].clone()

        # add a fake False class for binary classification
        output_idx[y_true_matrix.reshape(n_task, -1) == 0] = metagraph_edge_index.max() + 2

        input_seqs, _ = linearize(~b_mask, inputs_idx, output_idx)
        query_seqs, batch_rand_perm = linearize(b_mask, inputs_idx, torch.ones(output_idx.shape, dtype=torch.int) * (
                    metagraph_edge_index.max() + 1))
        query_seqs_gt, _ = linearize(b_mask, inputs_idx, output_idx, batch_rand_perm)
        return batch[0], all_labels, y_true_matrix.float(), metagraph_edges, metagraph_edge_attr, query_set_mask, input_seqs, query_seqs, query_seqs_gt, task_mask


class PostProcDataLoader(torch.utils.data.DataLoader):
    """
        A slightly modified dataloader
    """

    def __init__(self, dataset, batch_size: int = 1, shuffle: bool = False, text_encoder=None, collate_fn=None,
                 **kwargs):
        super().__init__(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn, **kwargs)
        assert text_encoder is not None, "text_encoder is a required argument"
        self.text_encoder = text_encoder
        self.iter = iter(self)

    def __next__(self):
        batch = next(self.iter)


def get_dataset_dataloader(dataset, batchsz, device, text_dict, params):
    """
    :param dataset: Dataset as returned by the get_dataset function.
    :param batchsz: Batch size.
    :param device: Which device everything should be in
    :param text_encoder: Text encoder to use for encoding text features.
    :return: A PyTorch DataLoader object.
    """
    return torch.utils.data.DataLoader(dataset, batch_size=batchsz, shuffle=True, num_workers=params["workers"],
                                                collate_fn=Collater(device, text_dict,
                                                                    use_original_features=params["original_features"]))


