from utils import read_file, generate_u
import numpy as np
import matplotlib.pyplot as plt
import random
import scipy.io as scio

font = {'family':'Times New Roman', 'weight':'normal','size':20,}
legend_font = {'family':'Times New Roman', 'weight':'normal','size':15,}

cmap = { 0:'c',1:'b',2:'y',3:'g',4:'r', 5:'k'}

d = 5000
kstar=5
k = 74*kstar
vec = np.ones(d)
vec[:-(2*k+kstar)] = 0
v = 1/100 * 1/(d)*np.arange(1, d+1)
v[:-70*kstar] = 0  # so that the gradient at x* is closer to 0
def f(x):
    y = (x - v) * (vec)
    return 0.5 * (y) @ (y).T 

sol = np.zeros(d)
sol[-kstar:] = np.arange(1, d+1)[-kstar:]*1/(d)

def plot_figure(fn, qs, etas, s2s, nm, num, axes, iter=200000, r=0.1, lam=10, k=2, miu=1e-4, SEED=1, ccurves=[], ddists=[]):
    q = qs
    s2 = s2s
    line = 0
    random.seed(SEED)
    np.random.seed(SEED)
    rs = np.random.RandomState(32)
    x = np.ones(d)/d  # similar reasoning as the init in appendix G
    x[-kstar:] = 0
    dist = np.linalg.norm(x - sol)
    performance_ZOHT = [[] for i in range(q.shape[0]*s2.shape[0])]
    for n in range(q.shape[0]* s2.shape[0]):
        performance_ZOHT[n] = [[0, f(x), dist]]
    for q_i in range(q.shape[0]):
        for s2_i in range(s2.shape[0]):
                
                x_new = x
                judge = 1
                t = 1
                for m in range(np.int(iter/q[q_i])):
                    gradient_i = np.zeros((q[q_i],num))
                    current_eta = etas[q_i, s2_i]
                    for i in range(q[q_i]):
                        u = generate_u(s2[s2_i], num)
                        func_1 = f(x_new)
                        func_2 = f(x_new+miu*u)
                        gradient_i[i] = num/miu*(func_2-func_1)*u
                    gradient = np.sum(gradient_i,axis=0)/q[q_i]
                    x_new = x_new - current_eta * gradient
                    top_k_idx = np.argsort(-np.abs(x_new))[0:k]
                    x_temp = np.zeros_like(x)
                    x_temp[top_k_idx] = x_new[top_k_idx]
                    x_new = x_temp

                    dist = np.linalg.norm(x_new - sol)
                    performance_ZOHT[line].append([t*q[q_i],f(x_new), dist])
                    judge = f(x_new)
                    print('Estimated f(x_k): %f  iters: %d' %
                        (judge, t*q[q_i]))
                    t=t+1
                ccurves[q_i].append(np.array(performance_ZOHT[line]))
                if len(q) == 1:
                    leg = f'$s_2$={s2[s2_i]}'
                elif len(s2) == 1:
                    leg = f'q={q[q_i]}'
                line = line + 1



if __name__ == "__main__":

    qs=np.array([1, 20, 200, 500, int(d/2), int(d)])
    s2s = np.array([d])
    s = 2*k +kstar
    epsilon_f = 2*d /(qs[:, None] *(s2s[None]+2)) *   (( s- 1) * (s2s[None] - 1) /(d-1) + 3) + 2


    etas = 1*np.ones((qs.shape[0], s2s.shape[0]))*   1/ ( 4 * epsilon_f + 1 ) 
    curves = dict()
    dists = dict()
    for i, _ in enumerate(qs):
        curves[i], dists[i] = [], []
    for i in range(1):
        plot_figure(fn = None,  qs=qs, etas=etas, s2s=s2s,nm='am5_q', num=d, axes=[0, 6000, 0.04, 0.106], miu=1e-5, k=k, SEED=i, ccurves=curves, ddists=dists)



    plt.figure()
    for i, q in enumerate(qs): 
        avg = np.zeros(len(curves[i][0]))
        for j, _ in enumerate(curves[i]):
            avg += curves[i][j][:, 1]
        avg /= len(curves[i])
        plt.plot(curves[i][0][:, 0], avg[:], color = cmap[i], label=f"q={q}")
    plt.xlabel('Function Evaluations', font)
    plt.ylabel('$f(x)$', legend_font)
    plt.legend(prop=legend_font)
    plt.savefig(f'result/quadric_f_final_generation_bis.pdf')

    print(curves)
    plt.figure()
    for i, q in enumerate(qs): 
        avg = np.zeros(len(curves[i][0]))
        for j, _ in enumerate(curves[i]):
            avg += curves[i][j][:, 2]
        avg /= len(curves[i])
        plt.plot(curves[i][0][:, 0], avg[:], color = cmap[i], label=f"q={q}")
    plt.xlabel('Function Evaluations', legend_font)
    plt.ylabel('$\|x - x^*\|$', legend_font)
    plt.legend(prop=legend_font)
    plt.savefig(f'result/quadric_dist_final_generation_bis.pdf')
