# AlphaMath Almost Zero: Process Supervision Without Process


This is the official repository of "AlphaMath Almost Zero: process Supervision without process". Our approach involves training the policy and value models using only the mathematical reasoning derived from the Monte Carlo Tree Search (MCTS) framework, eliminating the need for GPT-4 or human annotations. This is an illustration of [training instance](imgs/mcts_example.pdf) generated by MCTS in round 3.


## Installation
1. Install `requirements.txt`
```
pip install -r requirements.txt
```
2. Install the evaluation toolkit (../MATH_EVAL) as a package.
3. Install our customized vllm (../vllm) to support value model.


## Checkpoint Initialization
1. Download the deepseek-math-7b-base.
2. You can use the `scripts/save_value_head.py` to add the value head to the LLM.


## Greedy Decoding
You can run either of the following two cmds. There may be slightly difference of accuracy between the two. In our machine, the first got 53.4% and the second got 53.62%.
```
python react_batch_demo.py \
--custom_cfg configs/react_sft.yaml \
--qaf ../MATH_EVAL/data/math_testset_annotation.json
```
or
```
# use step_beam (1, 1) without value func
python solver_demo.py \
--custom_cfg configs/sbs_greedy.yaml \
--qaf ../MATH_EVAL/data/math_testset_annotation.json
```


## Step-level Beam Search
In our machine, on MATH testset, the following cmd with config `B1=1, B2=5` can achieve ~62%, and the one with config `B1=3, B2=5` can reach ~65%.
```
python solver_demo.py \
--custom_cfg configs/sbs_sft.yaml \
--qaf ../MATH_EVAL/data/math_testset_annotation.json
```


## MCTS
### Training data generation. 

![](./imgs/pipeline.png)

The `ground_truth` (the final answer, not the solution process) must be provided in `qaf` file.

round 1
```
# Checkpoint Initialization is required by adding value head
python solver_demo.py \
--custom_cfg configs/mcts_round1.yaml \
--qaf ../MATH_EVAL/data/math_testset_annotation.json
```

round > 1, after SFT
```
python solver_demo.py \
--custom_cfg configs/mcts_sft_round.yaml \
--qaf ../MATH_EVAL/data/math_testset_annotation.json
```

### Inference. 

Only `question` will be used, but the `ground_truth` will be used for calculating the accuracy..
```
python solver_demo.py \
--custom_cfg configs/mcts_sft.yaml \
--qaf ../MATH_EVAL/data/math_testset_annotation.json
```
Different from step-level beam search, you need first to build a complete tree, then you should run the MCTS offline.
```
python offline_inference.py \
--custom_cfg configs/offline_inference.yaml \
--tree_jsonl <the saved tree jsonl file by solver_demo.py>
```
Note: this script can also be run with saved tree by step-level beam search, and the accuracy should remain the same.



## Value Estimation

### Distribution of Q-value for intermediate steps on training data. 

Because ground truth is known for training data, the Q-value can converge very well.

<img src="imgs/Q_distribution.png" width="500">

### Distribution of Q-value for both intermediate and final steps on test data. 

On test set, the ground truth is unknown, so the Q-value distribution includes both intermediate and final steps. From this figure, we can find
1. When model prediction is correct, its Q-value also converges towards 1.
2. For solutions with incorrect final answer, the distribution of Q-value covers all [-1,1], because the intermediate steps may be correct.
3. When model prediciton is incorrect, but the model itself thinks it is correct, so its value prediction may towards 1, making the distribution of Q-value have a peak near 1.

<img src="imgs/Q_distribution_test.png" width="500">


## Inference on MATH dataset

| Inference Method       | Accuracy | avg. time (s) per question | avg. steps | 
| ---------------------- | -------- | -------------------------- | ---------- |
| Greedy                 | 53.62    | 1.6                        | 3.10       |
| Step-level Beam (1,5)  | 62.12    | 3.1                        | 3.01       |
| Step-level Beam (2,5)  | 64.98    | 2.4                        | 2.36       |
| Step-level Beam (3,5)  | 65.56    | 2.3                        | 2.21       |
| Step-level Beam (5,5)  | 65.98    | 4.7                        | 2.26       |
| MCTS (N=40)            | 63.72    | 10.1                       | 3.76       |


