#!/usr/bin/env python
# coding: utf-8

# In[168]:


import jax
import jax.numpy as jnp
from jax import grad, hessian, jacobian, random
import matplotlib.pyplot as plt
from jax.scipy.special import logsumexp

# Set up the random key
key = random.PRNGKey(0)


# Generate the synthetic dataset
x = jnp.linspace(0, 1,30).reshape(-1, 1)
true_function = lambda x: jnp.where(x < 0.5, x * 2, x * -2 + 2)
y_true = true_function(x)
x = x - 0.5
y_noisy = y_true + 0.5 * random.normal(key, x.shape)

# ReLU activation
def relu(x):
    return jnp.maximum(0, x)

# Predict function
def predict(params, x,scale_factor=1.0):
    W1, b1, W2, b2 = params
    hidden = relu(jnp.dot(x, W1) + b1)
    return jnp.dot(hidden, W2)*scale_factor + b2

def predict_alt(params, x,scale_factor=1.0):
    W1, b1, W2, b2 = params
    hidden = relu(jnp.dot(x, W1) + b1*W1)
    return jnp.dot(hidden, W2)*scale_factor + b2



# Design matrix function
def feature_map(params, x):
    W1, b1, W2, b2 = params
    hidden = relu(jnp.dot(x, W1) + b1)
    return hidden

def feature_map_alt(params, x):
    W1, b1, W2, b2 = param
    hidden = relu(jnp.dot(x, W1) + b1*W1)
    return hidden

# Loss function
def loss(params, x, y,scale_factor=1.0):
    pred = predict(params, x,scale_factor)
    return jnp.mean((pred - y) ** 2)

def loss_alt(params, x, y,scale_factor=1.0):
    pred = predict_alt(params, x,scale_factor)
    return jnp.mean((pred - y) ** 2)

mse_loss = lambda y_pred, y_true: ((y_pred - y_true) ** 2).mean()

# Gradient of the loss function
loss_grad = jax.jit(grad(loss))
loss_grad_alt = jax.jit(grad(loss_alt))

loss_hessian = jax.jit(hessian(loss))
loss_hessian_alt = jax.jit(hessian(loss_alt))

loss_jacobian = jax.jit(jacobian(grad(loss)))
loss_jacobian_alt = jax.jit(jacobian(grad(loss)))


def initialize_weights(hidden=1000,seed=0,FLAG='Standard'):
    #FLAG = 'Mup'
    #['Standard','NTK','Meanfield','Mup']
    input_size = 1
    hidden_size = hidden
    output_size = 1
        # Set up the random key
    key = random.PRNGKey(seed)

    #W1 = np.ones(shape=(input_size, hidden_size))
    key, subkey = random.split(key)
    W1 = jnp.sign(jax.random.normal(key,shape=(input_size, hidden_size))) #* np.sqrt(2. / input_size)

    b1 = jnp.zeros(hidden_size)#
    #b1 = np.random.randn(hidden_size)
    key, subkey = random.split(key)
    b1 = jax.random.uniform(key, shape=(hidden_size,), minval= -0.5, maxval=0.5)

    key, subkey = random.split(key)
    W2 = jax.random.normal(key,shape=(hidden_size, output_size)) #* np.sqrt(2. / hidden_size)
    #W2 = jnp.zeros(shape=(hidden_size, output_size))
    if FLAG == 'Standard' or FLAG == 'Mup':
        W2 = W2 * np.sqrt(2. / hidden_size)
    b2 = jnp.zeros(output_size)#
    #b2 = np.random.randn(output_size)
    return W1,b1,W2,b2


def train_and_plot(epochs, learning_rate,hidden=1000,FLAG='Standard', LastLayerOnly=False,TrainFeatures=True, AltParams=False):

    input_size = 1
    hidden_size = hidden
    output_size = 1

    W1,b1,W2,b2 = initialize_weights(hidden_size,FLAG=FLAG)
    params = [W1, b1, W2, b2]

    if AltParams:
        takegrad = loss_grad_alt
        makepred = predict_alt
    else:
        takegrad = loss_grad
        makepred = predict


    if FLAG == 'NTK' or FLAG == 'Mup':
        scale_factor = jnp.sqrt(2./hidden_size)
    elif FLAG == 'Meanfield':
        scale_factor = 1./hidden_size
    else:  # standard
        scale_factor = 1.0

    # Training parameters
