import os
import math
import torch
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.spectral_norm as spectral_norm
#os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import sympy
import random
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
from scipy.optimize import minimize
from sympy import *
import copy
import json
import warnings
from absl import app, flags
import torch
#from torchmin import minimize
from tensorboardX import SummaryWriter
from torchvision.datasets import CIFAR10
from torchvision.utils import make_grid, save_image
from torchvision import transforms
from tqdm import trange
from tqdm import tqdm
import logging
from model import UNet
from score.both import get_inception_and_fid_score
from libs.iddpm import UNetModel,UNetModel4Pretrained,UNetModel4Pretrained2
from adan import Adan
from shampoo import Shampoo
FLAGS = flags.FLAGS
flags.DEFINE_bool('train', False, help='train from scratch')
flags.DEFINE_bool('eval', False, help='load ckpt.pt and evaluate FID and IS')
# UNet: IDDPM
flags.DEFINE_integer('in_channel', 3, help='input channel of UNet')
flags.DEFINE_integer('out_channel', 3, help='output channel of UNet')
flags.DEFINE_integer('ch', 128, help='base channel of UNet')
flags.DEFINE_integer('num_res_blocks', 3, help='# resblock in each level')
flags.DEFINE_integer('num_heads', 4, help='Multi-Heads for attention')
flags.DEFINE_integer('dims', 2, help='1,2,3 dims')
flags.DEFINE_multi_integer('ch_mult', [1, 2, 2, 2], help='channel multiplier')
flags.DEFINE_multi_integer('attn', [32 // 16, 32 // 8], help='add attention to these levels')
flags.DEFINE_float('dropout', 0.3, help='dropout rate of resblock')
flags.DEFINE_bool('use_scale_shift_norm', True, help='load ckpt.pt and evaluate FID and IS')
flags.DEFINE_string('exp_name', 'CIFAR10', help='name of experiment')

flags.DEFINE_integer('head_out_channels', 3, help='the final layer of High order noise network')
flags.DEFINE_enum('mode', 'simple', ['simple','complex'], help='the mode for honn modeling')

# Gaussian Diffusion
flags.DEFINE_float('beta_1', 1e-4, help='start beta value')
flags.DEFINE_float('beta_T', 0.02, help='end beta value')
flags.DEFINE_integer('T', 1000, help='total diffusion training noising steps')
flags.DEFINE_enum('sample_type', 'ddpm', ['ddpm', 'analyticdpm', 'gmddpm','mean_network'], help='sample type for sampling')
flags.DEFINE_enum('mean_type', 'epsilon', ['xprev', 'xstart', 'epsilon'], help='predict variable')
flags.DEFINE_enum('var_type', 'fixedlarge', ['fixedlarge', 'fixedsmall'], help='variance type')
# Training
flags.DEFINE_float('lr', 1e-4, help='target learning rate')
flags.DEFINE_float('grad_clip', 1., help="gradient norm clipping")
flags.DEFINE_integer('total_steps', 500001, help='total training steps')
flags.DEFINE_integer('img_size', 32, help='image size')
flags.DEFINE_integer('warmup', 5000, help='learning rate warmup')
flags.DEFINE_integer('batch_size', 128, help='batch size')
flags.DEFINE_integer('num_workers', 4, help='workers of Dataloader')
flags.DEFINE_integer('noise_order', 1, help="the order of noise used to training")
flags.DEFINE_float('ema_decay', 0.9999, help="ema decay rate")
flags.DEFINE_bool('parallel', False, help='multi gpu training')
flags.DEFINE_string('pretrained_dir', './logs/iDDPM_CIFAR10_EPS/models/ckpt_1_450000.pt', help='log directory')

# Logging & Sampling
flags.DEFINE_string('logdir', './logs/iDDPM_CIFAR10_EPS', help='log directory')
flags.DEFINE_integer('sample_size', 64, "sampling size of images")
flags.DEFINE_integer('sample_step', 10000, help='frequency of sampling')
flags.DEFINE_integer('sample_steps', 1000, help='Sampling steps for generation stage')
# Evaluation
flags.DEFINE_integer('save_step', 50000, help='frequency of saving checkpoints, 0 to disable during training')
flags.DEFINE_integer('eval_step', 0, help='frequency of evaluating model, 0 to disable during training')
flags.DEFINE_integer('num_images', 50000, help='the number of generated images for evaluation')
flags.DEFINE_bool('fid_use_torch', False, help='calculate IS and FID on gpu')
flags.DEFINE_bool('time_shift', False, help='whether the noised data is from t=1')
flags.DEFINE_bool('rescale_time', True, help='adjust the maxmimum time to input the network is 1000')
flags.DEFINE_bool('nll_training', False, help='training the model to fit the noise.pow(a)')
flags.DEFINE_enum('noise_schedule', 'linear', ['linear','cosine'], help='the mode for honn modeling')
flags.DEFINE_string('fid_cache', './stats/cifar10.train.npz', help='FID cache')
# Model Dir
flags.DEFINE_string('eps1_dir', './logs/iDDPM_CIFAR10_EPS/models/ckpt_1_300000.pt', help='eps1 model log directory')
flags.DEFINE_string('eps2_dir', './logs/iDDPM_CIFAR10_EPS2/models/ckpt_2_300000.pt', help='eps2 model log directory')
flags.DEFINE_string('eps3_dir', './logs/iDDPM_CIFAR10_complex_EPS3/models/ckpt_3_300000.pt', help='eps3 model log directory')
flags.DEFINE_string('eps4_dir', './logs/iDDPM_CIFAR10_complex_EPS4/models/ckpt_4_300000.pt', help='eps4 model log directory')

device = torch.device('cuda:0')

def statistics2str(statistics):
    #for k,v in statistics.items():
    #    print(v)
    return str({k: f'{v:.6g}' for k, v in statistics.items()})


def report_statistics(s, t, statistics):
    logging.info(f'[(s, r): ({s:.6g}, {t:.6g})] [{statistics2str(statistics)}]')


class TemporaryGrad(object):
    def __enter__(self):
        self.prev = torch.is_grad_enabled()
        torch.set_grad_enabled(True)

    def __exit__(self, exc_type, exc_value, traceback) -> None:
        torch.set_grad_enabled(self.prev)

def solve_analytic(mean,cov,ske):
    Z12 = mean.pow(2)+cov
    #Z12 = 
    Z11 = mean
    Z13 = ske
    #logging.info(2*Z11.pow(3)-3*Z11*Z12+Z13)
    #logging.info(Z13)
    com_p = torch.where(2*Z11.pow(3)-3*Z11*Z12+Z13>=0,2*Z11.pow(3)-3*Z11*Z12+Z13,0)
    #logging.info((com_p).min())
    #logging.info((com_p).pow(1/3))
    mean1 = mean + 1.5874*(com_p).pow(1/3)
    mean2 = 0.5 * (2*Z11-1.5874*(com_p).pow(1/3))
    cov_m = (-Z11.pow(2)+Z12-1.25992*(com_p).pow(2/3))
    cov_m2 = torch.where(cov_m>=0,cov_m,cov)
    mean1 = torch.where(cov_m>=0,mean1,mean)
    mean2 = torch.where(cov_m>=0,mean2,mean)

    bar1 = 0.85
    bar2 = 0.9

    cov_m3 = torch.where(cov_m2<=bar1*cov,cov,cov_m2)
    mean1 = torch.where(cov_m2<=bar1*cov,mean,mean1)
    mean2 = torch.where(cov_m2<=bar1*cov,mean,mean2)
    mean1f = torch.where(torch.abs((mean1-mean)/mean)>=1-bar2,mean,mean1)
    mean2 = torch.where(torch.abs((mean1-mean)/mean)>=1-bar2,mean,mean2)

    mean2f = torch.where(torch.abs((mean2-mean)/mean)>=1-bar2,mean,mean2)
    mean1f = torch.where(torch.abs((mean2-mean)/mean)>=1-bar2,mean,mean1f)
    #logging.info(cov_m3.mean())
    #logging.info((1/3*mean1+2/3*mean2-mean).min())
    return mean1f,mean2f,cov_m3

def solve_gmm(mean,cov,ske,kur,gt,timestep,report_dict):
    device= mean.device
    x0 = torch.unsqueeze((mean),dim=0)
    x1 = torch.unsqueeze((mean-1e-3),dim=0)
    #beta2 = torch.unsqueeze(((torch.ones(size=mean.size()).to(device))*(cov/gt.mean().item())),dim=0)
    beta = torch.unsqueeze((torch.ones(size=mean.size()).to(device)*0.998),dim=0)
    #x     = torch.cat([x0,x1,beta1,beta2],axis=0)
    #x0,x1,beta = solve_analytic(mean,cov,ske)
    x     = torch.cat([x0,x1,beta],axis=0)
    cov_g = gt
    def loss_f(tensor):
        #if solve_type =='pi':
        x0, x1, beta = tensor[0,...], tensor[1,...],tensor[2,...]
        #x0, x1, beta1 = tensor[0,...], tensor[1,...],tensor[2,...]
        beta = torch.clamp(beta, 0.1, 1.2)
        #beta2 = 1
        pi = 1/3
        E0 = (pi*x0 + (1-pi)*x1 - mean).pow(2)
        E1 = (pi*(x0**2+cov_g*beta)+(1-pi)*(x1**2+cov_g*beta) - (mean**2+cov)).pow(2)
        E2 = (pi*(x0**3+3*x0*cov_g*beta)+(1-pi)*(x1**3+3*x1*cov_g*beta) - ske).pow(2)
        if kur is not None:
            E3 = (pi*(x0**4+6*x0**2*cov_g+3*(cov_g)**2)+(1-pi)*(x1**4+6*x1**2*cov_g*beta+3*(cov_g*beta)**2) - kur).pow(2)
        else:
            E3 = 0
        #return ((E0+E1+E2)).mean(),((E0+E1+E2)).max(),E0.mean(),E2.mean()
        return ((E0+E1+E2)).mean(),E2.max()
    import time
    s = time.time()

    #def warmup_lr(step):
    #    return min(step, 10) / 10
    warm_up    = 18
    iterations = 25
    lr     = max(-0.1*((1000-timestep)**2/1000**2)+0.16,0.12)
    #lr     = 0.02
    min_lr = 0.10

    warm_up_with_cosine_lr = lambda iter: (iter) / warm_up if iter <= warm_up \
        else max(0.5 * ( math.cos((iter - warm_up) /(iterations - warm_up) * math.pi) + 1), 
        min_lr / lr)

    """
    warm_up_with_cosine_lr = lambda iter: iter / opt.warm_iters if iter <= opt.warm_iters \
        else max(0.5 * ( math.cos((iter - opt.warm_iters) /(opt.iters - opt.warm_iters) * math.pi) + 1), 
        opt.min_lr / opt.lr)
    """

    with TemporaryGrad():
        #optimizer_solve = torch.optim.Adam([x],lr=lr,betas=(0.9, 0.95))
        #optimizer_solve = torch.optim.RMSprop([x],lr=lr,alpha=0.9)
        #optimizer_solve = torch.optim.Adagrad([x],lr=lr,weight_decay=1e-4)
        #optimizer_solve = torch.optim.AdamW([x],lr=lr,weight_decay=1e-4)
        optimizer_solve = Adan([x],lr=lr,betas=(0.9,0.92,0.92))
        #optimizer_solve = Shampoo([x],lr=lr,momentum=0.9)

        #sched = torch.optim.lr_scheduler.LambdaLR(optimizer_solve, lr_lambda=warmup_lr)
        #scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, warm_up_with_cosine_lr)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer_solve, warm_up_with_cosine_lr)
        #for step in range(200):
        pred_0,max_pre_E_2 = loss_f(x)
        for step in range(iterations):
            x.requires_grad = True
            #pred,pre_max,E_0_pre,E_2_pre = loss_f(x)
            pred,max_E_2 = loss_f(x)
            optimizer_solve.zero_grad()
            pred.backward()
            optimizer_solve.step()
            scheduler.step()
            #x[3,...] = torch.clip(x[3,...],0,1)
        # return mu1 mu2 sigma1 sigma2
    e = time.time()
    logging.info(e-s)
    report_dict['mean optimize'] = pred/pred_0
    report_dict['3-max optimize'] = max_E_2/max_pre_E_2
    #logging.info("the first output is {0} and final output rmse is {1}".format(pred_0,pred))
    #logging.info("the first output max is {0} and final output max is {1}".format(max_pre_0,pre_max))
    #logging.info('mean optimize {0},max optimize {1},onemoment optimize {2},threemoment optimize {2}'.format(pred/pred_0,pre_max/max_pre_0,E_0_pre/E_0,E_2_pre/E_2))
    #logging.info(x[0,...].mean())
    #logging.info(x[1,...].mean())
    return x[0,...], x[1,...],torch.clamp(x[2,...], 0.1, 1.2),report_dict

"""
Z11:mean
Z12:mean^2+cov
Z13: Ske
x  -> Z11 + 1.5874 (2. Z11^3 - 3. Z11 Z12 + Z13)^(1/3), 
y  -> 0.5 (2. Z11 - 1.5874 (2. Z11^3 - 3. Z11 Z12 + Z13)^(1/3)), 
Z1 -> (-1. Z11^2 + Z12 - 1.25992 (2. Z11^3 - 3. Z11 Z12 + Z13)^(2/3))
"""

def extract(v, t, x_shape,ratio=None):
    """
    Extract some coefficients at specified timesteps, then reshape to
    [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
    """
    #if ratio:
    #    out = torch.ones(size=(200,1)).squeeze()
    #    for ele in range(ratio):
    #        out *= torch.gather(v, index=t-ele, dim=0).float()
    out = torch.gather(v, index=t, dim=0).float()
    #print
    return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))


