# A Gradient Accumulation Method for Dense Retriever under Memory Constraint

This repository is the official implementation of A Gradient Accumulation Method for Dense Retriever under Memory Constraint, submitted to NeurIPS 2024. 

## 1. Requirements
---
To install requirements:

```setup
pip install -r requirements.txt
```

## 2. Preparing the data
To prepare the data used in the paper, two steps are required. First, download each dataset. Second, if required, preprocess the dataset that filtering the hard negatives by cross encoder scores.
### 2-1. Download DPR data
---
DPR provides the preprocessed datasets they used in the paper in their official repository. You can download the datasets by running the following command:

```bash
bash data/download_dpr_datasets.sh
```

### 2-2. Download and preprocess MS Marco data
---
Beir official repository provides efficient code to download MS Marco data. Also, Huggingface provides the preprocessed MS Marco data with annotations in the same format of sentense-transformers.
Additionally, you can preprocess the MS Marco data by filtering the hard negatives by cross encoder scores generated by sentence-transformers. 
You can find codes for downloading and preprocessing MS Marco in `data/msmarco_download_and_preprocess.ipynb` to download and preprocess MS Marco data.

## 3. Training
---
You can train the DPR in each setting as follows regardless of the dataset including MS Marco:

### 3-1. DPR with ContAccum in low-resource
```bash
python src/train_dpr.py --config_file config/{data_name}/train_dpr_{data_name}_contAccum_cache1_accum4.yaml
```

### 3-2. DPR in high-resource
```bash
python src/train_dpr.py --config_file config/{data_name}/train_dpr_{data_name}_bsz128.yaml
```

### 3-3. DPR in low-resource
```bash
python src/train_dpr.py --config_file config/{data_name}/vram11/train_dpr_{data_name}_bsz8.yaml
```

### 3-4. DPR with gradient accumulation in low-resource
```bash
python src/train_dpr.py --config_file config/{data_name}/vram11/train_dpr_{data_name}_gradAccum_4.yaml
```

## 4. Extract embeddings of all passages
For MS Marco, you can extract embeddings of all passages as follows:
```bash
accelerate launch --num_processes=4 doc2embedding_msmarco.py \
    --embed_dir /workspace/mnt2/dpr_output/{embed_dir} \
    --model_save_dir /workspace/mnt2/dpr_logs/{model_dir}
```
For DPR datasets, you can extract embeddings of all passages as follows:
```bash
bash scripts/tools/embed.sh {model_dir} {embed_dir} 
```

## 5. Evaluation
For MS Marco, you can evaluate the performance of the model as follows:
```bash
python test_msmarco.py \
    --embedding_dir {embed_dir} \
    --model_save_dir {model_dir} \
    --data_split test \
    --result_file_path result.csv
```

For DPR datasets, you can evaluate the performance of the model as follows:
```bash
bash scripts/tools/test.sh 6 {model_dir}/query_encoder {embed_dir}/embeddings
```