# Multicalibration Post-Processing

This repository contains the official implementation of all experiments in our NeurIPS 2024 submission. We conduct the first
comprehensive empirical study of multicalibration post-processing,
across a broad set of tabular, image, and language datasets for models spanning
from simple decision trees to 90 million parameter fine-tuned LLMs.

Included in this repository is all of the code necessary to run experiments from the paper, but it may also serve as a standalone tool for studying multicalibration and multicalibration algorithms. 

## Requirements

To install requirements:

```setup
pip install -r requirements.txt
```

## Reproducing Results

All experimental results are provided in the `results` directory. To create figures from these results, run the following commands in the root directory:

```bash
python scripts/generate_figures.py
python scripts/generate_tables.py
```

To run an experiment, run one of the functions available in `experiments.py`. Given a model, dataset, list of calibration fractions, and list of seeds for the validation split, these functions will pretrain, train, or evalaute (depending on the function) over the specified calibration fractions and split seeds. To specify the model hyperparameters on each dataset and calibration fraction, one may edit the `hyperparameters` dictionary in `configs/hyperparameters.py`, though it currently contains the hyperparameters used to obtain our results.

Once models have been trained, post-processed, and evaluated, one may reproduce the figures on these new runs. To first download the results from wandb, run the script provided in `download_results.py`. This will download the entire collection of wandb runs as csvs, which will be stored in the `results` directory. Once this information is saved, one may freely generate figures with the scripts cited above.

## Using This Repository

Training a model and applying a post-processing algorithm is straightforward. Consider the following example, which retrieves hyperparameters we use in the paper.

First define an mcb algorithms. To see an available algorithms, examine the names and parameter dictionaries in `configs/mcb_algorithms.py`. Alternatively, one may look at their implementations in the `mcb_algorithms` directory.

```python
mcb_algorithm = 'HKRR'
mcb_params = {
    'lambda': 0.1,
    'alpha': 0.025,
}
```

From here, running each postprocess follows from calling `experiment.multicalibrate()` with the desired algorithm and parameters.

```python
from configs.constants import SPLIT_DEFAULT, MCB_DEFAULT
from configs.hyperparameters import get_hyperparameters
from Experiment import Experiment
from Dataset import Dataset
from Model import Model

# set constants for the experiment
model_name = 'MLP'
dataset = 'ACSIncome'
calib_frac = 0.4
seed = 0

# set the save directory and wandb project
save_dir = 'models/saved_models/{dataset}/{model_name}/calib={calib_frac}_val_seed={seed}/'
wdb_project = f'{dataset}_project'

# define config for experiment
hyp = get_hyperparameters(model_name, dataset, cf)
config = {
    'model': model_name,
    'dataset': dataset,
    'calib_frac': cf,
    'val_split_seed': seed,
    'split': SPLIT_DEFAULT,
    'mcb': MCB_DEFAULT,
    'save_dir': save_dir,
    **hyp
}

dataset_obj = Dataset(dataset, val_split_seed=config['val_split_seed'])
model = Model(model_name, config=config, SAVE_DIR=config['save_dir'])
experiment = Experiment(dataset_obj, model, calib_frac=config['calib_frac'])

# init logger; this saves metrics to wandb
experiment.init_logger(config, project=wdb_project)

# train and postprocess
experiment.train_model()
if config['calib_frac'] > 0:
    experiment.multicalibrate(mcb_algorithm, mcb_params)

# evaluate splits
experiment.evaluate_val()
experiment.evaluate_test()

# close logger
experiment.init_logger(finish=True)
```

Although this example uses only one multicalibration postprocess, one may apply as many as they would like, sequentially. All postprocesses are stored by the experiment object and used in evaluation.


## Acknowledgements

The design of this repository is partially inspired by that of the [WILDS Benchmark](https://github.com/p-lambda/wilds).