class GaussianDiffusionSampler(nn.Module):
    def __init__(self, eps1_model,eps2_model,eps3_model,eps4_model, beta_1, beta_T, T,img_size=32,
                 sample_type='ddpm',time_shift=True,noise_schedule='linear'):
        assert sample_type in ['ddpm', 'analyticdpm', 'gmddpm','mean_network']
        super().__init__()
        self.model      = eps1_model
        self.cov_model  = eps2_model
        self.eps3_model = eps3_model
        self.eps4_model = eps4_model
        self.T = T
        self.total_T = 1000
        if self.total_T % self.T  ==0:
            self.ratio = int(self.total_T/self.T)
        else:
            self.ratio = int(self.total_T/self.T)+1
        self.t_list = [max(self.total_T-1-self.ratio*x,0) for x in range(T)]
        if self.t_list[-1] != 0:
            self.t_list.append(0)
        logging.info(self.t_list)

        self.img_size  = img_size
        self.sample_type = sample_type
        self.time_shift  = time_shift
        self.noise_schedule = noise_schedule
        if noise_schedule=='linear':
            self.register_buffer(
                'betas', torch.linspace(beta_1, beta_T, self.total_T).double())
            alphas = 1. - self.betas
            alphas_bar = torch.cumprod(alphas, dim=0)
            # calculations for diffusion q(x_t | x_{t-1}) and others
        else:
            logging.info(noise_schedule)
            g = lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
            betas = [0.]
            for i in range(self.total_T):
                t1 = i / self.total_T
                t2 = (i + 1) / self.total_T
                betas.append(min(1 - g(t2) / g(t1), 0.999))
            betas = torch.tensor(np.array(betas))
            self.register_buffer(
                'betas', betas[1:])
            alphas= 1-betas
            alphas_bar = torch.cumprod(alphas[1:], dim=0)
            alphas = alphas[1:]
            #logging.info(alphas_bar)
            logging.info(alphas_bar.size())

        alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:self.total_T]
        self.register_buffer(
            'sqrt_alphas_bar', torch.sqrt(alphas_bar))
        self.register_buffer(
            'one_minus_alphas_bar', (1.- alphas_bar))
        self.register_buffer(
            'sqrt_recip_one_minus_alphas_bar', 1./torch.sqrt(1.- alphas_bar))

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer(
            'sqrt_recip_alphas_bar', torch.sqrt(1. / alphas_bar))
        self.register_buffer(
            'sqrt_recipm1_alphas_bar', torch.sqrt(1. / alphas_bar - 1))

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        self.register_buffer(
            'posterior_var',
            self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar))
        
        # below: log calculation clipped because the posterior variance is 0 at
        # the beginning of the diffusion chain
        self.register_buffer(
            'posterior_log_var_clipped',
            torch.log(
                torch.cat([self.posterior_var[1:2], self.posterior_var[1:]])))
        
        self.register_buffer(
            'posterior_mean_coef1',
            torch.sqrt(alphas_bar_prev) * self.betas / (1. - alphas_bar))
        self.register_buffer(
            'posterior_mean_coef2',
            torch.sqrt(alphas) * (1. - alphas_bar_prev) / (1. - alphas_bar))

    def q_mean_variance(self, x_0, x_t, t):
        """
        Compute the mean and variance of the diffusion posterior
        q(x_{t-1} | x_t, x_0)
        """
        assert x_0.shape == x_t.shape
        posterior_mean = (
            extract(self.posterior_mean_coef1, t, x_t.shape) * x_0 +
            extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        posterior_log_var_clipped = extract(
            self.posterior_log_var_clipped, t, x_t.shape)
        return posterior_mean, posterior_log_var_clipped

    # use eps to estimate one order moment
    def predict_xpre_from_eps(self, x_t, t, eps):
        assert x_t.shape == eps.shape
        a_t = extract(self.sqrt_alphas_bar, t, x_t.shape)
        if (t-self.ratio)[0]>=0:
            a_s  = extract(self.sqrt_alphas_bar, t-self.ratio, x_t.shape)
            a_ts = extract(self.sqrt_alphas_bar, t, x_t.shape)/extract(self.sqrt_alphas_bar, t-self.ratio, x_t.shape)
            sigma_s = torch.sqrt(extract(self.one_minus_alphas_bar, t-self.ratio, x_t.shape))
            sigma_t = torch.sqrt(extract(self.one_minus_alphas_bar, t, x_t.shape))
            beta_ts = sigma_t**2-a_ts**2*sigma_s**2
        else:
            a_s  = extract(self.sqrt_alphas_bar, t-t, x_t.shape)
            a_ts = extract(self.sqrt_alphas_bar, t, x_t.shape)/extract(self.sqrt_alphas_bar, t-t, x_t.shape)
            sigma_s = torch.sqrt(extract(self.one_minus_alphas_bar, t-t, x_t.shape))
            sigma_t = torch.sqrt(extract(self.one_minus_alphas_bar, t, x_t.shape))
            beta_ts = sigma_t**2-a_ts**2*sigma_s**2
        mean_x0 = (x_t - sigma_t * eps)/a_t
        self.statistics['xt_mean'] = x_t.mean().item()
        self.statistics['eps_mean'] = eps.mean().item()
        self.statistics['unclip mean_x0_mean'] = mean_x0.mean().item()
        mean_x0 = mean_x0.clamp(-1.,1.)
        self.statistics['clip mean_x0_mean'] = mean_x0.mean().item()
        mean_xs = a_ts*sigma_s.pow(2)/(sigma_t.pow(2)) * x_t + a_s*beta_ts/(sigma_t.pow(2)) * mean_x0
        mean_xs = mean_xs.clamp(-100.,100.)
        self.statistics['clip mean_xs_max'] = mean_xs.max().item()
        return mean_xs,mean_x0

    # use eps and eps2 to estimate one order moment
    def predict_xpre_cov_from_eps(self, x_t, t, eps):
        if self.time_shift:
            eps2 = self.cov_model(x_t, t+1)
        else:
            eps2 = self.cov_model(x_t, t+1)
        a_t  = extract(self.sqrt_alphas_bar, t, x_t.shape)

        if (t-self.ratio)[0]>=0:
            a_s  = extract(self.sqrt_alphas_bar, t-self.ratio, x_t.shape)
            a_ts = extract(self.sqrt_alphas_bar, t, x_t.shape)/extract(self.sqrt_alphas_bar, t-self.ratio, x_t.shape)
            #a_ts = extract(self.sqrt_recip_alphas_bar, t-self.ratio, x_t.shape)/extract(self.sqrt_recip_alphas_bar, t, x_t.shape)
            sigma_t = torch.sqrt(extract(self.one_minus_alphas_bar, t, x_t.shape))
            sigma_s = torch.sqrt(extract(self.one_minus_alphas_bar, t-self.ratio, x_t.shape))
            beta_ts = sigma_t**2-a_ts**2*sigma_s**2
        else:
            # \alpha_{t|s}
            a_s  = extract(self.sqrt_alphas_bar, t-t, x_t.shape)
            a_ts = extract(self.sqrt_alphas_bar, t, x_t.shape)/extract(self.sqrt_alphas_bar, t-t, x_t.shape)
            #a_ts = extract(self.sqrt_recip_alphas_bar, t-t, x_t.shape)/extract(self.sqrt_recip_alphas_bar, t, x_t.shape)
            sigma_t = torch.sqrt(extract(self.one_minus_alphas_bar, t, x_t.shape))
            sigma_s = torch.sqrt(extract(self.one_minus_alphas_bar, t-t, x_t.shape))
            beta_ts = sigma_t**2-a_ts**2*sigma_s**2

        sigma2_small = (sigma_s**2*beta_ts)/(sigma_t**2)
        cov_x0_pred = sigma_t.pow(2)/a_t.pow(2) * (eps2-eps.pow(2)) 
        self.statistics['unclip cov_x0_mean'] = cov_x0_pred.mean().item()
        cov_x0_pred = cov_x0_pred.clamp(0., 1.)
        self.statistics['clip cov_x0_mean'] = cov_x0_pred.mean().item()
        offset = a_s.pow(2)*beta_ts.pow(2)/sigma_t.pow(4) * cov_x0_pred
        self.statistics['offset'] = offset.mean().item()
        self.statistics['offset_max'] = offset.max().item()
        self.statistics['sigma2_small'] = sigma2_small.mean().item()
        model_var  = sigma2_small + offset
        model_var  = model_var.clamp(0., 1.)
        return model_var,eps2
    
    def ddpm_cov(self, x_t, t):
        sigma_t = torch.sqrt(extract(self.one_minus_alphas_bar, t, x_t.shape))
        if (t-self.ratio)[0]>=0:
            # \alpha_{t|s}
            a_ts = extract(self.sqrt_alphas_bar, t, x_t.shape)/extract(self.sqrt_alphas_bar, t-self.ratio, x_t.shape)
            sigma_s = torch.sqrt(extract(self.one_minus_alphas_bar, t-self.ratio, x_t.shape))
            beta_ts = sigma_t**2-a_ts**2*sigma_s**2
        else:
            # \alpha_{t|s}
            a_ts = extract(self.sqrt_alphas_bar, t, x_t.shape)/extract(self.sqrt_alphas_bar, t-t, x_t.shape)
            sigma_s = torch.sqrt(extract(self.one_minus_alphas_bar, t-t, x_t.shape))
            beta_ts = sigma_t**2-a_ts**2*sigma_s**2

        model_var1 = (sigma_s**2*beta_ts)/(sigma_t**2)
        self.statistics['sigma2_small'] = model_var1.mean().item()
        return model_var1

    # use eps and eps2 and eps3 to estimate one order moment
    def predict_xpre_3moment_from_eps(self, x_t, t, eps, eps2, mean):
        if self.time_shift:
            eps3 = self.eps3_model(x_t, t)
        else:
            eps3 = self.eps3_model(x_t, t)
        sigma_t = torch.sqrt(extract(self.one_minus_alphas_bar, t, x_t.shape))
        a_t     = extract(self.sqrt_alphas_bar, t, x_t.shape)
        if (t-self.ratio)[0]>=0:
            # \alpha_{t|s}
            a_s  = extract(self.sqrt_alphas_bar, t-self.ratio, x_t.shape)
            a_ts = extract(self.sqrt_alphas_bar, t, x_t.shape)/extract(self.sqrt_alphas_bar, t-self.ratio, x_t.shape)
            sigma_s = torch.sqrt(extract(self.one_minus_alphas_bar, t-self.ratio, x_t.shape))
            beta_ts = sigma_t**2-a_ts**2*sigma_s**2
        else:
            # \alpha_{t|s}
            a_s  = extract(self.sqrt_alphas_bar, t-t, x_t.shape)
            a_ts = extract(self.sqrt_alphas_bar, t, x_t.shape)/extract(self.sqrt_alphas_bar, t-t, x_t.shape)
            sigma_s = torch.sqrt(extract(self.one_minus_alphas_bar, t-t, x_t.shape))
            beta_ts = sigma_t**2-a_ts**2*sigma_s**2

        mean_x0 = (x_t - sigma_t * eps)/a_t
        twom_x0 = 1/(a_t.pow(2))*(x_t.pow(2)+sigma_t.pow(2)*eps2-2*x_t*sigma_t*eps)
        skew_x0 = 1/(a_t.pow(3))*(x_t.pow(3) - sigma_t.pow(3)*eps3 - 3*x_t.pow(2)*sigma_t*eps + 3*x_t*sigma_t.pow(2)*eps2)
        mean_x0 = mean_x0.clamp(-1., 1.)
        twom_x0 = twom_x0.clamp(0., 1.)
        self.statistics['unclip_x0_skew'] = skew_x0.mean().item()
        skew_x0 = torch.where(torch.abs(skew_x0)<=torch.abs(mean_x0),skew_x0,mean_x0)
        skew_x0 = skew_x0.clamp(-1., 1.)
        self.statistics['clip_x0_skew'] = skew_x0.mean().item()
        sigma2_small = (sigma_s**2*beta_ts)/(sigma_t**2)

        skew_xs_part1 = (a_ts*sigma_s.pow(2)/(sigma_t.pow(2)) * x_t).pow(3)+\
            3*(a_ts*sigma_s.pow(2)/(sigma_t.pow(2)) * x_t).pow(2)*(a_s*beta_ts/sigma_t.pow(2))*mean_x0 +\
            3*(a_ts*sigma_s.pow(2)/(sigma_t.pow(2)) * x_t)*(a_s*beta_ts/sigma_t.pow(2)).pow(2)*twom_x0 +\
            (a_s*beta_ts/sigma_t.pow(2)).pow(3)*skew_x0
        skew_xs_part2 = 3*sigma2_small*(a_ts*sigma_s.pow(2)/(sigma_t.pow(2)) * x_t + a_s*beta_ts/(sigma_t.pow(2)) * mean_x0)
        skew_xs  = skew_xs_part1+skew_xs_part2
        #part1 = 1/(a_ts**3) * ((x_t**3) - 3*(x_t**2)*eps*(beta_ts/sigma_t)+3*(x_t)*eps2*(beta_ts**2/sigma_t**2)-(beta_ts/sigma_t)**3*eps3)
        #part2 = 3*(sigma_s**2*beta_ts)/(sigma_t**2) * (1/a_ts) * (x_t-beta_ts/sigma_t*eps)
        #third_moment = part1 + part2 
        self.statistics['clip_xs_skew'] = skew_xs.mean().item()
        return skew_xs,eps3

    # use eps and eps2 and eps3 and eps4 to estimate one order moment
    def predict_xpre_4moment_from_eps(self, x_t, t, eps,eps2,eps3):
        if self.time_shift:
            eps4 = self.eps4_model(x_t, t)
        else:
            eps4 = self.eps4_model(x_t, t+1)
        sigma_t = torch.sqrt(extract(self.one_minus_alphas_bar, t, x_t.shape))

        if (t-self.ratio)[0]>=0:
            a_ts = extract(self.sqrt_alphas_bar, t, x_t.shape)/extract(self.sqrt_alphas_bar, t-self.ratio, x_t.shape)
            sigma_s = torch.sqrt(extract(self.one_minus_alphas_bar, t-self.ratio, x_t.shape))
            beta_ts = sigma_t**2-a_ts**2*sigma_s**2
        else:
            a_ts = extract(self.sqrt_alphas_bar, t, x_t.shape)/extract(self.sqrt_alphas_bar, t-t, x_t.shape)
            sigma_s = torch.sqrt(extract(self.one_minus_alphas_bar, t-t, x_t.shape))
            beta_ts = sigma_t**2-a_ts**2*sigma_s**2

        part1 = 1/(a_ts**4) * ((x_t**4)-4*(x_t**3)*(beta_ts/sigma_t)*eps+6*(x_t**2)*(beta_ts/sigma_t)**2*eps2-4*(x_t)*(beta_ts/sigma_t)**3*eps3+(beta_ts/sigma_t)**4*eps4)
        part2 = 6*1/(a_ts**2)*((x_t**2)-2*(x_t)*(beta_ts/sigma_t)*eps+(beta_ts/sigma_t)**2*eps2)*(sigma_s**2*beta_ts)/sigma_t**2
        part3 = 3*((sigma_s**2*beta_ts)/sigma_t**2)**2
        four_moment = part1 + part2 + part3
        return four_moment
        
    #@torch.no_grad()
    def p_mean_variance(self, x_t, t):
        # below: only log_variance is used in the KL computations or Analytic-DPM
        # Mean parameterization
        if self.sample_type == 'ddpm':   # the model predicts epsilon
            if self.time_shift:
                eps = self.model(x_t, t+1)
            else:
                eps = self.model(x_t, t)
            model_mean,mean_x0 = self.predict_xpre_from_eps(x_t, t, eps=eps)
            model_log_var = {
            # for fixedlarge, we set the initial (log-)variance like so to
            # get a better decoder log likelihood
            'fixedlarge': torch.log(torch.cat([self.posterior_var[1:2],
                                               self.betas[1:]])),
            'fixedsmall': self.posterior_log_var_clipped,
            }['fixedsmall']
            if self.ratio == 1:
                model_log_var = extract(model_log_var, t, x_t.shape)
                return model_mean, torch.exp(model_log_var)
            else:
                model_log_var = self.ddpm_cov(x_t,t)
                return model_mean,model_log_var
                
        elif  self.sample_type == 'mean_network':
            eps = self.model(x_t, t)
            model_mean = eps
            model_log_var = {
            # for fixedlarge, we set the initial (log-)variance like so to
            # get a better decoder log likelihood
            'fixedlarge': torch.log(torch.cat([self.posterior_var[1:2],
                                               self.betas[1:]])),
            'fixedsmall': self.posterior_log_var_clipped,
            }['fixedsmall']
            model_log_var = extract(model_log_var, t, x_t.shape)
            return model_mean,torch.exp(model_log_var)

        elif self.sample_type == 'analyticdpm':
            assert self.cov_model is not None
            if self.time_shift:
                eps = self.model(x_t, t+1)
            else:
                eps = self.model(x_t, t)
            model_mean,mean_x0 = self.predict_xpre_from_eps(x_t, t, eps=eps)
            model_var,eps2 = self.predict_xpre_cov_from_eps(x_t, t, eps)
            #print(model_var.mean())
            return model_mean, model_var

        elif self.sample_type == 'gmddpm':
            assert self.eps3_model is not None
            #assert self.eps4_model is not None
            if self.time_shift:
                eps = self.model(x_t, t+1)
            else:
                eps = self.model(x_t, t)
            # mean function
            mean,mean_x0     = self.predict_xpre_from_eps(x_t, t, eps=eps)
            cov,eps2 = self.predict_xpre_cov_from_eps(x_t, t, eps)
            import time
            s1 = time.time()
            skeness,eps3  = self.predict_xpre_3moment_from_eps(x_t, t, eps, eps2, mean)
            e1 = time.time()
            logging.info('time for network {0}'.format(e1-s1))
            sigma2_small  = self.ddpm_cov(x_t,t)
            """
            if self.eps4_model is not None:
                fmoment  = self.predict_xpre_4moment_from_eps(x_t, t, eps,eps2,eps3)
            else:
                fmoment  = None
            """
            fmoment = None
            return mean,cov,skeness,fmoment,cov
        else:
            raise NotImplementedError(self.sample_type)

    def forward(self, x_T):
        x_t = x_T
        self.cluster_k  = {}
        #for time_step in tqdm(self.t_list):
        for time_step in self.t_list:
            self.statistics = {}
            self.cluster_k[time_step] = []
            t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long)  * time_step
            if time_step > 0:
                noise = torch.randn_like(x_t).to(x_T.device)
            else:
                if self.time_shift:
                    eps = self.model(x_t, t+1)
                else:
                    eps = self.model(x_t, t)
                a_ts = extract(self.sqrt_alphas_bar, t, x_t.shape)
                sigma_t = torch.sqrt(extract(self.one_minus_alphas_bar, t, x_t.shape))
                beta_ts = (1-a_ts**2)
                x_0 = 1/a_ts*( x_t - eps * beta_ts/sigma_t)
                #x_0 = x_t
                return torch.clip(x_0, -1, 1)

            # sample with mixture of Gaussian
            if self.sample_type == 'gmddpm':
                mean,cov,tmoment,fmoment,sigma2_small = self.p_mean_variance(x_t=x_t, t=t)
                self.statistics['moment error'] =  (torch.abs(tmoment-mean.pow(3)-3*mean*cov)).mean().item()
                pre_cov = sigma2_small
                if time_step-self.ratio <= 0:
                    if self.noise_schedule == 'linear':
                        clip_pixel = 2
                    else:
                        clip_pixel = 1
                    var_threshold = (clip_pixel * 2. / 255. * (math.pi / 2.) ** 0.5) ** 2
                    self.statistics['unclip var_mean'] = var.mean().item()
                    var = cov
                    var = var.clamp(0., var_threshold)
                    self.statistics['clip var_mean'] = var.mean().item()
                    self.statistics['threshold for var'] = var_threshold
                    x_t = mean + var**0.5 * noise
                    report_statistics(torch.tensor(max(time_step-self.ratio,0)), torch.tensor(time_step), self.statistics)
                    continue
                #beta2   = 1
                mean1,mean2,beta,self.statistics = solve_gmm(mean,cov,tmoment,fmoment,pre_cov,time_step,self.statistics)
                var  = pre_cov*beta
                #mean1,mean2,var  = solve_analytic(mean,cov,tmoment)
                mean = torch.zeros(size=mean1.size()).to(mean1.device)
                for n_count in range(mean.size()[0]):
                    if (torch.rand(size=(1,1))<torch.tensor(1/3))[0][0]:
                        self.cluster_k[time_step].append(0)
                        mean[n_count,...]= mean1[n_count,...]
                    else:
                        self.cluster_k[time_step].append(1)
                        mean[n_count,...]= mean2[n_count,...]
                """
                if (torch.rand(size=(1,1))<torch.tensor(1/3))[0][0]:
                    var = pre_cov*beta
                    mean= mean1
                else:
                    var = pre_cov*beta
                    mean= mean2
                """
                self.statistics['Gaussian_cov'] = pre_cov.mean().item()
                self.statistics['choosend_cov'] = var.mean().item()
                self.statistics['choosend_cov_min'] = var.min().item()
                x_t = mean + var**0.5 * noise
            # sample with DDPM/Imperfect Analytic-DPM (Bao et al. (2022))
            else:
                mean, var = self.p_mean_variance(x_t=x_t, t=t)
                #logging.info('var={}'.format(var))
                if time_step-self.ratio <= 0:
                    if self.noise_schedule == 'linear':
                        clip_pixel = 2
                    else:
                        clip_pixel = 1
                    var_threshold = (clip_pixel * 2. / 255. * (math.pi / 2.) ** 0.5) ** 2
                    self.statistics['unclip var_mean'] = var.mean().item()
                    var = var.clamp(0., var_threshold)
                    self.statistics['clip var_mean'] = var.mean().item()
                    self.statistics['threshold for var'] = var_threshold
                x_t = mean + var**0.5 * noise
            # logging the var-related result
            report_statistics(torch.tensor(max(time_step-self.ratio,0)), torch.tensor(time_step), self.statistics)

