import enum
import json
import os, sys
sys.path.append('../')
from inference import Inferencer
import torch

TASK_HEAD = {
    'mnli': 2,
}



class WrongExampleCollector():
    def __init__(self, ckpt_path, dataset_config, model='bert-base-uncased', device='cuda') -> None:
        self.dataset_config = dataset_config
        self.inferencer = Inferencer(ckpt_path=ckpt_path, model=model, batch_size=32, device=device)
        self.dataset = dict()
        self.wrong_examples = dict()
        self.correct_examples = dict()

        self.dataloader()

    def collector(self):
        for each_dataset in self.dataset:
            task_type = self.dataset[each_dataset][0]['task']
            self.wrong_examples[each_dataset] = []
            self.correct_examples[each_dataset] = []
            premise = [example['text_a'] for example in self.dataset[each_dataset]]
            hypothesis = [example['text_b'][0] for example in self.dataset[each_dataset]]
            label = [example['orig_label'] if example['orig_label']!=-1 else 1 for example in self.dataset[each_dataset]]

            predictions = self.inferencer.inference(premise=premise, hypo=hypothesis)[TASK_HEAD[each_dataset]]
            if TASK_HEAD[each_dataset] == 2:
                predictions = torch.argmax(predictions, dim=-1).tolist()
            elif TASK_HEAD[each_dataset] == 1:
                predictions = (predictions > 0.5).int().tolist()

            assert len(label) == len(predictions)
            for i, (one_label, one_prediction) in enumerate(zip(label, predictions)):
                if one_label != one_prediction:
                    self.wrong_examples[each_dataset].append(
                        {
                            'task': task_type,
                            'text_a': self.dataset[each_dataset][i]['text_a'],
                            'text_b': self.dataset[each_dataset][i]['text_b'][0],
                            'orig_label': self.dataset[each_dataset][i]['orig_label'],
                        }
                    )
                else:
                    self.correct_examples[each_dataset].append(
                        {
                            'task': task_type,
                            'text_a': self.dataset[each_dataset][i]['text_a'],
                            'text_b': self.dataset[each_dataset][i]['text_b'][0],
                            'orig_label': self.dataset[each_dataset][i]['orig_label'],
                        }
                    )

        self.save_result()

    def save_result(self):
        for each_dataset in self.wrong_examples:
            with open(f'../data/correct_wrong_examples/wrong_{each_dataset}.json', 'w', encoding='utf8') as f:
                for example in self.wrong_examples[each_dataset]:
                    json.dump(example, f, ensure_ascii=False)
                    f.write('\n')
        
        for each_dataset in self.correct_examples:
            with open(f'../data/correct_wrong_examples/correct_{each_dataset}.json', 'w', encoding='utf8') as f:
                for example in self.correct_examples[each_dataset]:
                    json.dump(example, f, ensure_ascii=False)
                    f.write('\n')
    def dataloader(self):
        for each_dataset in self.dataset_config:
            dataset_length = sum([1 for line in open(self.dataset_config[each_dataset]['data_path'], 'r', encoding='utf8')])
            dataset_length_limit = self.dataset_config[each_dataset]['size'] if isinstance(self.dataset_config[each_dataset]['size'], int) else int(self.dataset_config[each_dataset]['size'] * dataset_length)
            self.dataset[each_dataset] = []
            with open(self.dataset_config[each_dataset]['data_path'], 'r', encoding='utf8') as f:
                for i, each_line in enumerate(f):
                    if i >= dataset_length_limit:
                        break
                    self.dataset[each_dataset].append(json.loads(each_line))



if __name__ == '__main__':
    DATA_SIZE=1.0
    dataset_config = {
            # 'xsum': {'task_type': 'summarization', 'data_path': 'data/xsum.json', 'size':DATA_SIZE},
            # 'cnndm': {'task_type': 'summarization', 'data_path': 'data/cnndm.json', 'size': DATA_SIZE},
            'mnli': {'task_type': 'nli', 'data_path': '../data/mnli.json', 'size': DATA_SIZE},
            # 'nli_fever': {'task_type': 'fact_checking', 'data_path': 'data/nli_fever.json', 'size': DATA_SIZE},
            # 'doc_nli': {'task_type': 'paraphrase', 'data_path': 'data/doc_nli.json', 'size': DATA_SIZE},
            # 'squad': {'task_type': 'qa', 'data_path': 'data/squad.json', 'size': DATA_SIZE},
            # 'paws': {'task_type': 'paraphrase', 'data_path': 'data/paws.json', 'size':DATA_SIZE},
            # 'paws_qqp': {'task_type': 'paraphrase', 'data_path': 'data/paws_qqp.json', 'size':DATA_SIZE},
            # 'vitaminc': {'task_type': 'fact_checking', 'data_path': 'data/vitaminc.json', 'size':DATA_SIZE},
            # 'race': {'task_type': 'multiple_choice_qa', 'data_path': 'data/race.json', 'size': DATA_SIZE},
            # 'anli_r1': {'task_type': 'nli', 'data_path': 'data/anli_r1.json', 'size': DATA_SIZE},
            # 'anli_r2': {'task_type': 'nli', 'data_path': 'data/anli_r2.json', 'size': DATA_SIZE},
            # 'anli_r3': {'task_type': 'nli', 'data_path': 'data/anli_r3.json', 'size': DATA_SIZE},
            # 'snli': {'task_type': 'nli', 'data_path': 'data/snli.json', 'size': DATA_SIZE},
            # 'wikihow': {'task_type': 'summarization', 'data_path': 'data/wikihow.json', 'size': DATA_SIZE},
            # 'msmarco': {'task_type': 'ir', 'data_path': 'data/msmarco.json', 'size': DATA_SIZE},
            # 'paws_unlabeled': {'task_type': 'paraphrase', 'data_path': 'data/paws_unlabeled.json', 'size': DATA_SIZE},
            # 'wiki103': {'task_type': 'paraphrase', 'data_path': 'data/wiki103.json', 'size': DATA_SIZE},
            # 'qqp': {'task_type': 'paraphrase', 'data_path': 'data/qqp.json', 'size': DATA_SIZE},
            # 'stsb': {'task_type': 'sts', 'data_path': 'data/stsb.json', 'size': DATA_SIZE},
            # 'sick': {'task_type': 'sts', 'data_path': 'data/sick.json', 'size': DATA_SIZE},
            # 'ctc': {'task_type': 'ctc', 'data_path': 'data/ctc.json', 'size': DATA_SIZE},
        }
    collector = WrongExampleCollector(ckpt_path='checkpoints/roberta-base/roberta-base_no_mlm_mnli_nli_fever_doc_nli_paws_paws_qqp_vitaminc_anli_r1_anli_r2_anli_r3_snli_wikihow_paws_unlabeled_wiki103_qqp_stsb_sick_500000_32x2x1_epoch=00_step=110000.ckpt',
                                        dataset_config=dataset_config,
                                        model='roberta-base',
                                        device='cuda:2')
    collector.collector()

