import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import pickle
import os
from predict_spectrum import CK, NTK

font = {'size': 16}

matplotlib.rc('font', **font)

# Sigmoid Nonlinearity
def sigma(x):
  return np.arctan(x)/np.sqrt(0.4497)

def dsigma(x):
  return 1/(1+x**2)/np.sqrt(0.4497)

a = 1.11185 # E[sigma'(xi)^2]
b = 0.977755 # E[sigma'(xi)]

# Helper function
def cross(X):
  return np.dot(np.transpose(X),X)

# Histogram eigenvalues of each X'X
def histeig(eigs,ax,xgrid=None,dgrid=None,xlim=None,ylim=None,bins=50,title=None):
  if xlim is not None:
    eigs = eigs[np.nonzero(eigs <= xlim[1])[0]]
    h = ax.hist(eigs,bins=np.linspace(xlim[0],xlim[1],num=bins))
  else:
    h = ax.hist(eigs,bins=bins)
  if xgrid is not None:
    space = h[1][1]-h[1][0]
    ax.plot(xgrid,dgrid*len(eigs)*space,'r',linewidth=2)
  ax.set_title(title)
  ax.set_xlim(xlim)
  if ylim is None:
    ax.set_ylim([0,max(h[0])*1.5])
  else:
    ax.set_ylim(ylim)
  return ax

def summarize_input(X0,xlim=None,ylim=None,fname=''):
  Sigma = cross(X0)
  f, ax = plt.subplots(1,1)
  eigs = np.linalg.eigvalsh(cross(X0))
  histeig(eigs,ax,xlim=xlim,ylim=ylim,title='Input data spectrum')
  plt.savefig('%s_X0.png' % fname)
  return f

# Ws = [W1, W2, ..., WL, w]
# Xs = [X0, X1, ..., XL]
# d = [d0, d1, ..., dL]
def compute_NTK(Ws,Xs,d):
  L = len(Xs)-1
  n = Xs[0].shape[1]
  Ds = [[]]
  for l in range(L):
    Ds.append(dsigma(np.dot(Ws[l],Xs[l])))
  KNTK = cross(Xs[L])
  for l in range(1,L+1):
    XtX = cross(Xs[l-1])
    S = np.zeros((d[l],n))
    for i in range(n):
      s = Ws[-1].reshape(-1)/np.sqrt(d[L])
      for k in range(L,l-1,-1):
        s = Ds[k][:,i]*s
        if k > l:
          s = np.dot(np.transpose(Ws[k-1]),s)/np.sqrt(d[k-1])
      S[:,i] = s
    KNTK += cross(S) * XtX
  return KNTK

def simulate(X0,d,remove_PCs=0,xlim=None,ylim=None,ylimNTK=None,bins=50,fname=''):
  n = X0.shape[1]
  d0 = X0.shape[0]
  L = len(d)
  d = np.array([d0] + list(d))
  X0 = X0 - X0.mean(axis=0)
  X0 = X0 / X0.std(axis=0) / np.sqrt(d0)
  if remove_PCs > 0:
    u,s,vh = np.linalg.svd(X0)
    for i in range(remove_PCs):
      X0 -= s[i]*np.outer(u[:,i],vh[i,:])
  X0 = X0 - X0.mean(axis=0)
  X0 = X0 / X0.std(axis=0) / np.sqrt(d0)
  summarize_input(X0,xlim=xlim,ylim=ylim,fname=fname)
  # Compute X_1,...,X_L
  print('Computing forward propagation')
  if os.path.exists('%s_matrices.pkl' % fname):
    dat = pickle.load(open('%s_matrices.pkl' % fname, 'rb'))
    Xs = dat['Xs']
    KNTK = dat['NTK']
  else:
    Ws = []
    Xs = [X0]
    for l in range(1,L+1):
      W = np.random.normal(size=(d[l],d[l-1]))
      Ws.append(W)
      X = sigma(np.dot(W,Xs[l-1]))/np.sqrt(d[l])
      Xs.append(X)
    w = np.random.normal(size=d[L])
    Ws.append(w)
    f = np.dot(np.transpose(w),Xs[L]) # This is final prediction; not used
  # Compare with prediction for spectrum of X'X
  spec = np.linalg.eigvalsh(cross(Xs[0]))
  gamma = n/np.array(d[1:])
  for l in range(1,L+1):
    print('Computing CK prediction, layer %d' % l)
    f, ax = plt.subplots(1,1)
    eigs = np.linalg.eigvalsh(cross(Xs[l]))
    if xlim is None:
      xgrid = np.linspace(max(eigs)*(-0.05),max(eigs)*1.2,num=1000)
    else:
      xgrid = np.linspace(xlim[0],xlim[1],num=1000)
    dgrid = CK(l,gamma[:l],b,xgrid,spec=spec)
    histeig(eigs,ax,xgrid=xgrid,dgrid=dgrid,xlim=xlim,ylim=ylim,bins=bins,title='CK spectrum, layer %d' % l)
    plt.savefig('%s_X%d.png' % (fname,l))
  # Compute NTK
  print('Computing NTK')
  if not os.path.exists('%s_matrices.pkl' % fname):
    KNTK = compute_NTK(Ws, Xs, d)
  # Compare with prediction for spectrum of NTK
  print('Computing NTK prediction')
  eigs = np.linalg.eigvalsh(KNTK)
  if xlim is None:
    xgrid = np.linspace(max(eigs)*(-0.05),max(eigs)*1.2,num=1000)
  else:
    xgrid = np.linspace(xlim[0],xlim[1],num=1000)
  dgrid = NTK(L,gamma,a,b,xgrid,spec=spec)
  f, ax = plt.subplots(1,1)
  if ylimNTK is not None:
    histeig(eigs,ax,xgrid=xgrid,dgrid=dgrid,xlim=xlim,ylim=ylimNTK,bins=bins,title='NTK spectrum')
  else:
    histeig(eigs,ax,xgrid=xgrid,dgrid=dgrid,xlim=xlim,ylim=ylim,bins=bins,title='NTK spectrum')
  plt.savefig('%s_NTK.png' % fname)
  mats = {'Xs':Xs, 'NTK':KNTK}
  pickle.dump(mats, open('%s_matrices.pkl' % fname, 'wb'))