device = torch.device('cuda:0')

def Sample_parallel(net_sampler):
    save_file = './sample/cifar10/'+str(FLAGS.sample_type)+str(FLAGS.sample_steps)+'/'
    images    = []
    #with torch.no_grad():
    for i in trange(0, FLAGS.num_images, FLAGS.batch_size):
        batch_size = min(FLAGS.batch_size, FLAGS.num_images - i)
        x_T = torch.randn((batch_size, 3, FLAGS.img_size, FLAGS.img_size))
        batch_images_g= net_sampler(x_T.to(device))
        """
        if i==0:
            final_dict = sample_k
        else:
            for keys,value in sample_k.items():
                final_dict[keys].extend(value)
        print(final_dict)
        with open('cluster.txt','w') as f:
            f.write(str(final_dict))
        """
        batch_images = batch_images_g.cpu()
        images.append((batch_images + 1) / 2)
        for kkk in range(batch_images.size()[0]):
            single_image = (batch_images[kkk,...]+1)/2
            try:
                save_image(single_image, save_file+str(i+kkk)+'.png')
            except:
                os.mkdir(save_file)
                save_image(single_image, save_file+str(i+kkk)+'.png')
        grid = (make_grid(batch_images[:64,...]) + 1) / 2
        path = os.path.join(
            save_file,'sample.png')
        save_image(grid, path)
    images = torch.cat(images, dim=0).numpy()
    print(images.shape)
    (IS, IS_std), FID = get_inception_and_fid_score(
        images, FLAGS.fid_cache, num_images=FLAGS.num_images,
        use_torch=FLAGS.fid_use_torch, verbose=True)
    print(IS)
    print(FID)

