import os
import torch
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import argparse

parser = argparse.ArgumentParser(description="plots the results")
parser.add_argument("-d", "--decay", type=float, default=.99, help="Decay value for running averages")
parser.add_argument("-b", "--batch", type=int, default=100, help="Batch size")
parser.add_argument("-e", "--epochs", type=int, default=5, help="Number of epochs")
# parser.add_argument("-w", "--width", type=int, default=18, help="Width")
# parser.add_argument("-d", "--depth", type=int, default=18, help="Depth")
# parser.add_argument("-o", "--dim_out", type=int, default=10, help="Output dimension")
parser.add_argument("-r", "--runs", type=int, default=5, help="Number of runs to average over")

args = parser.parse_args()

batch = args.batch
epochs = args.epochs
runs = args.runs
decay = args.decay

train_size=9e4
dir_name = 'teacher_d10w10o1/learner_d15w15b'+str(batch)
title = f'Batch Size = {batch}, '


data = {}
epoch_step = train_size/batch
total_steps = epoch_step*epochs
#dir_name = "batch10_d15w15"    

metrics = ['Test R²', 'Training R²', 'Test Loss', 'Training Loss']

names = ['NAG, lr=1e-4, m=.99', 'ADAM 1e-3', 'SGD, lr=1e-4, m=.99', 'AGNES, lr=1e-4, eta=1e-3, m=.99', ]
labels = {'NAG, lr=1e-4, m=.99':'NAG',
        'ADAM 1e-3':'ADAM',
        'SGD, lr=1e-4, m=.99':'SGD+momentum',
        'AGNES, lr=1e-4, eta=1e-3, m=.99':'AGNES',}

max_acc = {}
#for filename in os.listdir(dir_name):
for name in names:
    data[name] = {metric:[] for metric in metrics}
    # if filename.startswith(str(i)) and filename.endswith(f"_{epochs}.pth"):
    #     #name = filename[1:-len(f"_{epochs}.pth")]
    #     if name not in names:
    #         names.append(name)
    #         
    
    for i in range(runs):
        filename = f'{name}+_r{i}_{epochs}.pth' #ADAM 1e-3+_r1
        try:
            with open(os.path.join(dir_name,filename), 'rb') as file:
                temp = torch.load(file, map_location=torch.device('cpu'))
                data[name]['Test Loss'].append(np.array(temp['test_losses']))
                data[name]['Test R²'].append(np.array(temp['test_accuracies']))
                running_averages = [temp['train_losses'][0]]
                for num in temp['train_losses']:
                    running_averages.append(decay*running_averages[-1] + (1-decay)*num)
                data[name]['Training Loss'].append(np.array(running_averages))
                running_averages = [temp['train_accuracies'][0]]
                for num in temp['train_accuracies']:
                    running_averages.append(decay*running_averages[-1] + (1-decay)*num)
                data[name]['Training R²'].append(np.array(running_averages))
    #                 data[name]['Max Accuracy'].append(np.maximum.accumulate(data[name]['Test R²'][i]))
    #                 max_acc[f'{name}_{i}']=data[name]['Max Accuracy'][i][-1]
        except FileNotFoundError:
            print(filename, "does not exist.")


metric='Test R²'
plt.figure()
for name in names:
    mean = np.mean(data[name][metric], axis = 0)
    std = np.std(data[name][metric] , axis = 0)

    plt.plot(np.arange(0,total_steps+1,epoch_step), mean, label = labels[name])#, color = colors[name])
    plt.fill_between(np.arange(0,total_steps+1,epoch_step), mean+std, mean-std, alpha = 0.2)#, color = colors[name])

plt.title(title+metric)
plt.legend()
plt.ylim([.9,1])
#plt.show()
plt.savefig(os.path.join(dir_name,title+metric))
#     plt.savefig(os.path.join(dir_name,title+metric+"_zoomed"))

metric='Training R²'
plt.figure()
for name in names:
    mean = np.mean(data[name][metric], axis = 0)
    std = np.std(data[name][metric] , axis = 0)

    plt.plot(mean, label = labels[name])#, color = colors[name])
    plt.fill_between(np.arange(0,total_steps+1), mean+std, mean-std, alpha = 0.2)#, color = colors[name])

plt.title(title+metric)
plt.legend()
plt.ylim([.9,1])
#plt.show()
plt.savefig(os.path.join(dir_name,title+metric))



metric='Test Loss'
plt.figure()
for name in names:
    mean = np.mean(data[name][metric], axis = 0)
    std = np.std(data[name][metric] , axis = 0)

    plt.semilogy(np.arange(0,total_steps+1,epoch_step), mean, label = labels[name])#, color = colors[name])
    plt.fill_between(np.arange(0,total_steps+1,epoch_step), mean+std, mean-std, alpha = 0.2)#, color = colors[name])

plt.title(title+metric)
plt.legend()
#plt.show()
plt.savefig(os.path.join(dir_name,title+metric))
plt.ylim([0.05,.8])
plt.savefig(os.path.join(dir_name,title+metric+"_zoomed"))

metric='Training Loss'
plt.figure()
for name in names:
    mean = np.mean(data[name][metric], axis = 0)
    std = np.std(data[name][metric] , axis = 0)

    plt.semilogy(mean, label = labels[name])#, color = colors[name])
    plt.fill_between(np.arange(0,total_steps+1), mean+std, mean-std, alpha = 0.2)#, color = colors[name])

plt.title(title+metric)
plt.legend()
#plt.show()
plt.savefig(os.path.join(dir_name,title+metric))
plt.ylim([0.05,.8])
plt.savefig(os.path.join(dir_name,title+metric+"_zoomed"))
