Understanding spiking networks through convex optimization - Figure 3

Imports

In [18]:
import numpy as np
import numba as nb
import matplotlib.pyplot as plt
import snn_cvx
from mpl_toolkits.mplot3d import axes3d
%matplotlib inline
In [14]:
import holoviews as hv
hv.extension('matplotlib')

Build train and test sets

In [3]:
# generate training data
P_train = 100
x_lim = 4

x1 = np.linspace(-x_lim, x_lim, P_train)
x2 = np.linspace(-x_lim, x_lim, P_train)
X1_train, X2_train = np.meshgrid(x1, x2)
    
def func(x1, x2):
    y = 0.3*(x1**2 + x2**2)
    return y

Y_targ_train = func(X1_train, X2_train)
Y_ravelled = Y_targ_train.ravel()
X_train = np.vstack((X1_train.ravel(), X2_train.ravel())).T
Y_train = Y_ravelled[:, None]
In [4]:
# create test data
P_test = 500
x_lim = 4
x1 = np.linspace(-x_lim, x_lim, P_test)
x2 = np.linspace(-x_lim, x_lim, P_test)
X1_test, X2_test = np.meshgrid(x1, x2)

Y_targ_test = func(X1_test, X2_test)
Y_ravelled_test = Y_targ_test.ravel()
X_test = np.vstack((X1_test.ravel(), X2_test.ravel())).T
Y_test = Y_ravelled_test[:, None]
In [22]:
# Plot a 3D surface
fig = plt.figure(figsize=(8, 9))
ax1 = fig.add_subplot(111, projection='3d')
ax1.set_xlabel('$x_1$')
ax1.set_ylabel('$x_2$')
ax1.set_zlabel('$y$')
# Plot a 3D surface
target_surface = ax1.plot_surface(X1_test, X2_test, Y_targ_test, alpha=0.5)
fig.suptitle('Target surface', fontsize=18)
plt.show()

Network parameters

In [5]:
# setting up dimensions and initial parameters
M = 1
K = 2
N = 50
leak = 2

# initialize my gamma matrix
random_state = np.random.RandomState(seed=3)
D_weights_init = random_state.rand(M, N)
D_weights_init = D_weights_init / np.linalg.norm(D_weights_init, axis=0)
G_weights_init = D_weights_init.copy().T
F_weights_init = random_state.randn(K, N).T
omega_init = -G_weights_init @ D_weights_init
thresholds_init = 2*random_state.rand(N) - 1

Before learning

We run the network with the inital parameters and compute an average readout for each input sample. Thus, we can get the SNN surface before learning.

Also, since the output dimension of the network is one, we can use a maxout function - refer to equation (10) of the paper - to get the network output (without discretization error) which is a faster method than running the SNN. We also check which neuron is active by looking at the argument that maximizes the function.

In [6]:
T = 4 # simulation time
dt = 3e-03 # time step
t_span = np.arange(0, T, dt)
num_bins = t_span.size
buffer_bins = int(1/dt)
buffer_zeros = int(buffer_bins/2)
x_sample = np.zeros((K, num_bins))

# initialize network parameters
D_weights = D_weights_init.copy()
G_weights = G_weights_init.copy()
F_weights = F_weights_init.copy()
omega = omega_init.copy()
thresholds = thresholds_init.copy()


y_readout = []

for data_index in range(X_train.shape[0]):
    x_sample[:, buffer_zeros:] = X_train[data_index, :][:, None]

    rates = snn_cvx.run_snn_trial(
        x_sample,
        F_weights,
        omega,
        thresholds,
        dt,
        leak,
    )
    
    y_readout += [np.copy(D_weights[0, :] @ rates)]
    
average_readouts = np.array(y_readout)[:, buffer_zeros + 500:].mean(axis=1)
In [23]:
# Plot a 3D surface
fig = plt.figure(figsize=(8, 9))
ax1 = fig.add_subplot(111, projection='3d')
ax1.set_xlabel('$x_1$')
ax1.set_ylabel('$x_2$')
ax1.set_zlabel('$y$')
# Plot a 3D surface
readout_surf_snn = ax1.plot_surface(X1_train, X2_train, average_readouts.reshape(Y_targ_train.shape), alpha=0.5)
fig.suptitle('SNN surface before learning', fontsize=18)
plt.show()
In [24]:
# plot contours and active-inactive neurons
active_neurons_init = np.zeros(X_test.shape[0]) * np.nan
y_predict_init = np.zeros(X_test.shape[0])
for i, x in enumerate(X_test):
    y_out_init, n_act = snn_cvx.run_maxout(x, F_weights_init, G_weights_init, thresholds_init)
    y_predict_init[i] = y_out_init
    active_neurons_init[i] = n_act
    
y_predict_init_reshaped = y_predict_init.reshape(Y_targ_test.shape)
active_neurons_init_reshaped = active_neurons_init.reshape(Y_targ_test.shape)

# make contour plots
zlim =14
ticks = [(i, i) for i in np.linspace(x_lim, -x_lim, 3)]
c_ticks = [(i, i) for i in np.linspace(0, zlim, 2)]

clim = (0, zlim)
bounds = (-x_lim, -x_lim, x_lim, x_lim)
nlevels = 40
cmap_neurons = 'glasbey_dark'
cmap_contour='gray'
alpha=1
    
img = hv.Image(y_predict_init_reshaped, kdims=['$x_1$', '$x_2$'], vdims='$y$', bounds=bounds).opts(
                                    cmap=cmap_contour, invert_zaxis=True, clims=clim)
