#! /bin/bash
#
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

nvidia-smi

EXP_NAME=$1
ENABLE=$2
PB=$3
GB=$4
MB=$5
SQMB=$6
OPTIM=$7
QUANT_TYPE=$8
ROUND_TYPE=$9
Q_ORACLE=$10
SEED=${11:-1}

RESULTS_DIR=/path-to-results/${EXP_NAME}
CHECKPOINTS_DIR=/path-to-results/${EXP_NAME}/checkpoints
mkdir -p $CHECKPOINTS_DIR

: ${LR:=0.0006}
: ${WARMUP:=4000}
: ${NUM_EPOCHS:=30}
: ${BS:=5120}
: ${NUM_GPU:=8}

STAT_FILE=${RESULTS_DIR}/DGX1_amp_${NUM_GPU}GPU_log.json
DISTRIBUTED="-m torch.distributed.run --nproc_per_node=${NUM_GPU}"

python ${DISTRIBUTED} train.py \
  data/wmt14_en_de_joined_dict \
  --arch transformer_wmt_en_de \
  --share-all-embeddings \
  --optimizer ${OPTIM} \
  --adam-betas 0.9 0.997 \
  --adam-eps 1e-9 \
  --clip-norm 0.0 \
  --lr-scheduler inverse_sqrt \
  --warmup-init-lr 0.0 \
  --warmup-updates ${WARMUP} \
  --lr $LR \
  --min-lr 0.0 \
  --dropout 0.1 \
  --weight-decay 0.0 \
  --criterion label_smoothed_cross_entropy \
  --label-smoothing 0.1 \
  --max-tokens ${BS} \
  --seed ${SEED} \
  --max-epoch ${NUM_EPOCHS} \
  --no-epoch-checkpoints \
  --fuse-layer-norm \
  --online-eval \
  --log-interval 100 \
  --save-dir ${RESULTS_DIR} \
  --stat-file ${STAT_FILE} \
  --amp --save-interval 10 \
  --lpmm_enable $ENABLE \
  --pb $PB \
  --gb $GB \
  --mb $MB \
  --sqmb $SQMB \
  --q_oracle $Q_ORACLE \
  --quant_type $QUANT_TYPE \
  --round_type $ROUND_TYPE
