# Part 2 of Code repository for "Path Independent Equilibrium Networks can better exploit test-time computation"
Code to train and test all the networks used for experiments with mazes. A part of this code repository has been adapted from the prior work by Schwarzschild et. al. [End-to-end Algorithm Synthesis with Recurrent Networks: Logical Extrapolation Without Overthinking](https://arxiv.org/abs/2202.05826)

## Getting Started

### Requirements
This code was developed and tested with Python 3.8.2.

To install requirements:

```$ pip install -r requirements.txt```

## Training 
To train models, run [train_model.py](train_model.py) (for unrolled + backprop models) or [train_deq.py](train_deq.py) (for DEQ models) with the desired command line arguments. With these arguments, you can choose a model architecture and set all the pertinent hyperparameters. The default values for all the arguments in the hydra directory configuration files.

```$ python train_model.py```

An example command to train a DEQ model is:
```python train_deq.py problem.hyp.train_mode=deq problem.hyp.test_mode=deq problem/model=deq_net_v3 problem=mazes name=<name_of_experiment> problem.deq.fp_init=x_proj problem.hyp.epochs=200 problem.hyp.lr=0.001 problem.train.pretrain_steps=200 problem.deq.loss.jac_loss=False problem.deq.num_layers=15 problem.deq.wnorm=True problem.deq.f_solver=broyden problem.deq.b_solver=broyden```

An example command to train a feedforward model is:
```python train_model.py problem.hyp.alpha=0.01 problem.hyp.epochs=200 problem.hyp.lr=0.001  problem/model=ff_net_2d problem=mazes name=mazes_ablation```

This command will train and save a model. For more examples see the [launch](launch) directory, where we have left several files corresponding to our main experiments.

## Testing

To test a saved model, run [test_model.py](test_model.py) (for unrolled + backprop models) or p[test_deq.py](test_deq_.py) (for DEQ models) as follows. 

```$ python test_model.py problem.model.model_path=<dir_with_checkpoint>```

An example command to test a DEQ model is:

```python test_deq.py problem.hyp.train_mode=deq problem.hyp.test_mode=deq problem/model=deq_net_v3 problem=mazes name=<name_of_experiment> problem.deq.norm=none problem.hyp.epochs=200 problem.train.pretrain_steps=200 problem.deq.wnorm=True problem.deq.num_layers=32 problem.model.model_path=<path_to_checkpoint_folder> problem.test_data=13 quick_test=True model_name=<checkpoint_to_load> problem.hyp.test_batch_size=100 problem.deq.fp_init=zeros problem.deq.f_solver=broyden problem.deq.f_thres=100 problem.deq.stop_mode=abs```

To point to the command line arguments that were used during training and to the model checkpoint file, use the flags in the example above. Other command line arguments are outlined in the code itself, and generally match the structure used for training. As with training, the `outputs` folder will have performance metrics in json data. (See the saving protocol below.)