#     epochs = 10000
#     learning_rate = 0.001
    loss_history = []
    mse_history = []



    # Training loop using Full Batch Gradient Descent
    for epoch in range(epochs):
        grads = takegrad(params, x, y_noisy,scale_factor=scale_factor)
        # Update parameters
        if not LastLayerOnly:
            W1 -= learning_rate * grads[0]
            if TrainFeatures:
                b1 -= learning_rate * grads[1]
        W2 -= learning_rate * grads[2]
        b2 -= learning_rate * grads[3]
        params = [W1,b1,W2,b2]
    #    params = [param - learning_rate * grad for param, grad in zip(params, grads)]


        if epoch%100 == 0:
            y_pred = makepred(params, x,scale_factor=scale_factor)
            loss = mse_loss(y_pred, y_noisy)
            loss_history.append(loss)
            mse = mse_loss(y_pred,y_true)
            mse_history.append(mse)


    # Evaluate model on all data
    y_pred = makepred(params, x,scale_factor=scale_factor)


    xfine = jnp.linspace(0, 1,1000).reshape(-1, 1) - 0.5

    y_pred2 = makepred(params, xfine,scale_factor=scale_factor)

    # Plotting fitted labels
    #plt.subplot(1, 2, 2)

    #plt.figure(figsize=(12, 4))

    #plt.subplot(1, 2, 1)


    
    plt.figure(figsize=(6, 4))
    plt.rcParams.update({'font.size': 12})
    plt.plot(x, y_true, label='True Function', color='blue')
    plt.scatter(x, y_noisy, marker='o', label='Noisy Labels', color='grey', alpha=0.5)
    plt.scatter(x, y_pred, marker='.',label='Fitted Labels', color='green', alpha=1.0)
    plt.plot(xfine, y_pred2,label='Fitted function', color='green', alpha=1.0)
    plt.legend()
    plt.title(r'Trained ReLU NN with $\eta$='+str(round(learning_rate, 3)))
    
    plt.savefig('func_eta='+str(round(learning_rate, 3))+".pdf",bbox_inches='tight')

    #plt.subplot(1, 2, 2)
    
    plt.figure(figsize=(6, 4))
    
    loss_true_func = [0.25 for a in loss_history]
    
    
    plt.rcParams.update({'font.size': 12})
    # Plotting learning curves
    plt.semilogy(loss_history, label='train Loss')
    plt.semilogy(mse_history,label='MSE vs truth')
    plt.semilogy(loss_true_func,':k',label=r'$\sigma^2$')
    plt.xlabel('Iterations (in 100s)')
    plt.ylabel('Loss')
    plt.legend()
    plt.title(r'Learning Curves:  $\eta$='+str(round(learning_rate, 3)))

    plt.savefig('learning_curve_eta='+str(round(learning_rate, 3))+".pdf",bbox_inches='tight')
    plt.show()

    #plt.savefig('stable_minima_eta='+str(round(learning_rate, 3))+".pdf")

    return params, y_pred, loss_history, mse_history




import numpy as np

def least_square_NN_fit(hidden=1000,verbose=True):

    input_size = 1
    hidden_size = hidden
    output_size = 1

    W1,b1,W2,b2 = initialize_weights(hidden_size)
    params = [W1, b1, W2, b2]

    A = feature_map(params, x)
    A_with_bias = jnp.concatenate([A, jnp.ones(shape=(A.shape[0],1))], axis=1)
    A_with_bias.shape

    coeffs = np.linalg.lstsq(A_with_bias, y_noisy,rcond=None)
    coeffs = coeffs[0]

    y_pred = np.dot(A_with_bias,coeffs)

    params= [W1, b1, coeffs[0:-1], coeffs[-1]]

    if verbose:
        xfine = jnp.linspace(0, 1,1000).reshape(-1, 1) - 0.5

        y_pred2 = predict(params, xfine)

        plt.figure(figsize=(6, 4))
        plt.plot(x, y_true, label='True Function', color='blue')
        plt.scatter(x, y_noisy, marker='o', label='Noisy Labels', color='grey', alpha=0.5)
        plt.scatter(x, y_pred, marker='.',label='Fitted Labels', color='green', alpha=1.0)
        plt.plot(xfine, y_pred2,label='Fitted function', color='green', alpha=1.0)
        plt.ylim([-0.7,1.7])
        plt.legend()
        plt.title('Interpolating Solution h = '+str(hidden))
        plt.savefig('interpolate_h='+str(hidden)+'.pdf')
    return params


