from algorithms_no_log import *

plot_only = False

linestyles = ['-', '-.', '--', ':', '-.', '-', '-', '-.', '--', ':', '-.', '-']
markers = ['o', '*', 'd', 'v', 'P', '1', 'p', 'X']
colors = ['tab:blue', 'tab:red', 'tab:green', 'tab:brown', 'tab:purple', 'tab:gray',
          'tab:olive', 'tab:cyan']
T_dict = {"a1a": 5, "mushrooms": 12, "phishing": 11, "duke": 11, "madelon": 200}

mu = 1e-4
L = 1.0
experiment = "LambdaEffect"
assignment = "random"

datasets = ["duke"]
algs = ["iapgdkat", "al2sgd", "l2sgd", "apgd2", "iapgdagd"]
labels2 = ["IAPGD+Kat.", "AL2SGD+", "L2SGD+", "APGD 2", "IAPGD+AGD"]


Lambdas = [10 ** i for i in range(-7, 2)]  ####
accuracy = 1e-3

for dataset in datasets:
    print("#############################")
    print(dataset)
    print("#############################")

    conv_com = [[] for i in range(len(algs))]
    conv_grad = [[] for i in range(len(algs))]



    A, b = get_data(dataset)
    n, d = A.shape
    A = normalize_data(A, L)

    x0 = np.zeros(d)

    T = T_dict[dataset]
    m = int(n / T)
    assert m * T == n

    pagg = 1 / m
    lamda = pagg / (1 - pagg)
    omega = lamda / T
    rho = pagg * 1.0

    Lf = (lamda + L) / T
    cL = max(L / T / (1 - rho), lamda / rho / T)
    rate = 1 - 0.25 * min(rho, np.sqrt(mu / (2 * T * max(cL / rho, Lf))))

    if not plot_only:
        for lamda in Lambdas:
            print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$\n\nLambda = {}, getting X_opt".format(lamda))

            A, b = rearrange_data(A, method=assignment, b=b)
            f, g = make_fg_logreg(A, b, mu)

            _, _, X_opt = al2sgd_no_log(g, d, T, m, rho, lamda, number_of_steps= 10000*m)

            for i in range(len(algs)):
                alg = algs[i]

                print("&&&&&&&&&&&&&&&&&&&&&&&&&&&")
                print("alg = {}".format(alg))

                if alg == "l2sgd":
                    alpha = get_stepsize_saga(v=np.ones(n), p=np.ones(n) / m, pagg=pagg, pwork=1.0, omega=lamda/T, n=n,
                                              T=T, mu=mu)
                    aggregations, loc_steps, X = l2sgd_no_log(g, alpha, d, T, m, pagg, lamda/T, X_opt=X_opt, accuracy=accuracy)
                elif alg == "al2sgd":
                    aggregations, loc_steps, X = al2sgd_no_log(g, d, T, m, rho, lamda, X_opt=X_opt, accuracy=accuracy)
                elif alg ==  "iapgdkat":
                    aggregations, loc_steps, X = apgd_no_log(g, d, T, m, lamda, X_opt=X_opt, accuracy=accuracy)
                elif alg == "iapgdagd":
                    aggregations, loc_steps, X = apgd_no_log(g, d, T, m, lamda, stochastic=False, X_opt=X_opt, accuracy=accuracy)
                elif alg == "apgd2":
                    aggregations, loc_steps, X = apgd2_no_log(g, d, T, m, lamda, mu=mu, X_opt=X_opt, accuracy=accuracy)
                else:
                    raise ValueError("Algorithm doesn't exist")
                X.reshape(T*d, -1)

                print("FINISHED ### communication: {}, gradient_comptutation: {}".format(aggregations, loc_steps))

                conv_com[i].append(aggregations)
                conv_grad[i].append(loc_steps)
            print(conv_com)
            print(conv_grad)

        # save F
        for backup in [True, False]:
            filename = createfilename(experiment, dataset, T, 0, mu, 1.0, "testing lambda", backup)
            pickle_out = open(filename, "w+b")
            pickle.dump((conv_com, conv_grad, X), pickle_out)
            pickle_out.close()


    # plot manually
    items = [[], []]  # conv_com, conv_grad
    items[0], items[1], Xlist = load_pickle(experiment, dataset, T, 0, mu, 1.0, "testing lambda")

    for i in range(2):
        name = ["Communication rounds", "Gradient of local summands"][i]
        plt.xscale('log')
        plt.yscale('log')
        plt.xlabel("Lambda/L", fontsize='x-large')
        plt.ylabel('Number of ' + name, fontsize='x-large')
        plt.title('Dataset: {}, accuracy: {}'.format(dataset, accuracy), fontsize='x-large')
        plt.tight_layout()
        for j in range(len(algs)):
            plt.scatter(Lambdas[1:], items[i][j][1:], label=labels2[j], marker="s")
        plt.legend(fontsize='x-large', loc='best')
        plt.savefig('{}{}LambdasDataset{}Accuracy{}.pdf'.format(os.getcwd() + '/plots/', name, dataset, accuracy))
        plt.close()

print("END")
