from dataclasses import dataclass
import eval_glue
from transformers import Trainer, TrainingArguments, TrainerCallback
from collections import defaultdict
from transformers import RobertaForSequenceClassification, AutoTokenizer
from transformers.modeling_utils import SequenceClassifierOutput
import sparsify
import torch
from param import param
import datasets
import copy
import utils

def merge_para(ps):
    # ps: [n_model, n_para]
    # ws: [n_model]
    ps, ws = torch.stack(ps[:-1],dim=0), ps[1]
    _param = torch.sum(ws * ps,dim=0)
    return _param

# TODO: 要有梯度
def task_arithmetics(
    self,
    base_model,
    models_to_merge: param,
    weights: float = 1.0,
    **kwargs,
):

    task_vectors = [
        model - base_model
        for model in models_to_merge
    ]
    
    if isinstance(weights, param):
        # tensor wise
        merged_param = base_model + param.vectorize_reduce(
            merge_para,
            task_vectors + [weights]
        )
    elif isinstance(weights, list):
        # task wises
        merged_param = base_model + sum([
            w * tv
            for w, tv in zip(weights, task_vectors)
        ])
    elif isinstance(weights, int):
        # merge wise
        merged_param = base_model + weights * sum(task_vectors)
    else:
        raise NotImplementedError
    return merged_param

def ties_merge_task_arithmetics(
    base_model: param, 
    models_to_merge: list(param), 
    weights: param,
    mask_rate: float,
):

    def disjoint_merge_noagg(
        tensor: torch.Tensor, # (n_model, n_para)
    ):
        # torch.sign 将正数转为1，将负数转为-1，将0保持为0
        expect_signs = torch.sign(tensor.sum(dim=0)) # (num_total_params, )
        # get majority sign 如果主要是正数，那么总和将为正，如果主要是负数，那么总和将为负
        majority_sign = torch.sign(expect_signs.sum(dim=0))
        # replace 0 in param_signs to the major sign in param_signs
        expect_signs[expect_signs == 0] = majority_sign
        del majority_sign

        # preserve the parameter with the expect sign
        # (1, n_para) & (n_model, n_para)
        # mask = (
        #     (expect_signs.unsqueeze(dim=0) > 0) & (tensor > 0)
        #  |  (expect_signs.unsqueeze(dim=0) < 0) & (tensor < 0)
        # )
        mask = torch.where(
            expect_signs.unsqueeze(0) > 0, tensor > 0, tensor < 0
        )
        tensor = tensor * mask
        #  don't aggregate now 
        return tensor

    task_vectors = [
        model - base_model
        for model in models_to_merge
    ]
    # 由于需要获取总的majority sign, 因此需要先提取出来所有的参数 
    flattened_param = [ task_vector.flatten() for task_vector in task_vectors ]
    # sparsify on model-level => (n_model, n_para)
    flattened_param = torch.vstack(
        [ sparsify.magnitude(_param, 1 - mask_rate) for _param in flattened_param ]
    )
    flattened_param = disjoint_merge_noagg(flattened_param)
    # randomly select one vector to unflatten
    merged_param = copy.deepcopy(base_model)

    task_vectors = [ 
        task_vector.unflatten(_flat_param) 
        for task_vector, _flat_param in zip(task_vectors,flattened_param)
    ]

    if isinstance(weights, param):
        # tensor wise
        merged_param = base_model + param.vectorize_reduce(
            merge_para,
            task_vectors + [weights]
        )
    elif isinstance(weights, list):
        # task wises
        merged_param = base_model + sum([
            w * tv
            for w, tv in zip(weights, task_vectors)
        ])
    elif isinstance(weights, int):
        # merge wise
        merged_param = base_model + weights * sum(task_vectors)
    else:
        raise NotImplementedError

    return merged_param


