import matplotlib
from batchgenerators.utilities.file_and_folder_operations import join

matplotlib.use('agg')
import seaborn as sns
import matplotlib.pyplot as plt


class AdamCPRLogger(object):
    """
    This class is really trivial. Don't expect cool functionality here. This is my makeshift solution to problems
    arising from out-of-sync epoch numbers and numbers of logged loss values. It also simplifies the trainer class a
    little

    YOU MUST LOG EXACTLY ONE VALUE PER EPOCH FOR EACH OF THE LOGGING ITEMS! DONT FUCK IT UP
    """
    def __init__(self, names, verbose: bool = False):
        lagmuls = {f'lagmul_{name}': list() for name in names}
        kappas = {f'kappa_{name}': list() for name in names}
        l2s = {f'l2_{name}': list() for name in names}
        self.my_fantastic_logging = {**lagmuls, **kappas, **l2s}
        self.verbose = verbose
        # shut up, this logging is great

    def log(self, key, value, epoch: int):
        """
        sometimes shit gets messed up. We try to catch that here
        """
        assert key in self.my_fantastic_logging.keys() and isinstance(self.my_fantastic_logging[key], list), \
            'This function is only intended to log stuff to lists and to have one entry per epoch'

        if self.verbose: print(f'logging {key}: {value} for epoch {epoch}')

        if len(self.my_fantastic_logging[key]) < (epoch + 1):
            self.my_fantastic_logging[key].append(value)
        else:
            assert len(self.my_fantastic_logging[key]) == (epoch + 1), 'something went horribly wrong. My logging ' \
                                                                       'lists length is off by more than 1'
            print(f'maybe some logging issue!? logging {key} and {value}')
            self.my_fantastic_logging[key][epoch] = value

        # handle the ema_fg_dice special case! It is automatically logged when we add a new mean_fg_dice
        if key == 'mean_fg_dice':
            new_ema_pseudo_dice = self.my_fantastic_logging['ema_fg_dice'][epoch - 1] * 0.9 + 0.1 * value \
                if len(self.my_fantastic_logging['ema_fg_dice']) > 0 else value
            self.log('ema_fg_dice', new_ema_pseudo_dice, epoch)

    def plot_progress_png(self, output_folder):
        
        
        # we infer the epoch form our internal logging
        epoch = min([len(i) for i in self.my_fantastic_logging.values()]) - 1  # lists of epoch 0 have len 1
        lagmuls = {k: v for k, v in self.my_fantastic_logging.items() if 'lagmul' in k}
        kappas = {k: v for k, v in self.my_fantastic_logging.items() if 'kappa' in k}
        l2s = {k: v for k, v in self.my_fantastic_logging.items() if 'l2_' in k}
        x_values = list(range(epoch + 1))
        
        sns.set(font_scale=2.5)
        fig_1, ax_all_1 = plt.subplots(8, 4, figsize=(60, 80))
        
        plt.tight_layout()
    
        for i, (k, v) in enumerate(lagmuls.items()):
            
            ax = ax_all_1[i//4, i%4]
            
            ax.plot(x_values, v[:epoch + 1], color='b', ls='-', label=k, linewidth=4)
            if i//4 == 7:
                ax.set_xlabel("epoch")
            ax.set_ylabel("lagmul")
            ax.legend(loc=(0, 1))

        fig_1.savefig(join(output_folder, "lagmuls.png"))

        fig_2, ax_all_2 = plt.subplots(8, 4, figsize=(60, 80))
        
        plt.tight_layout()
        
        for i, (k, v) in enumerate(kappas.items()):
                
            ax = ax_all_2[i//4, i%4]
            
            ax.plot(x_values, v[:epoch + 1], color='b', ls='-', label=k, linewidth=4)
            if i//4 == 7:
                ax.set_xlabel("epoch")
            ax.set_ylabel("kappa")
            ax.legend(loc=(0, 1))

        fig_2.savefig(join(output_folder, "kappas.png"))
        
        fig_3, ax_all_3 = plt.subplots(8, 4, figsize=(60, 80))
        
        plt.tight_layout()
    
        for i, (k, v) in enumerate(l2s.items()):
            
            ax = ax_all_3[i//4, i%4]
            
            ax.plot(x_values, v[:epoch + 1], color='b', ls='-', label=k, linewidth=4)
            if i//4 == 7:
                ax.set_xlabel("epoch")
            ax.set_ylabel("l2")
            ax.legend(loc=(0, 1))

        fig_3.savefig(join(output_folder, "l2s.png"))        
        
        plt.close()

    def get_checkpoint(self):
        return self.my_fantastic_logging

    def load_checkpoint(self, checkpoint: dict):
        self.my_fantastic_logging = checkpoint