def eval():
    if FLAGS.time_shift:
        if FLAGS.noise_schedule != 'cosine':
            eps1_model = UNetModel4Pretrained2(in_channels=FLAGS.in_channel,model_channels=FLAGS.ch,out_channels=FLAGS.out_channel,num_res_blocks=FLAGS.num_res_blocks,attention_resolutions=FLAGS.attn,dropout=FLAGS.dropout,
            channel_mult=FLAGS.ch_mult,conv_resample=True,dims=FLAGS.dims,num_classes=None,use_checkpoint=False,num_heads=FLAGS.num_heads,num_heads_upsample=-1,use_scale_shift_norm=FLAGS.use_scale_shift_norm,
            head_out_channels=FLAGS.head_out_channels,mode='simple')
            ckpt1 = torch.load('/home/aiops/allanguo/MixtureGaussianDiffusion/models/cifar10_ema_eps_eps2_pretrained_340000.ckpt.pth')
        else:
            eps1_model = UNetModel4Pretrained2(in_channels=FLAGS.in_channel,model_channels=FLAGS.ch,out_channels=FLAGS.out_channel,num_res_blocks=FLAGS.num_res_blocks,attention_resolutions=FLAGS.attn,dropout=FLAGS.dropout,
            channel_mult=FLAGS.ch_mult,conv_resample=True,dims=FLAGS.dims,num_classes=None,use_checkpoint=False,num_heads=FLAGS.num_heads,num_heads_upsample=-1,use_scale_shift_norm=FLAGS.use_scale_shift_norm,
            head_out_channels=FLAGS.head_out_channels,mode='simple')
            ckpt1 = torch.load('/home/aiops/allanguo/cifar/logs/cifar10_cosine1000_ema_eps_eps2_pretrained_460000.ckpt.pth')
            #ckpt1 = torch.load('./logs/iDDPM_IMAGENET_EPS1/models/ckpt_1_600000.pt')['ema_model']
            #ckpt1 = torch.load('./logs/iDDPM_CIFAR10_cos_EPS1/models/ckpt_1_600000.pt')['ema_model']
    else:
        if FLAGS.noise_schedule != 'cosine':
            logging.info(FLAGS.noise_schedule)
            eps1_model = UNetModel(in_channels=FLAGS.in_channel,model_channels=FLAGS.ch,out_channels=FLAGS.out_channel,num_res_blocks=FLAGS.num_res_blocks,attention_resolutions=FLAGS.attn,dropout=FLAGS.dropout,
            channel_mult=FLAGS.ch_mult,conv_resample=True,dims=FLAGS.dims,num_classes=None,use_checkpoint=False,num_heads=FLAGS.num_heads,num_heads_upsample=-1,use_scale_shift_norm=FLAGS.use_scale_shift_norm,)
            ckpt1 = torch.load('./logs/iDDPM_CIFAR10_EPS1/models/ckpt_1_800000.pt')['ema_model']
        else:
            eps1_model = UNetModel(in_channels=FLAGS.in_channel,model_channels=FLAGS.ch,out_channels=FLAGS.out_channel,num_res_blocks=FLAGS.num_res_blocks,attention_resolutions=FLAGS.attn,dropout=FLAGS.dropout,
            channel_mult=FLAGS.ch_mult,conv_resample=True,dims=FLAGS.dims,num_classes=None,use_checkpoint=False,num_heads=FLAGS.num_heads,num_heads_upsample=-1,use_scale_shift_norm=FLAGS.use_scale_shift_norm,)
            #ckpt1 = torch.load('/home/aiops/allanguo/cifar/logs/cifar10_cosine1000_ema_eps_eps2_pretrained_460000.ckpt.pth')
            ckpt1 = torch.load('./logs/iDDPM_CIFAR10_cos_EPS1/models/ckpt_1_1200000.pt')['ema_model']
    eps1_model.load_state_dict(ckpt1)
    eps1_model.eval()

    # Sampling for Extended Analytic DPM
    if FLAGS.sample_type == 'analyticdpm' or FLAGS.sample_type == 'gmddpm':
        print('Sample IS not using DDPM')
        eps2_model = UNetModel4Pretrained(in_channels=FLAGS.in_channel,model_channels=FLAGS.ch,out_channels=FLAGS.out_channel,num_res_blocks=FLAGS.num_res_blocks,attention_resolutions=FLAGS.attn,dropout=FLAGS.dropout,
            channel_mult=FLAGS.ch_mult,conv_resample=True,dims=FLAGS.dims,num_classes=None,use_checkpoint=False,num_heads=FLAGS.num_heads,num_heads_upsample=-1,use_scale_shift_norm=FLAGS.use_scale_shift_norm,
            head_out_channels=FLAGS.head_out_channels,mode='simple')
        if FLAGS.time_shift:
            if FLAGS.noise_schedule != 'cosine':
                ckpt2 = torch.load('/home/aiops/allanguo/MixtureGaussianDiffusion/models/cifar10_ema_eps_eps2_pretrained_340000.ckpt.pth')
            else:
                ckpt2 = torch.load('/home/aiops/allanguo/cifar/logs/cifar10_cosine1000_ema_eps_eps2_pretrained_460000.ckpt.pth')
        else:
            if FLAGS.noise_schedule != 'cosine':
            #ckpt2 = torch.load('./logs/iDDPM_CIFAR10_EPS/models/ckpt_1_900000.pt')['ema_model']
                ckpt2 = torch.load('/home/aiops/allanguo/MixtureGaussianDiffusion/models/cifar10_ema_eps_eps2_pretrained_340000.ckpt.pth')
            else:
                ckpt2 = torch.load('/home/aiops/allanguo/cifar/logs/cifar10_cosine1000_ema_eps_eps2_pretrained_460000.ckpt.pth')
        eps2_model.load_state_dict(ckpt2)
        eps2_model.eval()

        if FLAGS.sample_type == 'gmddpm':
            eps3_model = UNetModel4Pretrained(in_channels=FLAGS.in_channel,model_channels=FLAGS.ch,out_channels=FLAGS.out_channel,num_res_blocks=FLAGS.num_res_blocks,attention_resolutions=FLAGS.attn,dropout=FLAGS.dropout,
                    channel_mult=FLAGS.ch_mult,conv_resample=True,dims=FLAGS.dims,num_classes=None,use_checkpoint=False,num_heads=FLAGS.num_heads,num_heads_upsample=-1,use_scale_shift_norm=FLAGS.use_scale_shift_norm,
                    head_out_channels=FLAGS.head_out_channels,mode='complex')
            if FLAGS.noise_schedule != 'cosine':
                ckpt3_path = './logs/iDDPM_CIFAR10_EPS3_2/models/ckpt_3_950000.pt'
            else:
                ckpt3_path = './logs/iDDPM_CIFAR10_cos_EPS3_3/models/ckpt_3_1500000.pt'
            ckpt3 = torch.load(ckpt3_path)
            eps3_model.load_state_dict(ckpt3['ema_model'])
            logging.info(ckpt3_path)
            eps3_model.eval()

            #eps4_model = UNetModel4Pretrained(in_channels=FLAGS.in_channel,model_channels=FLAGS.ch,out_channels=FLAGS.out_channel,num_res_blocks=FLAGS.num_res_blocks,attention_resolutions=FLAGS.attn,dropout=FLAGS.dropout,
            #        channel_mult=FLAGS.ch_mult,conv_resample=True,dims=FLAGS.dims,num_classes=None,use_checkpoint=False,num_heads=FLAGS.num_heads,num_heads_upsample=-1,use_scale_shift_norm=FLAGS.use_scale_shift_norm,
            #        head_out_channels=FLAGS.head_out_channels,mode='complex')
            #ckpt4 = torch.load('./logs/iDDPM_CIFAR10_EPS4_1/models/ckpt_4_350000.pt')
            #eps4_model.load_state_dict(ckpt4['ema_model'])
            eps4_model= None
        else:
            eps3_model = None
            eps4_model = None
    else:
        eps2_model = None
        eps3_model = None
        eps4_model = None
    #print(eps2_model)
    print(FLAGS.beta_1)
    print(FLAGS.time_shift)
    print(FLAGS)
    net_sampler = GaussianDiffusionSampler(
        eps1_model,eps2_model,eps3_model,eps4_model, FLAGS.beta_1, FLAGS.beta_T, FLAGS.sample_steps, FLAGS.img_size,
        FLAGS.sample_type,FLAGS.time_shift,FLAGS.noise_schedule).to(device)
    if FLAGS.parallel:
        net_sampler = torch.nn.DataParallel(net_sampler)
    with torch.no_grad():
        Sample_parallel(net_sampler)

def main(argv):
    warnings.simplefilter(action='ignore', category=FutureWarning)
    eval()

app.run(main)