class AdaMergeMTLClassificationModel(RobertaForSequenceClassification):

    def __init__(
        self, 
        config, 
        models_to_merge: list[param], 
        base_model,
        classifer_heads: list[torch.nn.Module], 
        ada_type='tensorwise'
    ):
        
        super().__init__(config)
        self.requires_grad_(False)

        self.models_to_merge = models_to_merge
        self.base_model = base_model
        n_model = len(models_to_merge)
        self.classifier_heads = classifer_heads
        self.ada_type = ada_type

        if 'tensorwise' in ada_type:
            self.lambdas = torch.nn.ModuleDict({
                n: torch.nn.Parameter(torch.full((n_model,), 0.3 ))
                for n in models_to_merge[0].keys()
            })
        elif 'taskwise' in ada_type:
            self.lambdas = torch.nn.Parameter(torch.full((n_model,), 0.3 ))
        else:
            self.lambdas = torch.nn.Parameter(torch.tensor([0.3]))

    def copy_merged_para(self, ):
        if 'pp' in self.ada_type:
            merged_param = ties_merge_task_arithmetics(
                self.base_model,
                self.models_to_merge,
                self.lambdas,
            )
        else:
            merged_param = task_arithmetics(
                self.base_model,
                self.models_to_merge,
                self.lambdas,
            )
        merged_param.assign(self)

    def forward(
        self, 
        input_ids,
        attention_mask,
        token_type_ids,
        position_ids,
        head_mask,
        inputs_embeds,
        data_name,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):

        self.copy_merged_para()

        outputs = self.roberta(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        logits = self.classifier_heads[data_name](outputs[0])
        loss = -(logits.softmax(1) * logits.log_softmax(1)).sum(1).mean()

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


@dataclass
class AdaMerge:

    # models_to_merge: list
    names: list
    base_model: torch.nn.Module
    data_nums: list
    ada_type: str
    scaling: list = None
    norm_fish_weight: bool = True
    min_fish_weight: float = 1e-6

    def merge(self, ):
        return self.get_coefficient()

    def get_coefficient(self, ):
        # TODO: tasker
        
        models_to_merge, train_datasets,classifer_heads = [], [], []
        for name in self.names:
            model, tokenizer = eval_glue.load_glue_classifier(name, )
            train_datasets.append(eval_glue.load_glue_dataset(tokenizer, name, split='train').select(range(self.data_nums)))
            models_to_merge.append(model)
            classifer_heads.append(copy.deepcopy(model.classifier))

        self.model = AdaMergeMTLClassificationModel.from_pretrained(
            models_to_merge=models_to_merge,
            base_model=self.base_model,
            classifer_heads=classifer_heads,
            ada_type=self.ada_type
        )
        # cycle between all datasets (multi task setting)
        train_datasets = datasets.interleave_datasets(train_datasets,stopping_strategy="first_exhausted")
        trainer = Trainer(
            model=self.model,
            args=TrainingArguments(
                per_device_train_batch_size=16,
                num_train_epochs=500,
                learning_rate=1e-3,
                gradient_accumulation_steps=len(self.names), # backward After travese one batch all across dataset 
                report_to=[], # disable wandb
            ),
            train_dataset=train_datasets, 
            tokenizer=tokenizer,
        )
        trainer.train()
        return self.model

@torch.inference_mode()
def run_eval_glue(
    model,
    args,
    outdir='debug/test',
):
    # TODO: 解耦 getsavename, 以及 postprocess
    tokenizer = AutoTokenizer.from_pretrained(args.model_placeholder)
    import eval_glue 
    metrics = {
        "model": '+'.join(args.models_name) + '-' + args.merge_method,
    }
    for dataset in ["cola", "sst2", "mrpc", "stsb", "qqp", "mnli", "qnli","rte"]:
        if args.load_head:
            head_path = eval_glue.head_path_template.format(name=dataset)
            print(f' >>> load classifier head from {head_path} for {dataset}')
            classifier = torch.load(head_path)
            model.classifier = classifier
        elif model.num_labels != eval_glue.glue_data_num_labels_map[dataset]:
            print(f' >>> num labels {model.num_labels} is not Compatible for {dataset}, skipping')
            continue
        
        test_metrics = eval_glue.eval_glue(tokenizer, model, dataset, outdir)
        
        metrics[dataset] = test_metrics[f'eval_{eval_glue.glue_data_metrics_map[dataset]}']
    utils.save_excel(metrics, outdir)

def ada_glue(
    *, 
    models_to_merge: list[str], 
    models_name: list[str],
    yaml_file: str,
    model_placeholder: str = None, 
    model_loader: str = None,
    eval_func: str = None,
    dtype: str = None,
    exclude_param: list[str] = None, 
    load_head: bool = None,
    seed: int=10,

    base_model: str = None,
    # for task-arithmetic:
    scaling: float = None,
    # for dare-merge:
    mask_rate: float = None,
    mask_scale: float = None,
    mask_strategy: str = None,
    # for ada merge
    ada_type: str = None
):
    import inspect
    frame = inspect.currentframe()
    keys, _, _, args = inspect.getargvalues(frame)
    values = { k: args[k] for k in keys }
    args = utils.SimpleNamespace(
        **values
    )

    args.base_model = RobertaForSequenceClassification.from_pretrained(base_model)
    ada_merger = AdaMerge(**args)
    model = ada_merger.merge()
    
    # 4. eval (w classifier head)
    run_eval_glue(model, args)


if __name__ == '__main__':
    import defopt
    try:
        defopt.run(ada_glue)
    except:
        import sys,pdb,bdb
        type, value, tb = sys.exc_info()
        if type == bdb.BdbQuit:
            exit()
        print(type,value)
        pdb.post_mortem(tb)