import json
import argparse


parser = argparse.ArgumentParser()
parser.add_argument('--mode', type=str, choices=['student', 'teacher', 'distillation', 'selfdist', 'pretrain', 'finetune', "LUPI_teacher", "LUPI_dist"])
parser.add_argument('--ct', type=float)
parser.add_argument('--cb', type=float)
parser.add_argument('--pt', type=float)
parser.add_argument('--pb', type=float)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--alpha', type=float, default=0.5)
parser.add_argument('--json_idx', type=str, default="")
parser.add_argument('--repeat', type=int, default=0)
parser.add_argument('--data_id', type=str, choices=["Set1", "MSLRWEB30K", "Istella", "Istella_X"])
parser.add_argument('--out_root_dir', type=str, default="")
parser.add_argument('--loss_fn', type=str, choices=["RankNet", "RankBCE"])
parser.add_argument('--lr', type=float, default=0.0)
parser.add_argument('--pri_name', type=str, default='Gumbel')
parser.add_argument('--json_name', type=str, default="")
parser.add_argument('--qg_mask', type=int, default=1)

args = parser.parse_args()

args.qg_mask = bool(args.qg_mask)

model_specific = {"RankBCE":"", "RankNet":"Sigma_1"}
model_lr = {}

model_lr[("RankBCE", "Set1")] = 1e-3
model_lr[("RankBCE", "Istella")] = 1e-3
model_lr[("RankBCE", "Istella_X")] = 1e-3
model_lr[("RankBCE", "MSLRWEB30K")] = 1e-3
model_lr[("RankNet", "Set1")] = 3e-4
model_lr[("RankNet", "Istella")] = 3e-4
model_lr[("RankNet", "MSLRWEB30K")] = 3e-4

batch_size_dict = {}
batch_size_dict = {"RankBCE": 500, "RankNet": 300}

lr = model_lr[(args.loss_fn, args.data_id)]
batch_size = batch_size_dict[args.loss_fn]

pri_num_dict = {"Set1": 200, "Istella": 50, "MSLRWEB30K": 40}
if args.pri_name in ['preserve_linear', 'preserve_linear_gumbel']:
    args.pri_num = pri_num_dict[args.data_id]
else:
    args.pri_num = 1

if args.lr > 0:
    lr = args.lr

data_setting = {
    "min_docs": [10],
    "min_rele": [1],
    "train_batch_size": [1],
    "scaler_id": "SLog1P",

    "binary_rele": [False],
    "unknown_as_zero": [False],
    "train_presort": [True]
}

data_setting["data_id"] = args.data_id
if args.data_id == "Set1":
    data_setting["dir_data"] = "../dataset/Yahoo/"
if args.data_id == "MSLRWEB30K":
    data_setting["dir_data"] = "../dataset/MSLR-WEB30K/"
if args.data_id == "Istella":
    data_setting["dir_data"] = "../dataset/Istella/full/"

SF_parameter = {
    "BN": [False],
    "RD": [False],
    "layers": [5],
    "apply_tl_af": [False],
    "hd_hn_tl_af": ["R"]
}

