from __future__ import print_function

import math
import numpy as np
import torch
import torch.optim as optim
import random

from PIL import ImageFilter


from tabulate import tabulate

class ValueTable:
    def __init__(self, name, values):
        self.values = values
        self.name = name
    
    def print_table_and_mean(self):
        table_data = [[self.name]] + [[value] for value in self.values]
        mean_value = sum(self.values) / len(self.values)
        table_data.append(["Mean", mean_value])
        print(tabulate(table_data, headers="firstrow"))

class GaussianBlur(object):
    """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""

    def __init__(self, sigma=[0.1, 2.0]):
        self.sigma = sigma

    def __call__(self, x):
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
        return x


class SaveFeaturesInputHook:
    def __init__(self):
        self.reset()
        self.enabled = False

    def __call__(self, module, input, output):
        if self.enabled:
            self.features.append(input[0])

    def reset(self):
        self.features = []

    def enable(self):
        self.enabled = True

    def disable(self):
        self.enabled = False


class SaveFeaturesListInputHook(SaveFeaturesInputHook):
    def __call__(self, module, input, output):
        if self.enabled:
            self.features.append(output.detach().view(output.size(0), -1))

    def reset(self):
        self.features = []


class ExponentialMovingAverage:
    def __init__(self, model, decay_rate):
        # Initialize the EMA with the given model and decay rate
        self.model = model
        self.decay_rate = decay_rate
        self.shadow_weights = {}
        self.original_weights = {}

    def register(self):
        # Register the model parameters for the EMA
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.shadow_weights[name] = param.data.clone()

    def update(self):
        # Update the EMA for each registered parameter
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow_weights
                decay = self.decay_rate
                updated_average = (1 - decay) * param.data + \
                    decay * self.shadow_weights[name]
                self.shadow_weights[name] = updated_average.clone()

    def apply_shadow_weights(self):
        # Apply the EMA weights to the model parameters
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow_weights
                self.original_weights[name] = param.data
                param.data = self.shadow_weights[name]

    def restore_original_weights(self):
        # Restore the original model parameters
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.original_weights
                param.data = self.original_weights[name]
        self.original_weights = {}

    def get_shadow_weights(self):
        # Get the current EMA weights
        return self.shadow_weights

    def set_shadow_weights(self, new_shadow_weights):
        # Set new EMA weights
        self.shadow_weights = new_shadow_weights


class TwoCropTransform:
    """Create two crops of the same image"""

    def __init__(self, transform):
        self.transform = transform

    def __call__(self, x):
        return [self.transform(x), self.transform(x)]


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def adjust_learning_rate(args, optimizer, epoch):
    lr = args.learning_rate
    if args.cosine:
        eta_min = lr * (args.lr_decay_rate ** 3)
        lr = eta_min + (lr - eta_min) * (
            1 + math.cos(math.pi * epoch / args.epochs)) / 2
    else:
        steps = np.sum(epoch > np.asarray(args.lr_decay_epochs))
        if steps > 0:
            lr = lr * (args.lr_decay_rate ** steps)

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer):
    if args.warm and epoch <= args.warm_epochs:
        p = (batch_id + (epoch - 1) * total_batches) / \
            (args.warm_epochs * total_batches)
        lr = args.warmup_from + p * (args.warmup_to - args.warmup_from)

        for param_group in optimizer.param_groups:
            param_group['lr'] = lr


def set_optimizer(opt, model):
    optimizer = optim.SGD(model.parameters(),
                          lr=opt.learning_rate,
                          momentum=opt.momentum,
                          weight_decay=opt.weight_decay,
                          nesterov=True)
    # optimizer = optim.AdamW(model.parameters(),
    #             lr=opt.learning_rate,
    #             weight_decay=opt.weight_decay)
    return optimizer


def save_model(model, optimizer, opt, epoch, save_file):
    print('==> Saving...')
    state = {
        'opt': opt,
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
    }
    torch.save(state, save_file)
    del state
