import numpy as np
import sys
import matplotlib.pylab as pylab
import matplotlib.pyplot as plt
import matplotlib
import os
import warnings
from matplotlib.mathtext import MathTextWarning

# Define the function that fixes the plotting issues
def fix(ax=None):
    
    if ax is None:
        ax = plt.gca()
    fig = ax.get_figure()
    # Force the figure to be drawn
    with warnings.catch_warnings():
        warnings.simplefilter('ignore', category=MathTextWarning)
        fig.canvas.draw()
    # Remove '\mathdefault' from all minor tick labels
    labels = [label.get_text().replace('\mathdefault', '')
              for label in ax.get_xminorticklabels()]
    ax.set_xticklabels(labels, minor=True)

    for label in ax.get_yminorticklabels():
        print(label.get_text()) 
        print('---')

    labels = [label.get_text().replace('\mathdefault', '')
              for label in ax.get_yminorticklabels()]
    ax.set_yticklabels(labels, minor=True)



name_result = 'result'
            
# Define the plotting parameters
params = {'legend.fontsize': 24,
         'axes.labelsize': 22,
         'axes.titlesize': 26,
         'xtick.labelsize':20,
         'ytick.labelsize':20}
pylab.rcParams.update(params)
matplotlib.rcParams['mathtext.fontset'] = 'cm'
matplotlib.rcParams['mathtext.rm'] = 'serif'
plt.rcParams["font.family"] = "cmr10"
ls_plot   = 1.25
fs_label  = 22 
fs_legend = 13
dpi_value = 300
linewidth = 3
line_alpha = .8
title_offset = -0.5

loss_list = []
mse_mu_list = []
recon_loss_list = []
recon_fft_loss_list = []
kl_q_loss_list = []
fft_loss_list = []
state_time_loss_list = []
state_four_loss_list = []

count = 0
with open(home_dir + file_dir+'/'+name_result) as fp:
    Lines = fp.readlines()
    for line in Lines:
        if 'Epoch' in line:
            words = line.split('=')
            if count >= 1:
                #for ww in range(len(words)):
                #    print(ww,words[ww])
                #sys.exit()
                try:
                    #loss_list.append(float(words[1].split(', ')[0]))
                    #mse_mu_list.append(float(words[3].split(', ')[0]))#
                    #recon_loss_list.append(float(words[5].split(', ')[0]))#
                    #recon_fft_loss_list.append(float(words[6].split(', ')[0]))
                    #kl_q_loss_list.append(float(words[7].split(', ')[0]))
                    #state_four_loss_list.append(float(words[8].split(', ')[0]))
                    #state_time_loss_list.append(float(words[10].split(', ')[0]))

                    loss_list.append(float(words[1].split(', ')[0]))
                    #mse_mu_list.append(float(words[3].split(', ')[0]))#
                    recon_loss_list.append(float(words[4].split(', ')[0]))#
                    #recon_fft_loss_list.append(float(words[6].split(', ')[0]))
                    kl_q_loss_list.append(float(words[6].split(', ')[0]))
                    #state_four_loss_list.append(float(words[8].split(', ')[0]))
                    state_time_loss_list.append(float(words[9].split(', ')[0]))

                except:
                    break
                    print('count:',count)
            count += 1

fig1 = plt.figure(constrained_layout=True, figsize=(28,4))
gs = fig1.add_gridspec(1, 7, width_ratios=[1,1,1,1,1,1,1])
ax = fig1.add_subplot(gs[0, 0])
ax.plot(loss_list,'-',color='red')
ax.grid(True, which='both', alpha=.3)
plt.yscale("log")
plt.xlabel(fr'Iterations',fontsize=fs_label)
plt.ylabel(fr'Loss', fontsize=fs_label)
plt.tight_layout()

ax = fig1.add_subplot(gs[0, 1])
ax.plot(mse_mu_list,'-',color='red')
ax.grid(True, which='both', alpha=.3)
plt.yscale("log")
plt.xlabel(fr'Iterations',fontsize=fs_label)
plt.ylabel(fr'MSE of $\mu$', fontsize=fs_label)
fix()
plt.tight_layout()

ax = fig1.add_subplot(gs[0, 2])
ax.plot(recon_loss_list,'-',color='red')
ax.grid(True, which='both', alpha=.3)
plt.xlabel(fr'Iterations',fontsize=fs_label)
plt.ylabel(fr'Reconstruction loss', fontsize=fs_label)
fix()
plt.tight_layout()

ax = fig1.add_subplot(gs[0, 3])
ax.plot(kl_q_loss_list,'-',color='red')
ax.grid(True, which='both', alpha=.3)
plt.xlabel(fr'Iterations',fontsize=fs_label)
plt.ylabel(fr'KL loss', fontsize=fs_label)
fix()
plt.tight_layout()

ax = fig1.add_subplot(gs[0, 4])
ax.plot(recon_fft_loss_list,'-',color='red')
ax.grid(True, which='both', alpha=.3)
plt.xlabel(fr'Iterations',fontsize=fs_label)
plt.ylabel(fr'Recons. DFT loss', fontsize=fs_label)
fix()
plt.tight_layout()

ax = fig1.add_subplot(gs[0, 5])
ax.plot(state_time_loss_list,'-',color='red')
ax.grid(True, which='both', alpha=.3)
plt.xlabel(fr'Iterations',fontsize=fs_label)
plt.ylabel(fr'Time state loss', fontsize=fs_label)
fix()
plt.tight_layout()

ax = fig1.add_subplot(gs[0, 6])
ax.plot(state_four_loss_list,'-',color='red')
ax.grid(True, which='both', alpha=.3)
plt.xlabel(fr'Iterations',fontsize=fs_label)
plt.ylabel(fr'Fourier state loss', fontsize=fs_label)
fix()
plt.tight_layout()
plt.savefig(home_dir + file_dir+'/plot.png',dpi=dpi_value)
plt.close()
