# Custom environment
source ~/.bashrc
conda deactivate
conda activate data

export CONFIG=configs/base.yaml
export CHECKPOINTS_PATH=XXX


export OMP_NUM_THREADS=${SLURM_CPUS_PER_TASK}
export MPICH_GPU_SUPPORT_ENABLED=1
export MIOPEN_USER_DB_PATH=/tmp/${USER}-miopen-cache-${SLURM_JOB_ID}
export MIOPEN_CUSTOM_CACHE_DIR=${MIOPEN_USER_DB_PATH}

export PYTHONPATH=.:${PYTHONPATH}

export PYTORCH_KERNEL_CACHE_PATH=/tmp/pytorch_kernel_cache/
mkdir -p $PYTORCH_KERNEL_CACHE_PATH

if [ -z "${SLURM_NTASKS_PER_NODE+x}" ]; then
  export SLURM_NTASKS_PER_NODE=1
fi

# export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512

echo "Running with srun"

srun \
  --cpus-per-task=${SLURM_CPUS_PER_TASK} \
  --distribution=block:block \
  --kill-on-bad-exit \
  scripts/run_with_environment.sh \
      python -u scripts/train.py ${CONFIG} \
      ${@}