# In[99]:





# In[162]:


W1,b1,W2,b2 = initialize_weights(100)
params = [W1, b1, W2, b2]

plt.figure(figsize=(6, 6))
plt.rcParams.update({'font.size': 16})
A = feature_map(params, x)
plt.plot(x,A)
plt.grid()
plt.title('Random basis functions at initialization')
plt.savefig('basis_functions.pdf')
plt.show()


# In[3]:


hidden_list = [30,60,100,200,500,1000]
for h in hidden_list:
    params=least_square_NN_fit(h)


# In[169]:


import pickle

lrlist = [0.5,0.4,0.3,0.2,0.1,0.05, 0.02, 0.01]
eplist = [int(10000/lr) for lr in lrlist]

results = []

for lr,ep in zip(lrlist,eplist):
    [params,y_pred,loss_history, mse_history] = train_and_plot(ep, lr,hidden=100)
    results.append([params,y_pred,loss_history, mse_history ])

with open('results_sp2.pickle', 'wb') as handle:
    pickle.dump(results, handle, protocol=pickle.HIGHEST_PROTOCOL)


# In[159]:


# with open('results_sp2.pickle', 'rb') as handle:
#     results = pickle.load(handle)



    
plt.figure(figsize=(6, 6))
plt.rcParams.update({'font.size': 16})
plt.plot(x, y_true, '-',label='True Function', color='black')
plt.scatter(x, y_noisy, marker='o', label='Noisy Labels', color='grey', alpha=0.5)

xfine = jnp.linspace(0, 1,1000).reshape(-1, 1) - 0.5

for lr,ep,res in zip(lrlist,eplist,results):
    if lr == 0.4 or lr ==0.01:
        params = res[0]
        y_pred = res[1]
        y_pred2 = predict(params, xfine)
        plt.plot(xfine, y_pred2,'--',label=r'$\eta$='+str(round(lr, 3)),alpha=1.0)

plt.grid()
plt.title(r'GD-trained ReLU NN with step size $\eta$')

plt.legend()
plt.savefig("Fitted_functions_vs_eta.pdf")
plt.show()




colorlist = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']


plt.figure(figsize=(6, 6))
i = 0
for lr,ep,res in zip(lrlist,eplist,results):
    if lr == 0.4 or lr ==0.01:
        
        loss_history = res[2]
        mse_history = res[3]
        
        col = colorlist[i]
        iter_list = np.linspace(0,len(loss_history)*100*lr,len(loss_history))
        plt.semilogy(iter_list,loss_history,'--', color=col,label=r'train Loss $\eta$='+str(round(lr, 3)))
        plt.semilogy(iter_list,mse_history,color=col,label=r'MSE vs truth $\eta$='+str(round(lr, 3)))
        i +=1
plt.xlabel(r'Step size $\eta \times$ Iterations ')
plt.ylabel('Loss')
plt.title(r'Learning Curves')
plt.grid()
plt.legend()
plt.savefig("Learning_curves_vs_eta.pdf")
plt.show()




# In[144]:


# Generating illustration of complexity as a function of 1/\eta

def TV1(f):
    return np.sum(abs(np.diff(f.T,n=2)))

def TV1_l0(f):
    return np.sum(abs(np.diff(f.T,n=2))>0.000001)


plt.figure(figsize=(6, 6))
plt.rcParams.update({'font.size': 16})

xfine = jnp.linspace(0, 1,1000).reshape(-1, 1) - 0.5
n= 30
TV1_list = []
TV1_l0_list = []
MSE_list = []

for lr,ep,res in zip(lrlist,eplist,results):
    params = res[0]
    y_pred = res[1]
    y_pred2 = predict(params, xfine)
    TV1_list.append(n*TV1(y_pred2)) 
    TV1_l0_list.append(TV1_l0(y_pred2))
    MSE_list.append(res[3][-1])

