Gradient Dynamics of Shallow Low-Dimensional ReLU Networks
--------------------------

This repository contains the code for the experiments in the paper "Gradient Dynamics of Shallow Low-Dimensional ReLU Networks"
All dependencies can be installed with conda by running:

```
[Win64 and MacOS Only] conda config --add channels pytorch
conda env create -f environment.yml
```

To activate the environment, run
```
conda activate geometry-of-dnns
```

All scripts to train models and generate plots are in the prematched_1d directory.

#### Training a Model

Run `fit_model.py` to train a model with some geometry (the --geometry flag). The documentation for `fit_model.py` is as follows:

```
usage: fit_model.py [-h] [--output OUTPUT] [--geometry GEOMETRY]
                    [--epochs EPOCHS] [--learning-rate LEARNING_RATE]
                    [--seed SEED] [--save-every SAVE_EVERY]
                    [--last-layer-bias] [--device DEVICE] [--double]
                    [--init INIT] [--num-samples NUM_SAMPLES]
                    [--sampling SAMPLING] [--noise NOISE]
                    [--laziness LAZINESS] [--lr-decay LR_DECAY]
                    [--min-lr MIN_LR] [-cc] [-rsa RESCALE_A] [-rsc RESCALE_C]
                    num_hidden_units [num_hidden_units ...]

positional arguments:
  num_hidden_units      The number of units in the hidden layer

optional arguments:
  -h, --help            show this help message and exit
  --output OUTPUT, -o OUTPUT
                        File to output the model to
  --geometry GEOMETRY, -g GEOMETRY
                        Which geometry to fit
  --epochs EPOCHS, -ne EPOCHS
                        Number of fitting iterations
  --learning-rate LEARNING_RATE, -lr LEARNING_RATE
                        Step size for gradient descent
  --seed SEED           Random seed
  --save-every SAVE_EVERY
                        Save the weights ever k iterations
  --last-layer-bias, -b
                        Use a bias term in the last layer
  --device DEVICE       Which device to store the model on
  --double, -dbl
  --init INIT           Initialize weights using 'default', 'normal', 'one-
                        over-n', 'one-over-sqrt-n'
  --num-samples NUM_SAMPLES, -n NUM_SAMPLES
                        number of samples (s) to fit
  --sampling SAMPLING, -s SAMPLING
                        How to sample the interval. One of 'uniform' or
                        'random'
  --noise NOISE         Amount of noise to add
  --laziness LAZINESS   Type of laziness. One of 'pure', 'none', 'default'.
  --lr-decay LR_DECAY   exponential learning rate decay factor
  --min-lr MIN_LR       minimum learning rate
  -cc, --clamp-c        clamp the parameter c to +/-1
  -rsa RESCALE_A, --rescale-a RESCALE_A
                        rescale the a and b parameters
  -rsc RESCALE_C, --rescale-c RESCALE_C
                        rescale the c parameter
```

By default, `fit_model.py` will generate an `out.pt` file (which can be changed with the `-o` flag.

Example: 
```
python fit_model.py 1000 --geometry parabola -ne 10000 -n 20 --noise 0.0 -lr 1e-4 --lr-decay 1e-5 --min-lr 1e-7 --sampling 'uniform' --laziness 'default' --save-every 10 --seed 1234567 
```

#### Generating Plots
To generate plots, there are four important scripts:

The first is `plot_reconstruction.py` which plots the network function. It can be run as follows:
```
usage: plot_reconstruction.py [-h] [--init] [--plot-knots] [--lsq]
                              [--output OUTPUT]
                              state_file

positional arguments:
  state_file            Fitted model (out.pt) generated with fit_model.py

optional arguments:
  -h, --help            show this help message and exit
  --init                If set, plot the state at initialization
  --plot-knots          Plot the knots
  --lsq                 Plot the result after doing a least squares kernel fit
  --output OUTPUT, -o OUTPUT
                       Filename to save the figure to. Will display the figure if no file name is set.
```

Example:
```
python plot_reconstruction.py out.pt  
```

The second is `plot_phase.pt` which generates uv plots of the network function. It can be run as follows:

```
usage: plot_phase.py [-h] [-e EPOCH] [-s SCALE] [-o OUTPUT] state

positional arguments:
  state                 Fitted model (out.pt) generated with fit_model.py

optional arguments:
  -h, --help            show this help message and exit
  -e EPOCH, --epoch EPOCH
                        Which epoch to plot the model at
  -s SCALE, --scale SCALE
                        Scale factor for the figure
  -o OUTPUT, --output OUTPUT
                        Filename to save the figure to. Will display the figure if no file name is set.
```

The third script is `plot_neuron_trajectories.py` which plots the trajectories of neurons in uv space. It can be run as follows:
```
usage: plot_neuron_trajectories.py [-h] [--save SAVE] [--plot-residual]
                                   [--samples SAMPLES]
                                   [--print-every PRINT_EVERY]
                                   [--num-trajectories NUM_TRAJECTORIES]
                                   [-o OUTPUT] [-s SCALE]
                                   state_file

positional arguments:
  state_file

optional arguments:
  -h, --help            show this help message and exit
  --save SAVE           Save as a video
  --plot-residual       Plot the residual
  --samples SAMPLES     Number of curve samples
  --print-every PRINT_EVERY
                        Print a message every k epochs
  --num-trajectories NUM_TRAJECTORIES, -nt NUM_TRAJECTORIES
                        Number of trajectory lines to plot
  -o OUTPUT, --output OUTPUT
                        Filename to save the figure to. Will display the
                        figure if no file name is set
  -s SCALE, --scale SCALE
                        Scale factor for the figure
```

Example:
```
python plot_neuron_trajectories.py out.pt 
```

Finally, `plot_reduced_gradient.py` plots the vector field corresponding to the reduced gradient. It can be run as follows:
```
usage: plot_reduced_gradient.py [-h] [-e EPOCH] [-s SCALE] [-n NUM_SAMPLES]
                                state

positional arguments:
  state                 Fitted model (out.pt) generated with fit_model.py

optional arguments:
  -h, --help            show this help message and exit
  -e EPOCH, --epoch EPOCH
                        Which epoch to plot the model at
  -s SCALE, --scale SCALE
                        Scale factor for the figure
  -n NUM_SAMPLES, --num-samples NUM_SAMPLES
                        Number of neurons to plot
```

Example:
```
python plot_reduced_gradient.py out.pt 
```