EXP_NAME=$1
MODEL=$2
MODE=$3
ENABLE=$4
PB=$5
MB=$6
SQMB=$7
SCALE_TYPE=$8
QUANT_TYPE=$9
SEED=$10
ROUND_TYPE=$11

if [ "$MODEL" = "gpt2-large" ]; then
    MODEL_CARD="gpt2.lg"
elif [ "$MODEL" = "gpt2-medium" ]; then
    MODEL_CARD="gpt2.md"
elif [ "$MODEL" = "gpt2-xl" ]; then
    MODEL_CARD="gpt2.xl"
else
    MODEL_CARD="gpt2.sm"
    MODEL="gpt2-small"
fi

python -m torch.distributed.launch --nproc_per_node=4 --master_port=12345 \
    src/gpt2_ft.py \
    --train_data ./data/e2e/train.jsonl \
    --valid_data ./data/e2e/valid.jsonl \
    --train_batch_size 2 \
    --grad_acc 1 \
    --valid_batch_size 4 \
    --seq_len 512 \
    --model_card $MODEL_CARD \
    --init_checkpoint ./pretrained_checkpoints/$MODEL-pytorch_model.bin \
    --platform local \
    --clip 0.0 \
    --lr 4e-5 \
    --weight_decay 0.01 \
    --correct_bias \
    --adam_beta2 0.999 \
    --scheduler linear \
    --warmup_step 500 \
    --max_epoch 5 \
    --save_interval 1000 \
    --lora_dim 0 \
    --lora_alpha 32 \
    --lora_dropout 0.1 \
    --label_smooth 0.1 \
    --work_dir $EXP_NAME \
    --lpmm_enable $ENABLE \
    --pb $PB \
    --mb $MB \
    --sqmb $SQMB \
    --optim_mode $MODE \
    --scale_type $SCALE_TYPE \
    --q_oracle $QUANT_TYPE \
    --round_type $ROUND_TYPE \
    --random_seed $SEED