import torch
import mmcv
import sys
from mmengine.config import Config
from mmengine.runner import Runner
from mmengine import MODELS

def split_student_model(cfg_path, checkpoint_path, original_checkpoint_path, save_path=None):
    """
    :param: cfg_path: your normal classifier config file path which is not disitilation cfg path
    :param: checkpoint_path: your distilation checkpoint path
    :param: save_path: student model save path
    """
    cfg = Config.fromfile(cfg_path)

    model = MODELS.build(cfg.model)
    model_ckpt = torch.load(checkpoint_path)
    pretrained_dict = model_ckpt['state_dict']
    model_dict = model.state_dict()
    new_dict = {k.replace('architecture.', ''): v for k, v in pretrained_dict.items() if k.replace('architecture.', '') in model_dict.keys()}
    model_dict.update(new_dict)

    student_key_path = 'student_key_output.txt'
    with open(student_key_path, 'w') as f:
        if isinstance(new_dict, dict): 
            recursive_print_keys(new_dict, f) 
        else: 
            print("Checkpoint is not a dictionary. It might be a state dict directly.") 

    
    # 和新增的origin_checkpoint相关的改动
    original_model_ckpt = torch.load(original_checkpoint_path)
    
    torch.save({'state_dict': model_dict, 'meta': model_ckpt['meta'],
                'optimizer': model_ckpt['optimizer'], 'message_hub': original_model_ckpt['message_hub'],
                'param_schedulers': original_model_ckpt['param_schedulers']}, save_path)


def architecture_wrapper(source_ckpt, target_ckpt):
    model_ckpt = torch.load(source_ckpt)
    pretrained_dict = model_ckpt['state_dict']
    
    new_pretrained_dict = {f"architecture.{key}": value for key, value in pretrained_dict.items()}
    model_ckpt['state_dict'] = new_pretrained_dict
    torch.save(model_ckpt, target_ckpt)


def recursive_print_keys(d, file_obj, parent_key=''): 
    """递归打印字典中的所有键""" 
    for k, v in d.items(): 
        new_key = parent_key + '.' + str(k) if parent_key else k 
        if isinstance(v, dict): 
            recursive_print_keys(v, file_obj, new_key) 
        else: 
            file_obj.write(new_key + '\n')


# prompt里面的内容只能在loss_pts_backbone loss_pts_neck loss_head中选
def split_mask_generation_network(checkpoint_path, save_path, prompt):

    prefix = "distiller.distill_losses."
    model_ckpt = torch.load(checkpoint_path)
    pretrained_dict = model_ckpt['state_dict']

    for index, item in enumerate(prompt):
        item_prefix = prefix + item + '.'
        filtered_dict = {
            key[len(item_prefix):]: value
            for key, value in pretrained_dict.items()
            if key.startswith(item_prefix)
        }

        new_dict = {}
        new_dict['state_dict'] = filtered_dict
        torch.save(new_dict, save_path[index])

