import json
import os
import re
import torch
import torch.nn.functional as F
from typing import Optional, Sequence, List, Set, Dict, Any, Union
import transformers
import logging
from dataclasses import dataclass
import pathlib
from torch.utils.data import DataLoader
IGNORE_INDEX = -100



def read_jsonl(path: str):
    try:
        with open(path) as fh:
           return [json.loads(line) for line in fh.readlines() if line]
    except:
        return json.load(open(path, 'r', encoding= 'utf-8'))



def get_annotations(data_dir, target_set):
    examples = []
    for dd in data_dir.split(","):
        examples += read_jsonl(dd)
    print(f"{len(examples)} examples, each with {len(examples[0]['outputs'])} solutions")
    return examples




def make_training_dataloaders(
    data_module: Dict[str, torch.utils.data.Dataset],
    training_args: dataclass = None,
) -> Dict:
    train_dataloader = DataLoader(
                            data_module['train_dataset'], 
                            batch_size=training_args.per_device_train_batch_size, 
                            shuffle=True, 
                            drop_last=False, 
                            collate_fn=data_module['train_dataset'].collate_fn, 
                        )
    


    return train_dataloader 



def make_testing_dataloader(
    dataset: torch.utils.data.Dataset,
    batch_size: int,
):
    return DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=False, collate_fn=dataset.collate_fn)







def make_training_verifier_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: dataclass) -> Dict:
    if data_args.process == True:
            dataset_class = Process_VerifierDataset
    else:
            dataset_class = Outcome_VerifierDataset 

    train_dataset = dataset_class(
                        tokenizer=tokenizer, 
                        data_dir=data_args.data_dir, 
                        target_set=data_args.target_set,
                        verifier_id=data_args.verifier_id,
                        data_id=data_args.data_id,
                        generator_id=data_args.generator_id,
                        per_problem_sampling_solution=data_args.per_problem_sampling_solution, 
                    )
    
    val_dataset = None

    return dict(train_dataset=train_dataset, val_dataset=val_dataset)