# print(eval_setting["teacher_ckpt"])
preserve_feature_idx = {'Set1': [376, 595, 637, 446, 387, 622, 260, 345, 120, 164, 335, 664, 6, 627, 618, 238, 563, 169, 556, 447, 111, 189, 562, 527, 444, 78, 410, 100, 261, 362, 548, 167, 604, 534, 461, 614, 477, 408, 152, 497, 151, 248, 186, 414, 355, 20, 363, 621, 450, 502, 150, 153, 81, 605, 432, 366, 268, 256, 393, 489, 647, 494, 492, 507, 126, 28, 542, 285, 181, 472, 215, 244, 537, 161, 550, 91, 426, 381, 297, 379, 162, 331, 611, 8, 41, 440, 671, 476, 445, 317, 330, 521, 631, 456, 191, 178, 638, 475, 535, 511, 659, 202, 485, 83, 140, 304, 600, 332, 257, 342, 616, 283, 241, 9, 27, 255, 43, 641, 382, 495, 395, 277, 104, 620, 175, 457, 312, 96, 643, 155, 372, 21, 149, 208, 154, 603, 192, 37, 474, 508, 85, 60, 427, 69, 640, 230, 165, 7, 480, 233, 309, 1, 201, 375, 308, 559, 25, 29, 32, 579, 677, 389, 339, 114, 290, 170, 568, 176, 644, 698, 279, 400, 353, 322, 121, 71, 158, 610, 441, 174, 299, 404, 587, 421, 197, 504, 571, 10, 239, 463, 454, 204, 458, 18, 657, 452, 690, 187, 397, 692, 566, 264, 17, 305, 384, 138, 487, 86, 435, 352, 98, 436, 364, 326, 232, 265, 173, 275, 564, 139, 663, 338, 294, 674, 669, 254, 594, 282, 694, 350, 518, 575, 253, 163, 77, 483, 349, 122, 385, 123, 39, 108, 101, 665, 55, 697, 195, 505, 437, 267, 581, 329, 222, 383, 129, 585, 66, 574, 172, 648, 216, 451, 147, 196, 642, 247, 212, 453, 137, 31, 79, 325, 284, 220, 392, 206, 629, 106, 390, 356, 699, 433, 144, 45, 590, 596, 124, 344, 159, 606, 11, 74, 133, 465, 36, 607, 190, 243, 405, 271, 540, 481, 259, 514, 538, 76, 117, 479, 493, 478, 302, 541, 242, 333, 515, 583, 219, 292, 127, 656, 658, 473, 281, 519, 682, 654, 412, 578, 223, 539, 235, 696, 685, 321, 141, 234, 274, 628, 589, 498, 570, 97, 319, 87, 469, 455, 429, 145, 89, 378, 225, 64, 612, 240, 691, 399, 693, 513, 341, 44, 486, 30, 439, 639, 135, 26, 532, 417, 199, 533, 224, 58, 276, 608, 499, 442, 683, 320, 418, 347, 586, 348, 146, 286, 545, 670, 229, 592, 287, 56, 337, 531, 554, 500, 311, 431, 633, 359, 459, 680, 546, 358, 448, 316, 251, 340, 266, 2, 62, 48, 227, 428, 602, 245, 558, 434, 53, 88, 625, 82, 398, 131, 168, 128, 650, 132, 23, 205, 443, 300, 298, 369, 361, 80, 666, 572, 591, 401, 388, 423, 231, 182, 470, 598, 246, 634, 47, 262, 22, 354, 394, 565, 236, 374, 695, 295, 177, 613, 569, 512, 288, 67, 391, 324, 377, 687, 525, 488, 33, 555, 681, 291, 425, 517, 688, 228, 367, 660, 313, 34, 102, 179, 689, 623, 561, 323, 676, 143, 678, 416, 12, 166, 580, 110, 70, 645, 468, 107, 528, 686],
"Istella": [184, 150, 143, 31, 28, 195, 38, 36, 37, 137, 27, 73, 80, 25, 26, 70, 72, 30, 79, 78, 69, 67, 142, 68, 45, 147, 71, 149, 42, 140, 77, 52, 50, 8, 172, 115, 173, 139, 51, 35, 41, 44, 29, 112, 40, 39, 121, 7, 0, 9, 130, 122, 120, 64, 154, 141, 14, 131, 75, 66, 74, 114, 65, 43, 193, 56, 59, 49, 111, 179, 18, 109, 110, 158, 58, 55, 53, 54, 144, 19, 47, 113, 159, 46, 57, 61, 63, 60, 138, 181, 145, 119, 183, 210, 151, 160, 152, 146, 123, 176, 189, 12, 196, 213, 126, 76, 211, 132, 108, 1, 107, 33, 98, 101, 106, 207, 208, 100, 34, 32, 99, 96, 97, 95, 22, 5, 105, 11, 117, 116, 6, 103, 102, 166, 134, 24, 165, 163, 203, 124, 199, 187, 21, 157, 182, 206, 219, 194, 162, 216, 178, 129, 217, 3, 168, 155, 128, 191, 170, 167, 104, 192, 190, 185, 156, 198, 23, 177, 127, 180, 148, 118, 135, 17, 171, 13, 186, 169, 16, 164, 136, 15, 212, 201, 209, 214, 20, 62, 153, 48, 174, 10, 4, 215, 188, 175, 161, 92, 133, 2, 125, 200, 91, 197, 202, 81, 204, 205, 82, 83, 84, 85, 86, 87, 88, 218, 94, 89, 93, 90],
"MSLRWEB30K": [97, 29, 77, 27, 7, 37, 87, 25, 79, 102, 62, 52, 107, 122, 47, 75, 72, 2, 22, 82, 114, 112, 89, 133, 64, 124, 63, 54, 110, 39, 8, 32, 38, 53, 120, 49, 78, 98, 88, 28, 57, 48, 103, 85, 59, 108, 99, 58, 73, 35, 104, 83, 74, 84, 3, 23, 95, 33, 119, 17, 16, 109, 60, 18, 50, 115, 45, 9, 76, 26, 123, 14, 96, 36, 6, 24, 86, 70, 34, 80, 10, 100, 113, 61, 21, 101, 71, 106, 46, 128, 55, 1, 51, 81, 105, 31, 129, 20, 30, 56, 5, 121, 117, 15, 19, 134, 125, 111, 13, 118, 126, 135, 127, 11, 91, 44, 40, 131, 132, 93, 68, 12, 41, 66, 116, 67, 0, 42, 43, 130, 90, 92, 69, 94, 4, 65]}

