import numpy as np
from scipy.stats import ortho_group
from matplotlib import pyplot as plt

plt.rcParams["font.size"] = 16
plt.rcParams["legend.fontsize"] = 16
plt.rcParams["axes.labelsize"] = 18
plt.rcParams["lines.linewidth"] = 3


def generate_matrix(dim, rank, seed=123):
    np.random.seed(seed)

    eig_non0 = np.random.uniform(low=0.1, high=10, size=rank)
    eig = np.hstack((eig_non0, np.zeros(dim-rank)))

    P = ortho_group.rvs(dim, random_state=seed)
    A = P.T @ np.diag(eig) @ P

    return A


def oracle(x, A, b, x_0, reg=0.0, inexact=0.0):
    '''
    x is the query point
    obj: 1/2 (Ax - b)^2
    grad: A.T(Ax - b)
    '''

    res = A @ x - b
    func = 0.5 * (np.linalg.norm(res) ** 2)
    grad = (A.T @ res) + reg * (x - x_0)
    return func, grad + inexact * np.ones_like(grad)


def gd(A, b, x_0, step, T, reg=0.0, inexact=0.0):
    x_list = [x_0]
    x = x_0

    for t in range(T):
        _, grad = oracle(x, A, b, x_0, reg, inexact)
        x = x - step * grad
        x_list.append(x)
    
    return x_list


def agd(A, b, x_0, step, T, reg=0.0, inexact=0.0):
    x_list, y_list = [x_0], [x_0]
    y = x_0

    for t in range(T):
        _, grad = oracle(y, A, b, x_0, reg, inexact)
        x = y - step * grad

        if reg == 0:
            beta = t / (t+3)
        else:
            ell = np.linalg.norm(A.T @ A, 2)
            kappa_ = np.sqrt(reg / (reg + ell))
            beta = (2 - kappa_) / (2 + kappa_)

        y = x + beta * (x - x_list[-1])
        
        x_list.append(x)
        y_list.append(y)
    
    return x_list, y_list


def run_min_grad(T, size, step, delta, reg=0.0, seed=123):

    np.random.seed(seed)

    n, r = size
    A = generate_matrix(dim=n, rank=r, seed=seed)
    b = np.random.normal(0, 10, size=(n, 1))
    x = np.zeros(shape=(n, 1))

    x_agd_1, _ = agd(A, b, x, step, T)
    x_gd_1 = gd(A, b, x, step, T)
    x_reg_agd_1, _ = agd(A, b, x, step, T, reg)
    x_reg_gd_1 = gd(A, b, x, step, T, reg)

    x_agd_2, _ = agd(A, b, x, step, T, inexact=delta)
    x_gd_2 = gd(A, b, x, step, T, inexact=delta)
    x_reg_agd_2, _ = agd(A, b, x, step, T, reg, inexact=delta)
    x_reg_gd_2 = gd(A, b, x, step, T, reg, inexact=delta)

    print('stepsize used in theory:', 1 / np.linalg.norm(A.T @ A, 2))
    print('func val: gd={}, gd_inexact={}'.format(oracle(x_gd_1[-1], A, b, x)[0], oracle(x_gd_2[-1], A, b, x)[0]))
    print('func val: agd={}, agd_inexact={}'.format(oracle(x_agd_1[-1], A, b, x)[0], oracle(x_agd_2[-1], A, b, x)[0]))

    colorlist = ['#448bff', '#3bc335', '#ff9600', '#f84c00']

    fig = plt.figure(figsize=(12, 5))
    ax1 = fig.add_subplot(121)
    ax2 = fig.add_subplot(122)

    ax1.plot(list(range(T+1)), [oracle(i, A, b, x)[0] for i in x_agd_2], label='AGD', color=colorlist[0])
    ax1.plot(list(range(T+1)), [oracle(i, A, b, x)[0] for i in x_gd_2], label='GD', color=colorlist[1])
    ax1.plot(list(range(T+1)), [oracle(i, A, b, x)[0] for i in x_reg_agd_2], label='Reg-AGD', color=colorlist[2])
    ax1.plot(list(range(T+1)), [oracle(i, A, b, x)[0] for i in x_reg_gd_2], label='Reg-GD', color=colorlist[3])

    ax2.plot(list(range(T+1)), [np.linalg.norm(i1 - i2) for i1, i2 in zip(x_agd_1, x_agd_2)], label='AGD', color=colorlist[0])
    ax2.plot(list(range(T+1)), [np.linalg.norm(i1 - i2) for i1, i2 in zip(x_gd_1, x_gd_2)], label='GD', color=colorlist[1])
    ax2.plot(list(range(T+1)), [np.linalg.norm(i1 - i2) for i1, i2 in zip(x_reg_agd_1, x_reg_agd_2)], label='Reg-AGD', color=colorlist[2])
    ax2.plot(list(range(T+1)), [np.linalg.norm(i1 - i2) for i1, i2 in zip(x_reg_gd_1, x_reg_gd_2)], label='Reg-GD', color=colorlist[3])

    ax1.set_title('Function Value')
    ax1.legend()
    ax1.set_xlabel('# Iterations')
    ax1.set_ylim(20,3000)
    ax1.set_xscale('log')
    ax1.set_yscale('log')

    ax2.set_title('Deviation in Trajectory')
    ax2.legend()
    ax2.set_xlabel('# Iterations')
    ax2.set_xscale('log')
    ax2.set_yscale('log')
    
    plt.savefig('./result/min.png', dpi=500, bbox_inches='tight')