#!/bin/bash

DATASET="$1"
shift
SEED="$2"
shift

GPUS_PER_NODE=4
N_WORKERS=5  # CPUs per GPU
EPOCHS=25
MODEL_BASE="resnet18"

CKPT_DIR=ckpt
JOBNAME="${DATASET}/${MODEL_BASE}/$(date +%Y%m%d-%H%M%S-%6N)-$RANDOM"

echo ""
echo "------------------------------------------------------------------------"
echo ""
echo "SEED          = $SEED"
echo "DATASET       = $DATASET"
echo "EPOCHS        = $EPOCHS"
echo "MODEL_BASE    = $MODEL_BASE"
echo "GPUS_PER_NODE = $GPUS_PER_NODE"
echo "N_WORKERS     = $N_WORKERS"
echo "CKPT_DIR      = $CKPT_DIR"
echo "JOBNAME       = $JOBNAME"
echo "EXTRA_ARGS    = ${@}"
echo ""
echo "Main script begins"
echo ""

python3 -m torch.distributed.launch \
    --nproc_per_node="$GPUS_PER_NODE" \
    train.py \
    "$DATASET" \
    --model resnet18 \
    --torchvision-model \
    --seed "$SEED" \
    --lr 0.01 \
    --warmup-epochs 1 \
    --cooldown-epochs 0 \
    --epochs "$EPOCHS" \
    --weight-decay 1e-4 \
    --sched cosine \
    --crop-pct 0 \
    --scale 0.7 1.0 \
    --smoothing 0 \
    --batch-size 32 \
    --workers "$N_WORKERS" \
    --output "$CKPT_DIR" \
    --experiment "$JOBNAME/head" \
    --resume "$CKPT_DIR/$JOBNAME/head/last.pth.tar" \
    --checkpoint-hist 1 \
    --no-mlp-layer-norm \
    "${@}"