if args.pri_name in ["Gumbel"]:
    pri_setting = {
        "type": [args.pri_name],
        "temperature": [args.ct],
        "click_bias": [args.cb],
        "purchase_bias": [args.pb],
        "num_features": [args.pri_num],
        "mix_alpha": [args.alpha],
        "batch_size": [batch_size],
        "learning_rate": lr,
        "qg_mask": args.qg_mask
    }
elif args.pri_name in ['preserve_linear_gumbel']:
    pri_setting = {
        "type": [args.pri_name],
        "preserve_features": preserve_feature_idx[args.data_id][:args.pri_num],
        "temperature": [args.ct],
        "purchase_bias": [args.pb],
        "pri_num": [args.pri_num],
        "mix_alpha": [args.alpha],
        "batch_size": [batch_size],
        "learning_rate": lr,
        "qg_mask": args.qg_mask
    }

if pri_setting['type'][0] in ["Gumbel"]:
    assert args.ct == args.pt
    assert args.cb <= args.pb

alpha = pri_setting['mix_alpha'][0]

final_layer_AF = SF_parameter['hd_hn_tl_af'][0] if SF_parameter['apply_tl_af'][0] else 'No'
hidden_layers = SF_parameter["layers"][0] - 2


def get_ckpt_name():
    if args.pri_name in ["Gumbel"]:
        if args.mode == "selfdist":
            ckpt_name = f"../{args.out_root_dir}/gpu_grid_{args.loss_fn}/{data_setting['data_id']}/{args.loss_fn}_SF_{SF_parameter['hd_hn_tl_af'][0]}.{SF_parameter['hd_hn_tl_af'][0]}{hidden_layers}.{final_layer_AF}_{args.data_id}_MiD_10_MiR_1_TrBat_1_TrPresort_EP_{args.epoch}_V_True_QS_SLog1P/{model_specific[args.loss_fn]}/no_thresholding/{args.pri_name}-t:{args.ct}-cb:{args.cb}-pb:{args.pb}-lr:{lr}-bs:{batch_size}{'-qg_mask_disable' if not pri_setting['qg_mask'] else ''}/student/repeat-{args.repeat}/Fold-@FOLD_NUM@/optimal.pkl"
        elif args.mode == "finetune":
            ckpt_name = f"../{args.out_root_dir}/gpu_grid_{args.loss_fn}/{data_setting['data_id']}/{args.loss_fn}_SF_{SF_parameter['hd_hn_tl_af'][0]}.{SF_parameter['hd_hn_tl_af'][0]}{hidden_layers}.{final_layer_AF}_{args.data_id}_MiD_10_MiR_1_TrBat_1_TrPresort_EP_{args.epoch}_V_True_QS_SLog1P/{model_specific[args.loss_fn]}/no_thresholding/{args.pri_name}-t:{args.ct}-cb:{args.cb}-pb:{args.pb}-lr:{lr}-bs:{batch_size}{'-qg_mask_disable' if not pri_setting['qg_mask'] else ''}/pretrain/repeat-{args.repeat}/Fold-@FOLD_NUM@/optimal.pkl"
        elif args.mode == 'LUPI_dist':
            ckpt_name = f"../{args.out_root_dir}/gpu_grid_{args.loss_fn}/{data_setting['data_id']}/{args.loss_fn}_SF_{SF_parameter['hd_hn_tl_af'][0]}.{SF_parameter['hd_hn_tl_af'][0]}{hidden_layers}.{final_layer_AF}_{args.data_id}_MiD_10_MiR_1_TrBat_1_TrPresort_EP_{args.epoch}_V_True_QS_SLog1P/{model_specific[args.loss_fn]}/no_thresholding/{args.pri_name}-t:{args.ct}-cb:{args.cb}-pb:{args.pb}-lr:{lr}-bs:{batch_size}{'-qg_mask_disable' if not pri_setting['qg_mask'] else ''}/LUPI_teacher/repeat-{args.repeat}/Fold-@FOLD_NUM@/optimal.pkl"
        else:
            ckpt_name = f"../{args.out_root_dir}/gpu_grid_{args.loss_fn}/{data_setting['data_id']}/{args.loss_fn}_SF_{SF_parameter['hd_hn_tl_af'][0]}.{SF_parameter['hd_hn_tl_af'][0]}{hidden_layers}.{final_layer_AF}_{args.data_id}_MiD_10_MiR_1_TrBat_1_TrPresort_EP_{args.epoch}_V_True_QS_SLog1P/{model_specific[args.loss_fn]}/no_thresholding/{args.pri_name}-t:{args.ct}-cb:{args.cb}-pb:{args.pb}-lr:{lr}-bs:{batch_size}{'-qg_mask_disable' if not pri_setting['qg_mask'] else ''}/teacher/repeat-{args.repeat}/Fold-@FOLD_NUM@/optimal.pkl"
    elif args.pri_name in ['preserve_linear_gumbel']:
        if args.mode == "selfdist":
            ckpt_name = f"../{args.out_root_dir}/gpu_grid_{args.loss_fn}/{data_setting['data_id']}/{args.loss_fn}_SF_{SF_parameter['hd_hn_tl_af'][0]}.{SF_parameter['hd_hn_tl_af'][0]}{hidden_layers}.{final_layer_AF}_{args.data_id}_MiD_10_MiR_1_TrBat_1_TrPresort_EP_{args.epoch}_V_True_QS_SLog1P/{model_specific[args.loss_fn]}/no_thresholding/{args.pri_name}-d:{args.pri_num}-t:{args.ct}-pb:{args.pb}-lr:{lr}-bs:{batch_size}{'-qg_mask_disable' if not pri_setting['qg_mask'] else ''}/student/repeat-{args.repeat}/Fold-@FOLD_NUM@/optimal.pkl"
        elif args.mode == "pretrain":
            ckpt_name = f"../{args.out_root_dir}/gpu_grid_{args.loss_fn}/{data_setting['data_id']}/{args.loss_fn}_SF_{SF_parameter['hd_hn_tl_af'][0]}.{SF_parameter['hd_hn_tl_af'][0]}{hidden_layers}.{final_layer_AF}_{args.data_id}_MiD_10_MiR_1_TrBat_1_TrPresort_EP_{args.epoch}_V_True_QS_SLog1P/{model_specific[args.loss_fn]}/no_thresholding/{args.pri_name}-d:{args.pri_num}-t:{args.ct}-pb:{args.pb}-lr:{lr}-bs:{batch_size}{'-qg_mask_disable' if not pri_setting['qg_mask'] else ''}/pretrain/repeat-{args.repeat}/Fold-@FOLD_NUM@/optimal.pkl"
        elif args.mode == 'LUPI_dist':
            ckpt_name = f"../{args.out_root_dir}/gpu_grid_{args.loss_fn}/{data_setting['data_id']}/{args.loss_fn}_SF_{SF_parameter['hd_hn_tl_af'][0]}.{SF_parameter['hd_hn_tl_af'][0]}{hidden_layers}.{final_layer_AF}_{args.data_id}_MiD_10_MiR_1_TrBat_1_TrPresort_EP_{args.epoch}_V_True_QS_SLog1P/{model_specific[args.loss_fn]}/no_thresholding/{args.pri_name}-d:{args.pri_num}-t:{args.ct}-pb:{args.pb}-lr:{lr}-bs:{batch_size}{'-qg_mask_disable' if not pri_setting['qg_mask'] else ''}/LUPI_teacher/repeat-{args.repeat}/Fold-@FOLD_NUM@/optimal.pkl"
        else:
            ckpt_name = f"../{args.out_root_dir}/gpu_grid_{args.loss_fn}/{data_setting['data_id']}/{args.loss_fn}_SF_{SF_parameter['hd_hn_tl_af'][0]}.{SF_parameter['hd_hn_tl_af'][0]}{hidden_layers}.{final_layer_AF}_{args.data_id}_MiD_10_MiR_1_TrBat_1_TrPresort_EP_{args.epoch}_V_True_QS_SLog1P/{model_specific[args.loss_fn]}/no_thresholding/{args.pri_name}-d:{args.pri_num}-t:{args.ct}-pb:{args.pb}-lr:{lr}-bs:{batch_size}{'-qg_mask_disable' if not pri_setting['qg_mask'] else ''}/teacher/repeat-{args.repeat}/Fold-@FOLD_NUM@/optimal.pkl"
    return ckpt_name

eval_setting = {
    "repeat_idx": args.repeat,
    "mode": args.mode,
    "teacher_ckpt": get_ckpt_name(),
    "student_dir":"deprecated",
    "distillation_dir":"deprecated",

    "dir_output": f"../{args.out_root_dir}/",

    "epochs": args.epoch,
    "do_validation": True,
    "vali_k": 8,
    "cutoffs": [8, 16, 32],
    "loss_guided": False,
    "do_log": False,
    "log_step": 1,
    "do_summary": True,

    "mask": {
        "mask_label": False,
        "mask_type": ["rand_mask_all"],
        "mask_ratio": [0.2]
    },
    "thresholding": False,
    "thresholding_val": [3.5]
}

full_data = {"DataSetting": data_setting, "EvalSetting": eval_setting,
             "SFParameter": SF_parameter, "PriSetting": pri_setting}

if len(args.json_name) == 0:
    with open(f'Data_Eval_ScoringFunction{args.json_idx}.json', 'w') as outfile:
        json.dump(full_data, outfile, indent=4)
else:
    with open(f'{args.json_name}.json', 'w') as outfile:
        json.dump(full_data, outfile, indent=4)


