import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
import math
import random
import numpy as np
import os

from pytorch_lightning import LightningModule
import matplotlib.pyplot as plt

from models import utils as mutils
from models.ema import ExponentialMovingAverage
import sde_lib
import scipy.interpolate as interpolate
from functools import partial
from torchvision.utils import save_image
from collections import defaultdict
import torch.distributed as dist
import torch.distributed.nn as dist_nn
from torch.cuda.amp import GradScaler
from sklearn.metrics import f1_score

class ScoreModel(LightningModule):

    def __init__(self, config, workdir):
        super().__init__()
        if config.data.pu:
            config.classifier.classes = 2

        def get_data_scaler(config):
            """Data normalizer. Assume data are always in [0, 1]."""
            if config.data.centered:
                # Rescale to [-1, 1]
                return lambda x: x * 2. - 1.
            else:
                return lambda x: x

        def get_data_inverse_scaler(config):
            """Inverse data normalizer."""
            if config.data.centered:
                # Rescale [-1, 1] to [0, 1]
                return lambda x: (x + 1.) / 2.
            else:
                return lambda x: x

        self.config = config = config.unlock()
        self.workdir = workdir

        os.makedirs(workdir, exist_ok=True)
        os.makedirs(f"{workdir}/samples", exist_ok=True)

        self.score_model = mutils.create_model(config)
        self.ema = ExponentialMovingAverage(self.score_model.parameters(),
                                            decay=config.model.ema_rate)

        self.clf = mutils.create_classifier(config)
        self.clf_ema = ExponentialMovingAverage(self.clf.parameters(),
                                                decay=config.model.ema_rate)

        if config.training.sde.lower() == 'vpsde':
            sde = sde_lib.VPSDE(beta_min=config.model.beta_min,
                                beta_max=config.model.beta_max,
                                N=config.model.num_scales)
            sampling_eps = 1e-3
        elif config.training.sde.lower() == 'subvpsde':
            sde = sde_lib.subVPSDE(beta_min=config.model.beta_min,
                                   beta_max=config.model.beta_max,
                                   N=config.model.num_scales)
            sampling_eps = 1e-3
        elif config.training.sde.lower() == 'vesde':
            sde = sde_lib.VESDE(sigma_min=config.model.sigma_min,
                                sigma_max=config.model.sigma_max,
                                N=config.model.num_scales)
            sampling_eps = 1e-5

        self.sde = sde
        self.sampling_eps = sampling_eps
        self.scaler = get_data_scaler(config)
        self.inverse_scaler = get_data_inverse_scaler(config)

        self.automatic_optimization = False
    
    def adjust_parameters(self):
        if not (self.config.training.score_model):
            self.ema.copy_to(self.score_model.parameters())
            for param in self.score_model.parameters():
                param.requires_grad = False
            self.score_model = self.score_model.eval()
        if self.config.training.score_model:
            for param in self.clf.parameters():
                param.requires_grad = False
    
    def score_fn(self, x, t, score_model='score_model'):
        score_model = getattr(self, score_model)
        config = self.config
        sde = self.sde
        
        if config.training.sde.lower() in ['vpsde', 'subvpsde']:
            if self.config.training.continuous or config.training.sde.lower(
            ) in ['subvpsde']:
                # For VP-trained models, t=0 corresponds to the lowest noise level
                # The maximum value of time embedding is assumed to 999 for
                # continuously-trained models.
                labels = t * 999
                score = score_model(x, labels)
                std = sde.marginal_prob(torch.zeros_like(x), t)[1]
            else:
                # For VP-trained models, t=0 corresponds to the lowest noise level
                labels = t * (sde.N - 1)
                score = score_model(x, labels)
                std = sde.sqrt_1m_alphas_cumprod.to(
                    labels.device)[labels.long()]

            score = -score / std[:, None, None, None]
        elif config.training.sde.lower() in ['vesde']:
            if self.config.training.continuous:
                labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
            else:
                # For VE-trained models, t=0 corresponds to the highest noise level
                labels = sde.T - t
                labels *= sde.N - 1
                labels = torch.round(labels).long()
            score = score_model(x, labels)
        return score

    def clf_fn(self, x, t, embed=False, clf_model='clf'):
        config = self.config
        sde = self.sde
        clf_model = getattr(self, clf_model)
        if config.training.sde.lower() in ['vpsde', 'subvpsde']:
            if self.config.training.continuous or config.training.sde.lower(
            ) in ['subvpsde']:
                # For VP-trained models, t=0 corresponds to the lowest noise level
                # The maximum value of time embedding is assumed to 999 for
                # continuously-trained models.
                labels = t * 999
                y = clf_model(x, labels, embed)
            else:
                raise NotImplementedError(
                    f"SDE class not yet supported for the classifier.")

        elif config.training.sde.lower() in ['vesde']:
            if self.config.training.continuous:
                labels = sde.marginal_prob(torch.zeros(t.shape[0], 1), t)[1]
            else:
                raise NotImplementedError(
                    f"SDE class not yet supported for the classifier.")
            y = clf_model(x, labels, embed)
            
        return y

    def combined_score_fn(self, batch, batch_t, labels=None, classifier_scale=1):
        if labels == None:
            return self.score_fn(batch, batch_t)
        else:
            x = batch 
            t = batch_t 
            lab = labels
            with torch.enable_grad():
                _, std_t = self.sde.marginal_prob(x, t)
                if x.is_leaf:
                    x.requires_grad = True
                if self.config.training.denoise_augment:
                    score = self.score_fn(x, t)
                else:
                    with torch.no_grad():
                        score = self.score_fn(x, t)
                xs = [x,(x+std_t[:,None,None,None]**2 * score)*(1 if self.config.training.denoise_augment else 0)]
                logp = torch.log_softmax(self.clf_fn(xs, t), dim=1)
                oh = F.one_hot(lab, logp.shape[1]).to(logp)
                eq_energy = (logp * oh).sum()
                grad = torch.autograd.grad(eq_energy, x)[0]
            clf_score = score + classifier_scale * grad 
            return clf_score

    def _classify(self, x, t,embed=False):
        i = 1 if self.config.training.denoise_augment else 0
        return self.clf_fn([x, x*i], t, embed)
    
    def clf_gradient(self,
                     samples_tau,
                     tau,
                     samples_t,
                     t,
                     samples_s,
                     s,
                     labels):
        with torch.enable_grad():
            x2 = samples_t
            
            if x2.is_leaf:
                x2.requires_grad = True
            
            _, std_t = self.sde.marginal_prob(samples_tau, t)
            if self.config.training.denoise_augment:
                with torch.no_grad():
                    score_t = self.score_fn(x2, t)
                i = 1
            else:
                score_t = torch.zeros_like(x2)
                i = 0
            
            if self.config.data.dataset == 'CIFAR10':
                transform = transforms.Compose([
                        transforms.RandomHorizontalFlip(),
                        transforms.RandomCrop(32, padding=4, padding_mode='reflect'),])
            else:
                # MNIST or SVHN
                transform = transforms.Compose([
                        transforms.RandomCrop(32, padding=4, padding_mode='reflect'),])
            
            samples_tau = transform(samples_tau)
            _, y_tau = self.clf_fn([(samples_tau), (samples_tau)*i], tau, embed=True)
            logp_tau = torch.log_softmax(y_tau, dim=1)
            
            aug_x2 = transform(torch.cat([x2, x2 + score_t * std_t[:, None, None, None]**2],axis=0))
            et, yt = self.clf_fn([
                aug_x2[:x2.shape[0]], aug_x2[x2.shape[0]:] * i
            ], t, embed=True)

            logp_t = torch.log_softmax(yt, dim=1)
            E_loss = 0
            if samples_s is not None:
                _, std_s = self.sde.marginal_prob(samples_tau, s)
                with torch.no_grad():
                    if self.config.training.denoise_augment:
                        score_s = self.score_fn(samples_s, s)
                    else:
                        score_s = torch.zeros_like(x2)
                    es, ys = self.clf_fn([(samples_s), (samples_s + score_s*std_s[:,None,None,None]**2 )*i], t, embed=True)
                    p_s = torch.softmax(ys,dim=1)
               
                E_loss += -(logp_t*p_s).sum(dim=1)[p_s.max(dim=1)[0]>0.95].mean()                    
        return logp_tau, logp_t, E_loss

    
    def unconditional_loss(self, batch):
        config = self.config
        sde = self.sde
        eps = self.sampling_eps

        reduce_mean = config.training.reduce_mean
        likelihood_weighting = config.training.likelihood_weighting

        reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(
            *args, **kwargs)

        t = torch.rand(batch.shape[0],
                           device=batch.device) * (sde.T - eps) + eps

        z = torch.randn_like(batch)
        mean, std = sde.marginal_prob(batch, t)

        perturbed_data = mean + std[:, None, None, None] * z
        score = self.score_fn(perturbed_data, t)
        if not likelihood_weighting:
            losses_unc = torch.square(score * std[:, None, None, None] + z)
            losses_unc = reduce_op(losses_unc.reshape(losses_unc.shape[0], -1),
                               dim=-1)
        else:
            g2 = sde.sde(torch.zeros_like(batch), t)[1] ** 2
            losses = torch.square(score + z / std[:, None, None, None])
            losses_unc = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1) * g2
        
        loss_unc = torch.mean(losses_unc)
        return loss_unc

    
    def get_loss(self, batch, labels):
        config = self.config
        sde = self.sde
        eps = self.sampling_eps
        
        t = torch.rand(batch.shape[0],
                           device=batch.device) * (sde.T - eps) + eps

        tau = torch.rand(batch.shape[0], device=batch.device) * 0.01 + eps

        losses = dict(loss_ce=0)

        selector = labels == -1

        # samples_tau
        mean, std = sde.marginal_prob(batch, tau)
        samples_tau = mean + std[:, None, None, None] * torch.randn_like(mean)

        # samples_t
        z = torch.randn_like(batch)
        mean, std = sde.marginal_prob(batch, t)
        samples_t = mean + std[:, None, None, None] * z

        # samples_s
        if config.data.labels_per_class!=-1:
            s = torch.rand(batch.shape[0],
                           device=batch.device) * (sde.T - eps) + eps
            _, std_s = self.sde.marginal_prob(mean, s)
            samples_s = mean + std_s[:, None, None, None] * torch.randn_like(mean)
        else:
            samples_s = None
            s = None
        logp_tau, logp_t, E_loss = self.clf_gradient(samples_tau, tau, samples_t, t, samples_s, s, labels)
        
        onehot_labels = F.one_hot(labels + 1,
                                  num_classes=self.config.classifier.classes +
                                  1)[:, 1:].float()
        onehot_labels[labels == -1] = torch.exp(logp_tau[labels == -1])
        oh = F.one_hot(onehot_labels.max(dim=1)[1],
                       num_classes=self.config.classifier.classes).float()
        
        ce_t = -(logp_t * oh).sum(dim=1)[onehot_labels.max(dim=1)[0]>0.95].mean()
        ce_tau = -(logp_tau * oh).sum(dim=1)[labels!=-1].mean()
        ent_tau = -(logp_tau * torch.exp(logp_tau)).sum(dim=1).mean()
        
        losses['ent_tau'] = ent_tau.item()
        losses['loss_sm'] = E_loss        
        losses['loss_ce'] = ce_t + ce_tau

        if config.data.pu:
            N = len(config.data.pu_config.use_classes)
            p = len(config.data.pu_config.positive_classes)
            avg_tau = torch.exp(logp_tau)[labels == -1].mean(dim=0)

            logEnt = torch.log(avg_tau + 1e-5)
            avg_t = torch.exp(logp_t)[labels==-1].mean(dim=0)
            logEnt += torch.log(avg_t + 1e-5)
            
            pu_loss = (-((1 - p / N) * logEnt[0] + (p / N) * logEnt[1]))
            losses['loss_sm'] = losses['loss_sm'] + pu_loss
        
        labels[selector] = -1

        return {
            **losses,
            'E_loss': E_loss,
        } | ({
            'pu_loss': pu_loss
        } if config.data.pu else dict())
        
    def getloss_from_losses(self, losses):
        loss = losses["loss_sm"] + losses["loss_ce"] 
        
        return loss

    def training_step(self, batch, batch_idx):
        if type(batch) == type(dict()):
            values = [batch[k] for k in batch if 'score' not in k]

            score_batch, unused_labels = batch['score']
            batch = torch.cat([v[0] for v in values])
            labels = torch.cat([v[1] for v in values])
        else:
            batch, labels = batch
        score_batch = self.scaler(score_batch)
        batch = self.scaler(batch)
        score_opt, clf_opt = self.optimizers()
        score_sched, clf_sched = self.lr_schedulers()
        config = self.config

        score_train = self.config.training.score_model
        clf_train = self.config.training.clf_model
        all_losses = {}
        if score_train:
            self.score_model = self.score_model.train()
            score_opt.zero_grad(set_to_none=True)
            loss1 = self.unconditional_loss(score_batch)
            all_losses['unc_loss'] = loss1
        else:
            loss1 = 0

        if clf_train:
            losses = self.get_loss(batch, labels)
            all_losses |= losses
            loss2 = self.getloss_from_losses(losses)
        else:
            loss2 = 0

        self.manual_backward(loss1 + loss2)
        if score_train:
            if config.optim.grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(self.score_model.parameters(),
                                               max_norm=config.optim.grad_clip)
            all_losses['score_lr'] = score_opt.param_groups[0]['lr']
            score_opt.step()
            score_sched.step()
            self.ema.update(self.score_model.parameters())
        # print(batch_idx%GRAD_ACC)
        if clf_train:
            if config.optim.grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(self.clf.parameters(),
                                               max_norm=config.optim.grad_clip)
            
            clf_opt.param_groups[0]['lr'] = clf_sched.get_last_lr()[0]
            all_losses['clf_lr'] = clf_opt.param_groups[0]['lr']
            all_losses['clf_wd'] = clf_opt.param_groups[0]['weight_decay']
            clf_opt.step()
            clf_sched.step()
            clf_opt.zero_grad(set_to_none=True)

            self.clf_ema.update(self.clf.parameters())
        
            with torch.no_grad():        
                training = self.training
                self.clf.train(False)
                t = torch.zeros(score_batch.shape[0],
                                device=score_batch.device) + self.sampling_eps

                mean, std = self.sde.marginal_prob(score_batch, t)
                b = mean + std[:, None, None, None] * torch.randn_like(mean)
                preds = self._classify(b, t).max(dim=1)[1]
                acc = (1.0 * (preds == unused_labels)).mean()
                if config.data.pu:
                    f1 = f1_score(unused_labels.detach().cpu().numpy(),
                                preds.detach().cpu().numpy())
                    all_losses |= {'f1': f1}
                all_losses |= {'acc': acc}
                self.clf.train(training)
        
        self.log_dict(all_losses,
                    prog_bar=True,
                    logger=True,
                    on_step=True,
                    on_epoch=True,
                    sync_dist=True)
        

    def validation_step(self, batch, batch_idx):
        batch, labels = batch
        batch = self.scaler(batch)
        
        unc_loss = self.unconditional_loss(batch)
        losses = {'unc_loss':unc_loss}

        if self.config.training.clf_model:
            losses = self.get_loss(batch, labels)

            t = torch.zeros(batch.shape[0],
                            device=batch.device) + self.sampling_eps
            preds = self._classify(batch, t).max(dim=1)[1]
            acc = ((preds == labels) * (labels != -1)).sum() / (labels != -1).sum()
            if self.config.data.pu:
                f1 = f1_score(labels.detach().cpu().numpy(),
                            preds.detach().cpu().numpy())
                losses['f1'] = f1
            self.clf_ema.store(self.clf.parameters())
            self.clf_ema.copy_to(self.clf.parameters())
            preds = self.clf_fn([batch,batch], t).max(dim=1)[1]
            acc_ema = ((preds == labels) *
                    (labels != -1)).sum() / (labels != -1).sum()
            if self.config.data.pu:
                f1_ema = f1_score(labels.detach().cpu().numpy(),
                            preds.detach().cpu().numpy())
                losses['f1_ema'] = f1_ema
            self.clf_ema.restore(self.clf.parameters())
            losses["unc_loss"] = unc_loss
            losses["acc"] = acc
            losses["acc_ema"] = acc_ema
            losses = {k + "_v": v for k, v in losses.items()}
        self.log_dict(losses,
                      prog_bar=True,
                      logger=True,
                      on_step=True,
                      on_epoch=True,
                      sync_dist=True)

    def on_save_checkpoint(self, checkpoint):
        print("Saved EMAs")
        checkpoint["ema"] = self.ema.state_dict()
        checkpoint["clf_ema"] = self.clf_ema.state_dict()

    def on_load_checkpoint(self, checkpoint):
        print("Loaded EMAs")
        self.ema.load_state_dict(checkpoint["ema"])
        self.clf_ema.load_state_dict(checkpoint["clf_ema"])
        
    def configure_optimizers(self):
        score_opt = torch.optim.Adam(list(self.score_model.parameters()),
                                     lr=self.config.optim.lr,
                                     betas=(self.config.optim.beta1, 0.999),
                                     eps=self.config.optim.eps)
        config = self.config

        def nodecay(step):
            return min(step / config.optim.warmup, 1)

        score_scheduler = torch.optim.lr_scheduler.LambdaLR(score_opt,
                                                            lr_lambda=nodecay)
        
        
        clf_opt = torch.optim.AdamW(list(self.clf.parameters()),
                                    lr=3e-4,
                                    weight_decay=0.05)
        clf_scheduler = torch.optim.lr_scheduler.LambdaLR(
            clf_opt, lambda step: 1)
            
        return (
            {
                "optimizer": score_opt,
                "lr_scheduler": score_scheduler,
            },
            {
                "optimizer": clf_opt,
                "lr_scheduler": clf_scheduler
            },
        )

    def reversediffusion_langevin_samples(self, num_samples, labels=None, classifier_scale=1):
        import tqdm
        score_fn = partial(self.combined_score_fn,
                           labels=labels,
                           classifier_scale=classifier_scale)
        # rsde = self.sde.reverse(score_fn)
        N = self.config.model.num_scales
        batch = self.sde.prior_sampling(shape=(num_samples,
                                               self.config.data.num_channels,
                                               self.config.data.image_size,
                                               self.config.data.image_size))
        batch = batch.to(device=self.device)
        timesteps = torch.linspace(self.sde.T,
                                   self.sampling_eps,
                                   N,
                                   device=self.device)
        
        config = self.config

        x = batch
        ones = torch.ones(num_samples, device=x.device)
        for i in tqdm.tqdm(range(N)):
            t = timesteps[i]
            vec_t = ones * t
            
            if isinstance(self.sde, sde_lib.VPSDE) or isinstance(
                    self.sde, sde_lib.subVPSDE):
                timestep = (vec_t * (self.sde.N - 1) / self.sde.T).long()
                alpha = self.sde.alphas.to(vec_t.device)[timestep]
                alpha = alpha[:,None,None,None]
            else:
                alpha = 1

            for j in range(config.sampling.n_steps_each):
                grad = score_fn(x, vec_t)
                noise = torch.randn_like(x)
                grad_norm = torch.norm(grad.view(grad.shape[0], -1),
                                       dim=-1).mean()
                noise_norm = torch.norm(noise.view(noise.shape[0], -1),
                                        dim=-1).mean()
                step_size = (config.sampling.snr * noise_norm /
                             grad_norm)**2 * 2 * alpha
                x_mean = x + step_size * grad
                x = x_mean + torch.sqrt(step_size * 2) * noise

            f_forward, g_forward = self.sde.discretize(x, vec_t)
            f_backward = (f_forward - g_forward[:, None, None, None]**2 *
                          score_fn(x, vec_t)) 
            g_backward = g_forward 
            z = torch.randn_like(x)
            x_mean = x - f_backward
            x = x_mean + g_backward[:, None, None, None] * z

        if self.config.sampling.noise_removal:
            return self.inverse_scaler(x_mean)
        else:
            return self.inverse_scaler(x)


    def load_pretrained_state_dict(self, state_dict):
        print(state_dict.keys())
        self.score_model.load_state_dict(
            state_dict["model"],
            strict=False)  
        self.ema.load_state_dict(
            state_dict["ema"])
