import sys

sys.path.append(".")
import argparse
import os
import time
import json
import torch
import torch.nn as nn

from accelerate import dispatch_model, load_checkpoint_in_model
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, get_scheduler, set_seed
# from transformers import LlamaTokenizer, LlamaForCausalLM

from binarization.binary_util import get_blocks, replace_with_mos, to_regular_linear_mos

from utils.datautils import get_qat_dataset
from utils.utils import print_trainable_parameters, prepare_model_for_training
from utils.kd_utils import KDTrainer

def main(args):
    start = time.time()
    set_seed(args.seed)

    # Load Tokenizer
    if "llama" in args.model_id:
        tokenizer = AutoTokenizer.from_pretrained(args.model_id, device_map='auto', use_fast=False, trust_remote_code=True)
    else:
        tokenizer = AutoTokenizer.from_pretrained(args.model_id, device_map='auto', trust_remote_code=True)
    
    model = AutoModelForCausalLM.from_pretrained(args.model_id, device_map='auto', torch_dtype=torch.float16)
    model.config.use_cache = False
        
    prepare_model_for_training(model)


    print(f'Model GPU Status: {model.hf_device_map}')

    # Save Directory
    save_dir = os.path.join(f"outputs/{args.model_id}", args.save_dir)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # BinaryMoS
    moe_config = {'scale_init': True, 'num_experts': args.num_experts, 'train_only_scale': args.train_only_scale,}
    replace_with_mos(get_blocks(model), args, moe_config)


    print_trainable_parameters(model)

    # Load dataset
    print(f"Prepare training data ({args.dataset})")
    datasets, data_collator = get_qat_dataset(args.dataset, tokenizer, args.cache_dir)

    # Define training arguments
    training_args = TrainingArguments(
        per_device_train_batch_size=args.per_device_train_batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        warmup_ratio=args.warmup_ratio,
        num_train_epochs=args.num_train_epochs,
        bf16=True,
        logging_steps=1,
        save_steps=20000,
        save_only_model=True,
        output_dir=save_dir,
        learning_rate=args.lr,
        lr_scheduler_type="cosine",
        optim="adamw_torch",
        adam_beta1=args.adam_beta1,
        adam_beta2=args.adam_beta2,   
        )

    print(f"Loading Teacher Model")
    teacher_model = AutoModelForCausalLM.from_pretrained(args.model_id, device_map='auto', torch_dtype=torch.float16)
    teacher_model.config.use_cache = False
    print(f'Teacher Model GPU Status: {teacher_model.hf_device_map}')
        
    # Create trainer
    trainer = KDTrainer(
        model=model,
        teacher_model=teacher_model,
        l2l_loss_scale=args.l2l_loss_scale,
        tokenizer=tokenizer,
        train_dataset=datasets,
        args=training_args,
        data_collator=data_collator,
        )

    # Train the model
    trainer.train()

    # Save model and binary parameter and config
    model.eval()
    checkpoint_dict = {}
    for key in model.state_dict():
        if 'in_channel_scale' in key:
            checkpoint_dict[key] = model.state_dict()[key]
        elif 'out_channel_scale' in key:
            checkpoint_dict[key] = model.state_dict()[key]
        elif 'gate_linear' in key:
            checkpoint_dict[key] = model.state_dict()[key]
    torch.save(checkpoint_dict, os.path.join(save_dir, 'checkpoint_dict'))
    to_regular_linear_mos(get_blocks(model))
    model.save_pretrained(save_dir)

    moe_config['scale_init'] = False
    moe_config_path = os.path.join(save_dir, "moe_config.json")
    with open(moe_config_path, 'w') as f:
        json.dump(moe_config, f)

    # Save arguments
    end = time.time()
    with open(os.path.join(save_dir, "args.txt"), "w", encoding="utf-8") as f:
        print(f'* Args\n{vars(args)}\n', file=f)
        print(f'* Training Args\n{training_args}\n', file=f)
        print(f'* Max memory allocated\n{torch.cuda.max_memory_allocated() / 1024**2 / 1024:.2f} GiB\n', file=f) 
        print(f'* Training Time\n{(end-start)/3600:.2f} Hours\n', file=f)
        
    # Print Memory Usage and save directory
    print(f"Max memory_allocated: {torch.cuda.max_memory_allocated() / 1024**2 / 1024:.2f} GiB")
    print(f"Model saved to {save_dir}")


if __name__ == "__main__":
    tm = time.localtime(time.time())
    parser = argparse.ArgumentParser(description="Model Training Script")
    ### Model & Datasets ###
    parser.add_argument(
        "--model_id", type=str, default="huggyllama/llama-7b", help="Pretrained model ID",
    )
    parser.add_argument(
        "--dataset", type=str, default="c4_wiki", help="Dataset name"
    )
    parser.add_argument(
        "--cache_dir", type=str, default='', help=""
    )
    parser.add_argument(
        "--save_dir", type=str, default=f'{tm.tm_year}_{tm.tm_mon}_{tm.tm_mday}_{tm.tm_hour}_{tm.tm_min}_{tm.tm_sec}', help=""
    )
    ### Training Parameter ###
    parser.add_argument(
        "--seed", type=int, default=42, help="Seed"
    )
    parser.add_argument(
        "--num_train_epochs", type=float, default=3.0, help="Number of training epochs"
    )
    parser.add_argument(
        "--per_device_train_batch_size", type=int, default=4
    )   
    parser.add_argument(
        "--gradient_accumulation_steps", type=int, default=1
    )
    parser.add_argument(
        "--warmup_ratio", type=float, default=0.03
    )
    parser.add_argument(
        "--lr", type=float, default=2e-5
    )
    parser.add_argument(
        "--adam_beta1", type=float, default=0.9
    )  
    parser.add_argument(
        "--adam_beta2", type=float, default=0.999
    )  
    ### Binary MoS ###
    parser.add_argument(
        "--train_only_scale", action='store_true', default=False,
    )
    parser.add_argument(
        "--num_experts", type=int, default=4,
    )
    ### KD ###
    parser.add_argument(
        "--l2l_loss_scale", type=float, default=10.0,
    )


    args = parser.parse_args()

    main(args)