class Outcome_VerifierDataset(torch.utils.data.Dataset):
    """Right Padding"""
    def __init__(
        self, 
        tokenizer: transformers.PreTrainedTokenizer = None, 
        data_dir: str = None,
        target_set: str = None,
        per_problem_sampling_solution: str = None, 
        loss_level: str = 'token', 
        loss_on_llm: bool = False,
    ):
        self.examples = get_annotations(data_dir)
        assert len(self.examples[0]['outputs']) >= per_problem_sampling_solution

        self.tokenizer = tokenizer
        self.data_dir = data_dir
        self.target_set = target_set
        self.loss_level = loss_level
        self.loss_on_llm = loss_on_llm
        assert loss_level in ('token', 'step')

        self.pad_token_id = tokenizer.pad_token_id
        self.eos_token_id = tokenizer.eos_token_id


        if per_problem_sampling_solution != -1:
            for example in self.examples:
                example['outputs'] = example['outputs'][:per_problem_sampling_solution]
        else:
            per_problem_sampling_solution = len(self.examples[0]['outputs'])
        

        for ex in self.examples:
            dedup_outputs = []
            responses = set()
            for output in ex['outputs']:
                if output['response'] in responses:
                    continue
                responses.add(output['response'])
                dedup_outputs.append(output)
            ex['outputs'] = dedup_outputs

        indices1 = [[i] * len(ex["outputs"]) for i, ex in enumerate(self.examples)]
        indices2 = [[j for j in range(len(ex["outputs"]))] for i, ex in enumerate(self.examples)]
        qns_str = [[ex["input"]] * len(ex["outputs"]) for ex in self.examples]
        solutions_str = [[outputs["response"] for outputs in ex["outputs"]] for ex in self.examples]
        v_classes = [[outputs["label"] == True for outputs in ex["outputs"]] for ex in self.examples]

        indices1 = self._flatten(indices1)
        indices2 = self._flatten(indices2)
        qns_str = self._flatten(qns_str)
        solutions_str = self._flatten(solutions_str)
        v_classes = self._flatten(v_classes)

        qns_tokens = tokenizer(qns_str, padding=False).input_ids
        solutions_tokens = tokenizer(solutions_str, padding=False, add_special_tokens=False).input_ids


        # Remove instances whose length is bigger than 2048
        self.qns_tokens, self.solutions_tokens, self.indices1, self.indices2, self.qns_str, self.solutions_str, self.v_classes = zip(
            *[
                (qns_tokens[i], solutions_tokens[i], indices1[i], indices2[i], qns_str[i], solutions_str[i],
                 v_classes[i])
                for i in range(len(qns_tokens))
                if len(qns_tokens[i]) + len(solutions_tokens[i]) + 1 <= 2048
            ])
        self.max_len = max(
            [len(qns_tokens[i]) + len(solutions_tokens[i]) + 1 for i in range(len(solutions_tokens))])

        print(f"Max tokens: {self.max_len}")
        self.per_problem_sampling_solution = per_problem_sampling_solution
        print(f'Number of examples = {len(self.qns_str)}')
        self.n_question = len(self.examples)

    def __len__(self):
        return len(self.solutions_tokens)

    def _flatten(self, ls):
        return [item for sublist in ls for item in sublist]

    def __getitem__(self, idx):
        qn_tokens = self.qns_tokens[idx]
        sol_tokens = self.solutions_tokens[idx]
        v_class = self.v_classes[idx]

        input_ids = qn_tokens + sol_tokens + [self.eos_token_id]
        masks = (
            ([0] * len(qn_tokens))
            + ([1] * len(sol_tokens))
            + ([1])
        )

        # create language modeling labels
        labels = input_ids
        labels = mask_labels(labels, masks)

        # create verifier labels
        v_labels = [int(v_class)] * len(input_ids)
        v_labels = mask_labels(v_labels, masks)

        input_ids = torch.tensor(input_ids)
        labels = torch.tensor(labels) if self.loss_on_llm else None
        v_labels = torch.tensor(v_labels)
        return dict(
            idx1=self.indices1[idx], idx2=self.indices2[idx], 
            input_ids=input_ids, labels=labels, v_labels=v_labels, 
            qn_str=self.qns_str[idx], qn_tokens=self.qns_tokens[idx], sol_str=self.solutions_str[idx], sol_tokens=self.solutions_tokens[idx], v_class=self.v_classes[idx],
        )

    def collate_fn(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels, v_labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels", "v_labels"))
        idx1, idx2, qn_str, qn_tokens, sol_str, sol_tokens, v_class = tuple([instance[key] for instance in instances] for key in ("idx1", "idx2", "qn_str", "qn_tokens", "sol_str", "sol_tokens", "v_class"))

        input_ids, attention_mask = right_pad_sequences(input_ids, padding_value=self.pad_token_id, return_attention_mask=True)
        labels = right_pad_sequences(labels, padding_value=IGNORE_INDEX, return_attention_mask=False) if self.loss_on_llm else None
        v_labels = right_pad_sequences(v_labels, padding_value=IGNORE_INDEX, return_attention_mask=False)
        
        return dict(
            idx1=idx1, idx2=idx2,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            v_labels=v_labels,
            qn_str=qn_str, qn_tokens=qn_tokens, sol_str=sol_str, sol_tokens=sol_tokens, v_class=v_class,
        )

class Process_VerifierDataset(Outcome_VerifierDataset):
    """Right Padding"""
    def __init__(
            self,
            tokenizer: transformers.PreTrainedTokenizer = None,
            data_dir: str = 'data/gsm8k/model_generation',
            target_set: str = None,
            data_id: str = None,
            generator_id: str = None,
            verifier_id: str = None,
            per_problem_sampling_solution: str = None,
            dedup: bool = False, 
    ):
        self.examples = get_annotations(data_dir, data_id,verifier_id,generator_id)
        assert len(self.examples[0]['outputs']) >= per_problem_sampling_solution

        self.tokenizer = tokenizer
        self.data_dir = data_dir
        self.target_set = target_set
        self.generator_id = generator_id

        self.pad_token_id = tokenizer.pad_token_id
        self.eos_token_id = tokenizer.eos_token_id

        if per_problem_sampling_solution != -1:
            for example in self.examples:
                if "input" not in example:
                    example['input'] = example['question']
                example['outputs'] = example['outputs'][:per_problem_sampling_solution]
        else:
            per_problem_sampling_solution = len(self.examples[0]['outputs'])

        for ex in self.examples:
            dedup_outputs = []
            responses = set()
            for output in ex['outputs']:
                if output['response'] in responses:
                    continue
                responses.add(output['response'])
                dedup_outputs.append(output)
            ex['outputs'] = dedup_outputs

        indices1 = [[i] * len(ex["outputs"]) for i, ex in enumerate(self.examples)]
        indices2 = [[j for j in range(len(ex["outputs"]))] for i, ex in enumerate(self.examples)]
        qns_str = [[ex["input"]] * len(ex["outputs"]) for ex in self.examples]
        solutions_str = [[outputs["response"] for outputs in ex["outputs"]] for ex in self.examples]
        v_classes = [[outputs["process_vscores"] for outputs in ex["outputs"]] for ex in self.examples]

        indices1 = self._flatten(indices1)
        indices2 = self._flatten(indices2)
        qns_str = self._flatten(qns_str)
        solutions_str = self._flatten(solutions_str)
        v_classes = self._flatten(v_classes)

        qns_tokens = tokenizer(qns_str, padding=False).input_ids
        solutions_tokens = tokenizer(solutions_str, padding=False, add_special_tokens=False).input_ids


        # Remove instances whose length is bigger than 2048
        self.qns_tokens, self.solutions_tokens, self.indices1, self.indices2, self.qns_str, self.solutions_str, self.v_classes = zip(
            *[
                (qns_tokens[i], solutions_tokens[i], indices1[i], indices2[i], qns_str[i], solutions_str[i],
                 v_classes[i])
                for i in range(len(qns_tokens))
                if len(qns_tokens[i]) + len(solutions_tokens[i]) + 1 <= 2048
            ])
        self.max_len = max(
            [len(qns_tokens[i]) + len(solutions_tokens[i]) + 1 for i in range(len(solutions_tokens))])

        print(f"Max tokens: {self.max_len}")
        self.per_problem_sampling_solution = per_problem_sampling_solution
        print(f'Number of examples = {len(self.qns_str)}')
        self.n_question = len(self.examples)

    def __getitem__(self, idx):
        qn_tokens = self.qns_tokens[idx]
        sol_tokens = self.solutions_tokens[idx]
        v_class = self.v_classes[idx]

        input_ids = qn_tokens + sol_tokens + [self.eos_token_id]
        masks = (
                ([0] * len(qn_tokens))
                + ([1] * len(sol_tokens))
                + ([1])
        )

        labels = input_ids
        labels = mask_labels(labels, masks)

        v_class = [1] * len(qn_tokens)+ v_class
        v_labels = mask_labels(v_class, masks)

        input_ids = torch.tensor(input_ids)
        labels = torch.tensor(labels) if self.loss_on_llm else None
        v_labels = torch.tensor(v_labels)
        return dict(
            idx1=self.indices1[idx], idx2=self.indices2[idx],
            input_ids=input_ids, labels=labels, v_labels=v_labels,
            qn_str=self.qns_str[idx], qn_tokens=self.qns_tokens[idx], sol_str=self.solutions_str[idx],
            sol_tokens=self.solutions_tokens[idx], v_class=self.v_classes[idx],
        )















def left_pad_sequences(sequences: List[torch.LongTensor], padding_value: int, return_attention_mask: bool = False):
    max_length = max(len(x) for x in sequences)
    padded_sequences = torch.stack([F.pad(seq, (max_length - seq.shape[-1], 0), value=padding_value) for seq in sequences], dim=0)
    if return_attention_mask:
        attention_mask = padded_sequences.ne(padding_value)
        return padded_sequences, attention_mask
    return padded_sequences

def right_pad_sequences(sequences: List[torch.LongTensor], padding_value: int, return_attention_mask: bool = False):
    padded_sequences = torch.nn.utils.rnn.pad_sequence(
        sequences,
        batch_first=True,
        padding_value=padding_value,
    )
    if return_attention_mask:
        attention_mask = padded_sequences.ne(padding_value)
        return padded_sequences, attention_mask
    return padded_sequences


def mask_labels(labels: List[int], masks: List[bool]):
    """Mask the corresponding label into IGNORE_INDEX"""
    assert len(labels) == len(masks)
    return [
        token if mask
        else IGNORE_INDEX
        for token, mask in zip(labels, masks) 
    ]



