TARGET_TASK_NAMES=$1
DIM=8192 # decide which dimension to use
for seed in 3 6 9;do
GRADIENT_PATH=/mydata/corpus_grad/p0.05_seed${seed}/adam_grads_llama2-7b-p0.05_seed${seed}_{}_{}_dim${DIM}/all_orig.pt
TRAIN_FILE_NAMES="flan_v2 cot dolly oasst1"
CKPTS="105 211 317 420" # checkpoing index
CHECKPOINT_WEIGHTS="1.6877e-05 1.2859e-05 7.7030e-06 2.5616e-06" # average lr of the epoch

VALIDATION_GRADIENT_PATH=../grads/p0.05_seed${seed}_lora/{}-ckpt{}-sgd/dim${DIM}/all_orig.pt
SELECTED_DATA_OUTPUT_PATH="../kdeknn_data_seed${seed}"

if [[ ! -d $SELECTED_DATA_OUTPUT_PATH ]]; then
    mkdir -p $SELECTED_DATA_OUTPUT_PATH
fi

for alpha in 0.075;
do
	echo "##############################"
	echo "alpha: ${alpha}"
	python3 -m kdeknn.prob_assign \
	--gradient_path $GRADIENT_PATH \
	--train_file_names $TRAIN_FILE_NAMES \
	--ckpts $CKPTS \
	--checkpoint_weights $CHECKPOINT_WEIGHTS \
	--validation_gradient_path $VALIDATION_GRADIENT_PATH \
	--target_task_names $TARGET_TASK_NAMES \
	--output_path $SELECTED_DATA_OUTPUT_PATH \
	--alpha ${alpha} --sigma 0.2 \
	--C 5 #--load_top_info --load_kde_info
done

done
