import numpy as np
import pickle
from simulate_NN import simulate

np.random.seed(123)

# Read CIFAR-10 data
Xcifar = []
for j in range(1,6):
  with open('cifar-10-batches-py/data_batch_%d' % j,'rb') as fo:
    dat = pickle.load(fo, encoding='bytes')
    Xcifar.append(np.transpose(dat[b'data']))
Xcifar = np.hstack(Xcifar)

# Simulation parameters
n = 5000
d = [10000,10000,10000,10000,10000]

inds = sorted(np.random.choice(Xcifar.shape[1],n,replace=False))
X0 = Xcifar[:,inds]

simulate(X0,d,remove_PCs=0,xlim=[-3,40],ylim=[0,100],ylimNTK=[0,500],bins=100,fname='CIFAR_raw')
simulate(X0,d,remove_PCs=10,xlim=[-3,40],ylim=[0,100],ylimNTK=[0,500],bins=100,fname='CIFAR_pruned')