oneoverlr_list = [1/a for a in lrlist]
bound_list = [1/a - 0.5 + 0.25 + np.sqrt(b) for a,b in zip(lrlist,MSE_list)]


tt = len(lrlist)
plt.loglog(oneoverlr_list[:tt],bound_list[:tt],'-o',label=r'$1/\eta-1/2 + \sigma + \sqrt{MSE}$')
plt.loglog(oneoverlr_list[:tt],TV1_list[:tt],'-x',label=r'1st order Total Variation')
plt.loglog(oneoverlr_list[:tt],(np.array(TV1_l0_list)[:tt]+1),':s',label=r'# of Linear Pieces')
plt.xlabel(r'$1/\eta$')
plt.grid()
plt.title(r'Complexity of Fitted ReLU NN')

plt.legend()
plt.savefig("complexity_vs_eta.pdf")
plt.show()


# In[163]:


# Check if the learned knots are exactly on the input

    
def find_knots(params):
    W1, b1, W2, b2 = params
    #print(W1.flatten().shape, b1.shape, W2.shape, b2.shape)
    knots = -b1/W1.flatten()
    #print(knots.shape)
    idx = np.argsort(knots)
    

    return knots[idx], (W1.flatten()*W2.flatten())[idx]




def knots_quantiles(x):
    q = [0.05, 0.25, 0.5, 0.75, 0.95]
    return np.quantile(x, q)


def lp_norm(x,p=1):
    return np.sum(np.abs(x)**p)

knots_stats = []
sparsity_stats = []

min_dist_min = []
min_dist_1st = []
min_dist_median = []

oneoverlr_list = [1/a for a in lrlist]


for lr,ep,res in zip(lrlist,eplist,results):
    params = res[0]
    knots,coeffs = find_knots(params)
    
    knots_stats.append(knots_quantiles(knots))
    sparsity_stats.append([np.sum(np.abs(coeffs)),lp_norm(coeffs,0.5), lp_norm(coeffs,0.2)])
    dist_from_input = []
    for xknot in knots:
        dist_from_input.append(np.min(np.abs(x - xknot)))
    min_dist_min.append(np.min(np.array(dist_from_input)))
    min_dist_1st.append(np.quantile(np.array(dist_from_input),[0.01]))
    min_dist_median.append(np.median(np.array(dist_from_input)))


plt.figure(figsize=(6, 6))
plt.rcParams.update({'font.size': 16})

tt = len(lrlist)
plt.loglog(oneoverlr_list[:tt],np.array(min_dist_min),'o-',label=r'Minimum')
plt.loglog(oneoverlr_list[:tt],np.array(min_dist_1st),'o-',label=r'1st percentile')
plt.loglog(oneoverlr_list[:tt],np.array(min_dist_median),'o-',label=r'Median')
#plt.loglog(oneoverlr_list[:tt],TV1_list[:tt],'-x',label=r'1st order Total Variation')
#plt.loglog(oneoverlr_list[:tt],(np.array(TV1_l0_list)[:tt]+1),':s',label=r'# of Linear Pieces')
plt.xlabel(r'$1/\eta$')
plt.ylabel(r'Distance to the closest input x')
plt.grid()
plt.title(r'Learned knots to the closest input knot')
plt.legend()
plt.savefig("closest_knots.pdf",bbox_inches='tight')
plt.show()


plt.figure(figsize=(6, 6))
plt.rcParams.update({'font.size': 16})
tt = len(lrlist)
plt.semilogx(oneoverlr_list[:tt],np.array(knots_stats),'.-')
plt.xlabel(r'$1/\eta$')
plt.ylabel(r'location on input axis')
plt.grid()
plt.title(r'Quantiles of the Learned knots')
plt.legend(["0.05 quantile", "0.25 quantile","0.5 quantile","0.75 quantile","0.95 quantile"])
plt.ylim([-3,3])
plt.savefig("quantiles_knots.pdf",bbox_inches='tight')
plt.show()



plt.figure(figsize=(6, 6))
plt.rcParams.update({'font.size': 16})

