#!/bin/bash

############### Function Declarations   ##############################
# Choose a function to run: trainm or selection
# trainm: train models
# selection: perform test selection
# History
    # Author: 

############### Host   ##############################
HOST=$(hostname)
echo "Current host is: $HOST"

DATE=`date +%Y-%m-%d`
echo $DATE
DIRECTORY=./save/${DATE}/
if [ ! -d "$DIRECTORY" ]; then
    mkdir ./save/${DATE}/
fi

############### Step selection   ##############################

function="$1"
echo "Input function: "$function
test -z $function && echo "You must input a function" && exit 0

############### Configuration   ##############################

DATA_ROOT='./datasets'

epoch=200
STEP=100
RANDOM_SEED=10


############### Train   ##############################
# ----- IP vendor: Train biased models -----
if [ "$function" == "trainm" ]; then
    echo "train model for IP vendor"
    # train models
#    MODEL='resnet34'
    DATASET='stl10'
    MODEL='resnet34'
    save_path=save/${DATE}/${DATASET}_${MODEL}

    python train_classifier.py --dataset ${DATASET} \
                                --model ${MODEL} \
                                --n_epochs ${epoch} \
                                --data_root ${DATA_ROOT} \
                                --manualSeed ${RANDOM_SEED} \
                                --save_path ${save_path} \
                                --class_weight 0 

    wait 

    python train_classifier.py --dataset ${DATASET} \
                                --model ${MODEL} \
                                --n_epochs ${epoch} \
                                --data_root ${DATA_ROOT} \
                                --manualSeed ${RANDOM_SEED} \
                                --save_path ${save_path} \
                                --class_weight 1

    wait 

    python train_classifier.py --dataset ${DATASET} \
                                --model ${MODEL} \
                                --n_epochs ${epoch} \
                                --data_root ${DATA_ROOT} \
                                --manualSeed ${RANDOM_SEED} \
                                --save_path ${save_path} \
                                --class_weight 2
    wait
fi 


#  ------Test Center:  Selection ------
# uncomment the following two lines if you want to use cifar10 dataset
# DATASET='cifar10'
# MODEL='resnet18'


# uncomment the following two lines if you want to use svhn dataset
# DATASET='svhn'
# MODEL='wide_resnet'


# uncomment the following two lines if you want to use stl10 dataset
DATASET='stl10'
MODEL='resnet34'

no_neighbors=100
# 0: BYOL, 1:model2test
feature_extractor_id=0

# run testing
if [ "$function" == "selection" ]; then 
    for MODEL_NO in 1; do
        echo $MODEL_NO
        MODEL2TEST=${MODEL}
        MODEL2TESTPATH=./checkpoint/${DATASET}/ckpt_bias/${MODEL}_${MODEL_NO}_b.t7 
        save_path=save/${DATE}/${DATASET}_${MODEL2TEST}_${STEP}
        echo 'model to test arch '$MODEL2TEST
        echo 'model to test path '$MODEL2TESTPATH
            python selection.py \
                    --dataset $DATASET \
                    --manualSeed ${RANDOM_SEED} \
                    --model2test_arch $MODEL2TEST \
                    --model2test_path $MODEL2TESTPATH \
                    --model_number $MODEL_NO \
                    --save_path ${save_path} \
                    --data_path ${DATA_ROOT} \
                    --graph_nn \
                    --feature_extractor_id ${feature_extractor_id} \
                    --no_neighbors ${no_neighbors} \
                    --learn_mixed 
    done

fi


