import numpy as np
import matplotlib.pyplot as plt

def get_fvals(result, idx):
    mu_fval = np.mean(result[idx], axis=0)
    std_fval = np.std(result[idx], axis=0)
    return mu_fval, std_fval

def process_result(ris, l, T, idx=1):
    mu, std = get_fvals(ris, idx=idx)
    new_mu, new_std = [], []
    for i in range(len(mu)):
        new_mu += [mu[i] for _ in range(2*l)]
        new_std += [std[i] for _ in range(2*l)]
        if len(new_mu) >= T:
            break
    return np.asarray(new_mu), np.asarray(new_std)


def process_result2(ris, l, T, idx=1):
    mu, std = get_fvals(ris, idx=idx)

    return mu, std

def plot_results(title, means, stds, labels, xlabel, legend=False, out_file="./test.pdf"):


    fig, ax1 = plt.subplots(1, 1, figsize=(7, 6))
    ax1.set_title(title, fontsize=20)
    for i in range(len(means)):
        ax1.plot(means[i], '-', lw=3, label=labels[i])
        ax1.fill_between(range(len(means[i])), means[i] - stds[i], means[i] + stds[i], alpha=0.45)
    #    ax2.plot(stds[i],  '-', lw=3, label=labels[i])

    ax1.set_xlabel("function evaluations", fontsize=16)
    ax1.set_ylabel(xlabel, fontsize=16)
    ax1.set_yscale("log")
    if legend:
        ax1.legend(fontsize=12)
    fig.tight_layout()
    fig.savefig(out_file, bbox_inches='tight')

class TargetFunction:
    
    def __init__(self, x_star : np.ndarray, seed : int = 1212) -> None:
        self.x_star = x_star
        self.seed = seed
        self.rnd_state = np.random.RandomState(seed=seed)

    def __call__(self, x):
        pass


class ConvexSmooth(TargetFunction):

    def __init__(self, d : int, seed : int = 121212):
        super().__init__(np.zeros(d), seed=seed)
        self.A = self.rnd_state.randn(d, d)
    
    def __call__(self, x):
        return 0.5 * np.square(np.linalg.norm(self.A.dot(x), ord=2))

    def get_L(self):
        eigs = np.linalg.eigvalsh(self.A.T @ self.A)
        return max(eigs)


class ConvexNonSmooth(TargetFunction):

    def __init__(self, d : int, seed : int = 121212):
        super().__init__(np.zeros(d), seed=seed)
        self.y = np.asarray([i for i in range(d)])
    
    def __call__(self, x):
        return  np.linalg.norm(x - self.y, ord=1)


