import numpy as np
import warnings
import matplotlib

from algorithms import *
from utils import *

matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

warnings.simplefilter('ignore')

seed = 121212

d = 10
l = 10

c = 0.65

conv_smooth = ConvexNonSmooth(d)
alpha_single_gaus = lambda t: c * (1/d) * (t**(-1/2 - 1e-5)) 
alpha_single_sph = lambda t:  c * (1/d) *  (t**(-1/2 - 1e-5)) 


alpha_multi_sph = lambda t:  c * (l/d) * (t**(-1/2 - 1e-5)) 
alpha_multi_gaus = lambda t:  0.08 *(l/d) * (t**(-1/2 - 1e-5)) 


alpha_multi = lambda t:  c * (l/d) * (t**(-1/2 - 1e-5)) 

h = lambda t : (1/d**2) * (1.0/(t + 1))

T = 1000
reps = 10
x0 = np.full(d, 1.0, dtype=np.float64)
cgaus = RandomCenterGaussian(l = 1, alpha=alpha_single_gaus, h=h, seed=seed)
csph = RandomCenteredSphere(l = 1, alpha=alpha_single_sph, h=h, seed=seed)

csphm = RandomCenteredSphere(l = l, alpha=alpha_multi_sph, h=h, seed=seed)
cgausm = RandomCenterGaussian(l = l, alpha=alpha_multi_gaus, h=h, seed=seed)

our = RandomCenteredStructured(l = l, alpha= alpha_multi, h=h, seed=seed)


ris_cgaus = cgaus.optimize(conv_smooth, x0=x0.copy(), f_star=0, T=T, reps=reps)
ris_csph = csph.optimize(conv_smooth, x0=x0.copy(), f_star=0, T=T, reps = reps)

ris_csphm = csphm.optimize(conv_smooth, x0=x0.copy(), f_star=0, T=T, reps=reps)
ris_cgausm = cgausm.optimize(conv_smooth, x0=x0.copy(), f_star=0, T=T, reps=reps)

ris_our = our.optimize(conv_smooth, x0=x0.copy(), f_star=0, T=T, reps = reps)

ris_csph = process_result(ris_csph, 1, T)
ris_cgaus = process_result(ris_cgaus, 1, T) 

ris_csphm = process_result(ris_csphm, l, T)
ris_cgausm = process_result(ris_cgausm, l, T)

ris_our = process_result(ris_our, l, T)

labels = ['Single Gaussian', 
          'Single Spherical',
          'Multi Gaussian', 
          'Multi Spherical', 
          'Ours']


csph_mu,  csph_std= get_fvals(ris_csph, 1)
csphm_mu,  csphm_std = get_fvals(ris_csphm, 1)


cgaus_mu,  cgaus_std= get_fvals(ris_cgaus, 1)
cgausm_mu,  cgausm_std = get_fvals(ris_cgausm, 1)



means = [ris_cgaus[0], ris_csph[0], ris_cgausm[0], ris_csphm[0], ris_our[0]] 
stds =  [ris_cgaus[1], ris_csph[1], ris_cgausm[1], ris_csphm[1], ris_our[1]] 


plot_results('NonSmooth Convex Target', 
             means, stds, labels, "$f(x_k) - f(x^*)$", 
             legend=True, out_file="./conv_nonsmooth_comp.pdf")