#!/bin/bash

scaling_factor=$1
scaling_type=$2
model_type=$3
task_name=$4
step=$5


if [[ "$@" == *"flash_attn"* ]]; then
  flash_attn=1
else
  flash_attn=0
fi
echo "flash_attn: ${flash_attn}"


data_path_prefix=/scratch/nlp/wutong/dataset/PoSE-Datasets
model_path_prefix=/scratch2/nlp/wutong/${task_name}/${model_type}_results


if [ "$scaling_factor" == "none" ]
then
  for gold_index in 0 18 37 54 74; do
      python -u evaluation/get_responses.py \
          --input_path ${data_path_prefix}/kv_data/kv-retrieval-75_keys.jsonl.gz \
          --model_name_or_path /scratch2/nlp/plm/Llama-2-7b-hf \
          --task_name kv \
          --batch_size 1 \
          --gold_index ${gold_index} \
          --max_prompt_length 4096 \
          --model_max_position_embeddings 4096 \
          --max_new_tokens 50 \
          --use_flash_attn ${flash_attn} \
          --output_path eval_output/kv_predictions/${model_type}-${task_name}/kv_75_at_${gold_index}_4k.txt
  done
else
  for gold_index in 0 18 37 54 74; do
      python -u evaluation/get_responses.py \
          --input_path ${data_path_prefix}/kv_data/kv-retrieval-75_keys.jsonl.gz \
          --model_name_or_path ${model_path_prefix}/4k-$((scaling_factor*4))k-${scaling_type}/checkpoint-${step} \
          --task_name kv \
          --batch_size 1 \
          --gold_index ${gold_index} \
          --max_prompt_length $((scaling_factor*4096)) \
          --model_max_position_embeddings 4096 \
          --rope_scaling_factor ${scaling_factor} \
          --rope_scaling_type ${scaling_type} \
          --max_new_tokens 50 \
          --use_flash_attn ${flash_attn} \
          --output_path eval_output/kv_predictions/${model_type}-${task_name}-${step}/kv_75_at_${gold_index}_4k_$((scaling_factor*4))k_${scaling_type}.txt
  done
fi