levels = np.linspace(-zlim, zlim, nlevels)
img_contour_init = hv.operation.contours(img, group='Y', levels=levels).opts(
                                    xticks=3, yticks=3, colorbar=True, cmap=cmap_contour, 
                                    cbar_ticks=c_ticks, clim=clim, alpha=alpha, linewidth=1.5, 
                                    )

# show which neuron is active
img_nactive_init = hv.Image(active_neurons_init_reshaped, kdims=['$x_1$', '$x_2$'], vdims='$y$', bounds=bounds).opts(cmap=cmap_neurons, alpha=0.6)

# plot contours and active neurons
img_nactive_init*img_contour_init
Out[24]:

Train network parameters

In [27]:
T = 4 # simulation time
dt = 3e-03 # time step
t_span = np.arange(0, T, dt)
num_bins = t_span.size
buffer_bins = int(1/dt)
buffer_zeros = int(buffer_bins/2)
x_sample = np.zeros((K, num_bins))

# initialize network parameters
D_weights = D_weights_init.copy()
G_weights = G_weights_init.copy()
F_weights = F_weights_init.copy()
omega = omega_init.copy()
thresholds = thresholds_init.copy()

# run supervised learning
alpha_thresh_init = 1e-03
alpha_F_init = 1e-03
leak_thresh = 0.

num_epochs = 100
thresholds_array_fit = np.zeros((N, num_epochs))
F_weights_array_fit = np.zeros((N, K, num_epochs))
decrease_learning_rate = True

for epoch in range(num_epochs):
    print ('iteration: ',epoch+1)
    data_index_list = np.arange(X_train.shape[0])
    np.random.shuffle(data_index_list)
    
    if decrease_learning_rate:
        alpha_thresh = alpha_thresh_init * np.exp(-0.0001 * (epoch + 1))
        alpha_F = alpha_F_init * np.exp(-0.0001 * (epoch + 1))

    else:
        alpha_thresh = alpha_thresh_init
        alpha_F = alpha_F_init
    
    for data_index in data_index_list:
        x_sample[:, buffer_zeros:] = X_train[data_index, :][:, None]
        y_sample = Y_train[data_index, :]

        thresholds, F_weights = snn_cvx.update_weights(
            x_sample,
            y_sample,
            F_weights,
            G_weights,
            omega,
            thresholds,
            buffer_bins,
            dt,
            leak,
            leak_thresh,
            alpha_thresh,
            alpha_F,
            mu=0.,
            sigma_v=0.
        )
        
    thresholds_array_fit[:, epoch] = thresholds
    F_weights_array_fit[:, :, epoch] = F_weights
    

After learning

In [28]:
# run snn with learnt parameters
x_sample = np.zeros((K, num_bins))

# call learnt parameters
F_weights_fit = F_weights_array_fit[:, :, -1]
thresholds_fit = thresholds_array_fit[:, -1]

y_readout = []

for data_index in range(X_train.shape[0]):
    x_sample[:, buffer_zeros:] = X_train[data_index, :][:, None]

    rates = snn_cvx.run_snn_trial(
        x_sample,
        F_weights_fit,
        omega,
        thresholds_fit,
        dt,
        leak,
    )
    
    y_readout += [np.copy(D_weights[0, :] @ rates)]
    
average_readouts_fit = np.array(y_readout)[:, buffer_zeros + 500:].mean(axis=1)
In [29]:
# Plot a 3D surface
fig = plt.figure(figsize=(8, 9))
ax1 = fig.add_subplot(111, projection='3d')
ax1.set_xlabel('$x_1$')
ax1.set_ylabel('$x_2$')
ax1.set_zlabel('$y$')
# Plot a 3D surface
readout_surf_snn = ax1.plot_surface(X1_train, X2_train, average_readouts_fit.reshape(Y_targ_train.shape), alpha=0.5)
fig.suptitle('SNN surface after learning', fontsize=18)
plt.show()
In [30]:
# plot contours and active-inactive neurons after learning
active_neurons_fit = np.zeros(X_test.shape[0]) * np.nan
y_predict_fit = np.zeros(X_test.shape[0])
for i, x in enumerate(X_test):
    y_out_fit, n_act = snn_cvx.run_maxout(x, F_weights_fit, G_weights, thresholds_fit)
    y_predict_fit[i] = y_out_fit
    active_neurons_fit[i] = n_act
    
y_predict_fit_reshaped = y_predict_fit.reshape(Y_targ_test.shape)
active_neurons_fit_reshaped = active_neurons_fit.reshape(Y_targ_test.shape)

# make contour plots
zlim =14
ticks = [(i, i) for i in np.linspace(x_lim, -x_lim, 3)]
c_ticks = [(i, i) for i in np.linspace(0, zlim, 2)]

clim = (0, zlim)
bounds = (-x_lim, -x_lim, x_lim, x_lim)
nlevels = 40
cmap_neurons = 'glasbey_dark'
cmap_contour='gray'
alpha=1
    
img = hv.Image(y_predict_fit_reshaped, kdims=['$x_1$', '$x_2$'], vdims='$y$', bounds=bounds).opts(
                                    cmap=cmap_contour, invert_zaxis=True, clims=clim)
levels = np.linspace(-zlim, zlim, nlevels)
img_contour_fit = hv.operation.contours(img, group='Y', levels=levels).opts(
                                    xticks=3, yticks=3, colorbar=True, cmap=cmap_contour, 
                                    cbar_ticks=c_ticks, clim=clim, alpha=alpha, linewidth=1.5, 
                                    )

# show which neuron is active
img_nactive_fit = hv.Image(active_neurons_fit_reshaped, kdims=['$x_1$', '$x_2$'], vdims='$y$', bounds=bounds).opts(cmap=cmap_neurons, alpha=0.6)

# plot contours and active neurons
img_nactive_fit*img_contour_fit
Out[30]:
In [ ]: