
import numpy as np
import matplotlib.pyplot as plt
import random
import numpy.linalg as la
import math
import random
from math import pi
from datetime import datetime

def fig_multigamma(circuit_name, n_qubits, n_params, iteration, time, optimizer, noise_con, noise_rate):
    line = 1.5
    figure_iteration_x = np.zeros(iteration)
    for i in range(iteration):
        figure_iteration_x[i] = i
    colors = ["blue", "brown", "red", "grey", "black"]
    n1 = 0
    for noise in noise_rate:
        folder = "./"+circuit_name+"_"+str(n_qubits)+"_"+str(n_params)+"_"+optimizer+"_gaussian_"+noise+"_"+str(noise_con)+"/"
        loss = np.zeros((time,iteration))
        for i in range(time):
            loss[i] = np.load(folder+"loss_"+str(i)+".npy")[:iteration]
        loss_mean = np.mean(loss, axis = 0)
        loss_min = np.min(loss, axis = 0)
        loss_max = np.max(loss, axis = 0)
        plt.plot(figure_iteration_x, loss_mean, label = noise+r"$\gamma$", color = colors[n1], lw = line)
        plt.fill_between(figure_iteration_x, loss_min, loss_max, color = colors[n1], alpha=0.3)
        n1 = n1 + 1
    if(optimizer == 'gd'):
        plt.xlabel('training iteration with GD', fontsize = 17)
    if(optimizer == 'adam'):
        plt.xlabel('training iteration with Adam', fontsize = 17)
    plt.ylabel(r"$f$", fontsize = 17)
    plt.xticks(fontsize = 12)
    plt.yticks(fontsize = 12)
    plt.grid(visible=True, axis = 'y', which='major', color='#666666', linestyle='--')
    plt.legend(fontsize = 17, loc="upper right")
    plt.savefig("xxx_multiloss"+"_"+optimizer+"_"+str(n_qubits)+"_"+str(n_params)+"_"+str(noise_con)+".pdf")
    plt.close()

    n1 = 0
    for noise in noise_rate:
        folder = "./"+circuit_name+"_"+str(n_qubits)+"_"+str(n_params)+"_"+optimizer+"_gaussian_"+noise+"_"+str(noise_con)+"/"
        loss = np.zeros((time,iteration))
        for i in range(time):
            loss[i] = np.load(folder+"grad_norm_"+str(i)+".npy")[:iteration]
        loss_mean = np.mean(loss, axis = 0)
        loss_min = np.min(loss, axis = 0)
        loss_max = np.max(loss, axis = 0)
        plt.plot(figure_iteration_x, loss_mean, label = noise+r"$\gamma$", color = colors[n1], lw = line)
        plt.fill_between(figure_iteration_x, loss_min, loss_max, color = colors[n1], alpha=0.3)
        n1 = n1 + 1
    if(optimizer == 'gd'):
        plt.xlabel('training iteration with GD', fontsize = 17)
    if(optimizer == 'adam'):
        plt.xlabel('training iteration with Adam', fontsize = 17)
    plt.ylabel("gradient norm", fontsize = 17)
    plt.xticks(fontsize = 12)
    plt.yticks(fontsize = 12)
    plt.yscale('log')
    plt.grid(visible=True, axis = 'y', which='major', color='#666666', linestyle='--')
    plt.legend(fontsize = 17, loc="upper right")
    plt.savefig("xxx_multigradnorm"+"_"+optimizer+"_"+str(n_qubits)+"_"+str(n_params)+"_"+str(noise_con)+".pdf")
    plt.close()

def fig(circuit_name, n_qubits, n_params, iteration, time, optimizer, noise_con):
    legend_size = 20
    line = 1.5
    folder = "./"+circuit_name+"_"+str(n_qubits)+"_"+str(n_params)+"_"+optimizer
    if(noise_con==0):
        folder_gaussian = folder+"_gaussian_1_0/"
        folder_zero = folder+"_zero_0/"
        folder_uniform = folder+"_uniform_0/"
    else: 
        folder_gaussian = folder+"_gaussian_1_1/"
        folder_zero = folder+"_zero_1/"
        folder_uniform = folder+"_uniform_1/"

    figure_iteration_x = np.zeros(iteration)
    for i in range(iteration):
        figure_iteration_x[i] = i
    loss = np.zeros((time,iteration))
    for i in range(time):
        loss[i] = np.load(folder_gaussian+"loss_"+str(i)+".npy")[:iteration] #- ground_energy
    loss_mean = np.mean(loss, axis = 0)
    loss_min = np.min(loss, axis = 0)
    loss_max = np.max(loss, axis = 0)
    plt.plot(figure_iteration_x, loss_mean, linestyle='-', label = "gaussian", color = 'red', lw = line)
    plt.fill_between(figure_iteration_x, loss_min, loss_max, color='red', alpha=0.3)

    figure_iteration_x = np.zeros(iteration)
    for i in range(iteration):
        figure_iteration_x[i] = i
    loss = np.zeros((time,iteration))
    for i in range(time):
        loss[i] = np.load(folder_uniform+"loss_"+str(i)+".npy")[:iteration]
    loss_mean = np.mean(loss, axis = 0)
    loss_min = np.min(loss, axis = 0)
    loss_max = np.max(loss, axis = 0)
    plt.plot(figure_iteration_x, loss_mean,  linestyle=':',  label = "uniform", color = 'blue', lw = line)
    plt.fill_between(figure_iteration_x, loss_min, loss_max, color = 'blue', alpha=0.3)

    loss = np.zeros((time,iteration))
    for i in range(time):
        loss[i] = np.load(folder_zero+"loss_"+str(i)+".npy")[:iteration]
    loss_mean = np.mean(loss, axis = 0)
    loss_min = np.min(loss, axis = 0)
    loss_max = np.max(loss, axis = 0)
    plt.plot(figure_iteration_x, loss_mean,  linestyle='--', label = "zero", color = 'black', lw = line)
    plt.fill_between(figure_iteration_x, loss_min, loss_max, color = 'black', alpha=0.3)

    if(optimizer == 'gd'):
        plt.xlabel('training iteration with GD', fontsize = 17)
    if(optimizer == 'adam'):
        plt.xlabel('training iteration with Adam', fontsize = 17)

    plt.ylabel(r"$f$", fontsize = 17)
    plt.xticks(fontsize = 12)
    plt.yticks(fontsize = 12)

    plt.grid(visible=True, axis = 'y', which='major', color='#666666', linestyle='--')
    # plt.minorticks_on()
    # plt.grid(b=True, which='minor', color='#999999', linestyle=':', alpha=0.2)
    plt.legend(fontsize = legend_size)
    plt.legend(fontsize = legend_size, loc="upper right")
    plt.savefig("xxx_loss"+"_"+optimizer+"_"+str(n_qubits)+"_"+str(n_params)+"_"+str(noise_con)+".pdf")
    plt.close()

    gradnorm = np.zeros((time,iteration))
    for i in range(time):
        gradnorm[i] = np.load(folder_gaussian+"grad_norm_"+str(i)+".npy")[:iteration]
    gradnorm_mean = np.mean(gradnorm, axis = 0)
    gradnorm_min = np.min(gradnorm, axis = 0)
    gradnorm_max = np.max(gradnorm, axis = 0)
    plt.plot(figure_iteration_x, gradnorm_mean, linestyle='-', label = "gaussian", color = 'red', lw = line)
    plt.fill_between(figure_iteration_x, gradnorm_min, gradnorm_max, color='red', alpha=0.3)

    figure_iteration_x = np.zeros(iteration)
    for i in range(iteration):
        figure_iteration_x[i] = i
    gradnorm = np.zeros((time,iteration))
    for i in range(time):
        gradnorm[i] = np.load(folder_uniform+"grad_norm_"+str(i)+".npy")[:iteration]
    gradnorm_mean = np.mean(gradnorm, axis = 0)
    gradnorm_min = np.min(gradnorm, axis = 0)
    gradnorm_max = np.max(gradnorm, axis = 0)
    plt.plot(figure_iteration_x, gradnorm_mean, linestyle=':', label = "uniform", color = 'blue', lw = line)
    plt.fill_between(figure_iteration_x, gradnorm_min, gradnorm_max, color='blue', alpha=0.3)

    gradnorm = np.zeros((time,iteration))
    for i in range(time):
        gradnorm[i] = np.load(folder_zero+"grad_norm_"+str(i)+".npy")[:iteration]
    gradnorm_mean, gradnorm_min, gradnorm_max = np.mean(gradnorm, axis = 0), np.min(gradnorm, axis = 0), np.max(gradnorm, axis = 0)
    plt.plot(figure_iteration_x, gradnorm_mean, linestyle='--', label = "zero", color = 'black', lw = line)
    plt.fill_between(figure_iteration_x, gradnorm_min, gradnorm_max, color='black', alpha=0.3)


    if(optimizer == 'gd'):
        plt.xlabel('training iteration with GD', fontsize = 17)
    if(optimizer == 'adam'):
        plt.xlabel('training iteration with Adam', fontsize = 17)


    plt.ylabel('gradient norm', fontsize = 17)
    plt.xticks(fontsize = 12)
    plt.yticks(fontsize = 12)

    plt.grid(visible=True, axis = 'y', which='major', color='#666666', linestyle='--')
    # plt.minorticks_on()
    # plt.grid(b=True, which='minor', color='#999999', linestyle=':', alpha=0.2)
    plt.legend(fontsize=legend_size)
    plt.legend(fontsize=legend_size, loc="upper right")
    plt.savefig("xxx_gradnorm"+"_"+optimizer+"_"+str(n_qubits)+"_"+str(n_params)+"_"+str(noise_con)+".pdf")
    plt.close()


n_qubits = 15
iteration = 100
time = 5
noise_rate = ["0.01", "0.1", "1", "10", "100"]

para_list = [300]
for n_params in para_list:
    fig_multigamma("xxx", n_qubits, n_params, iteration, time, "gd", 0, noise_rate)
    fig_multigamma("xxx", n_qubits, n_params, iteration, time, "gd", 1, noise_rate)
    fig_multigamma("xxx", n_qubits, n_params, iteration, time, "adam", 0, noise_rate)
    fig_multigamma("xxx", n_qubits, n_params, iteration, time, "adam", 1, noise_rate)
    fig("xxx", n_qubits, n_params, iteration, time, "gd", 0)
    fig("xxx", n_qubits, n_params, iteration, time, "gd", 1)
    fig("xxx", n_qubits, n_params, iteration, time, "adam", 0)
    fig("xxx", n_qubits, n_params, iteration, time, "adam", 1)