tt = len(lrlist)
plt.semilogx(oneoverlr_list[:tt],np.array(sparsity_stats),'o-') 
#plt.loglog(oneoverlr_list[:tt],TV1_list[:tt],'-x',label=r'1st order Total Variation')
#plt.loglog(oneoverlr_list[:tt],(np.array(TV1_l0_list)[:tt]+1),':s',label=r'# of Linear Pieces')
plt.xlabel(r'$1/\eta$')
plt.ylabel(r'Sparsity')
plt.grid()
plt.title(r'Sparsity of the learned coefficients')
plt.legend([r"$L_1$-norm", r"$L_p$ norm $p=0.5$", r"$L_p$ norm $p=0.2$"])
plt.savefig("sparsity_vs_eta.pdf",bbox_inches='tight')
plt.show()


# plt.figure()

# plt.hist(knots.reshape((1,100)))
# coeffs[0]

# dist_from_input = []
# for xknot in knots:
#     dist_from_input.append(np.min(np.abs(x - xknot)))

# plt.figure()
# plt.hist(dist_from_input)

# print(np.min(np.array(dist_from_input)))


# In[158]:


def get_hessian_matrix(params):
    h = len(params[1])
    rrrr = loss_hessian(params, x, y_noisy)

    H_matrix =  np.zeros(shape=(3*h+1,3*h+1))
    bb = [0,h,2*h,3*h,3*h+1]
    for i in range(4):
        for j in range(4):
            H_matrix[bb[i]:bb[i+1],bb[j]:bb[j+1]] = rrrr[i][j].squeeze().reshape((bb[i+1]-bb[i],bb[j+1]-bb[j]))
    return H_matrix
        
lamb_max_list = []
TV1_list = []
xfine = jnp.linspace(0, 1,1000).reshape(-1, 1) - 0.5
n= 30

for lr,ep,res in zip(lrlist,eplist,results):
    params = res[0]
    rrrr = loss_hessian(params, x, y_noisy)
    H=get_hessian_matrix(params)
    lamb_max_list.append(np.max(np.linalg.eigvalsh(H)))
    
    y_pred = res[1]
    y_pred2 = predict(params, xfine)
    TV1_list.append(n*TV1(y_pred2)) 

bound_list = [1/a - 0.5 + 0.25 + np.sqrt(b) for a,b in zip(lrlist,MSE_list)]

    
plt.figure(figsize=(6, 6))
plt.rcParams.update({'font.size': 16})

tt = len(lrlist)
plt.loglog(oneoverlr_list[:tt],np.array(lamb_max_list)/2,'.-',label=r'$\lambda_{\max}$(Hessian)') 
#plt.loglog(oneoverlr_list[:tt],TV1_list[:tt],'-x',label=r'1st order Total Variation')
#plt.loglog(oneoverlr_list[:tt],(np.array(TV1_l0_list)[:tt]+1),':s',label=r'# of Linear Pieces')

plt.loglog(oneoverlr_list[:tt],bound_list[:tt],'-o',label=r'$1/\eta-1/2 + \sigma + \sqrt{MSE}$')
plt.loglog(oneoverlr_list[:tt],TV1_list[:tt],'-x',label=r'1st order Total Variation')
plt.xlabel(r'$1/\eta$')
plt.grid()
plt.title(r'Our upper bound on $\lambda_{\max}(Hessian)$')
plt.legend()

plt.savefig("flatness_vs_eta.pdf")

plt.show()


# In[170]:


for lr,ep,res in zip(lrlist,eplist,results):
    params = res[0]
    A = feature_map(params, x)
        
    plt.figure(figsize=(6, 4))
    plt.rcParams.update({'font.size': 12})
    plt.plot(x,A)
    plt.title(r'Learned basis functions ($\eta$='+str(round(lr, 2))+")")
    plt.grid()
    plt.savefig("basis_function_eta="+str(lr)+"short.pdf",bbox_inches='tight')
    plt.show()




# In[154]:


for lr,ep,res in zip(lrlist,eplist,results):
    if lr == 0.2:
        params = res[0]
        H_matrix = get_hessian_matrix(params)   
        
plt.figure(figsize=(6,6))
plt.imshow(np.log(np.abs(H_matrix)+0.0001))
