''' 
This script does conditional image generation on MNIST, using a diffusion model

This code is modified from,
https://github.com/cloneofsimo/minDiffusion

Diffusion model is based on DDPM,
https://arxiv.org/abs/2006.11239

The conditioning idea is taken from 'Classifier-Free Diffusion Guidance',
https://arxiv.org/abs/2207.12598

This technique also features in ImageGen 'Photorealistic Text-to-Image Diffusion Modelswith Deep Language Understanding',
https://arxiv.org/abs/2205.11487

'''
import random
# import time
import warnings
import torch.backends.cudnn as cudnn
from typing import Dict, Tuple
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import models, transforms
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from torchvision.datasets import MNIST
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import numpy as np

import wandb
from utils import AverageMeter, CustomDataset  # Import the AverageMeter class
from unet import ContextUnet, ddpm_schedules  # Import the ContextUnet class

import os
import argparse
from PIL import Image

from supervised_mnist import Net, train_supervised, test
from train_mnist_lenet import LeNet5
from fid import calculate_fid, compute_metrics

parser = argparse.ArgumentParser(description='PyTorch DDPM MNIST')
parser.add_argument('--dataset', default='mnist', help='dataset setting')
parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
parser.add_argument('-b', '--batch-size', default=256, type=int,
                    metavar='N',
                    help='mini-batch size')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--root_log',type=str, default='log')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('--exp_str', default='0', type=str, help='number to indicate which experiment it is')
parser.add_argument('--resume', '-r', action='store_true',
                    help='resume from checkpoint')
parser.add_argument('--root_model', type=str, default='runs')
parser.add_argument('--log_results', action='store_true',
                    help='To log results on wandb')
parser.add_argument('--save_model', action='store_true',
                    help='To log results on wandb')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('-p', '--print-freq', default=100, type=int,
                    metavar='N', help='print frequency (default: 100)')
parser.add_argument('-T','--diffusion-steps', default=500, type=int,
                     help='Number of diffusion steps')
parser.add_argument('-G','--num_generations', default=10, type=int,
                     help='Number of diffusion steps')
parser.add_argument('--start_gen', default=-1, type=int,
                     help='Number of diffusion steps')
parser.add_argument('--feature_dim', default=256, type=int,
                     help='Feature Dim')
parser.add_argument('--num_sampled_images', default=60000, type=int,
                     help='Number of sampled images')
parser.add_argument('--seed', default=None, type=int,
                    help='seed for initializing training. ')

class DDPM(nn.Module):
    def __init__(self, nn_model, betas, n_T, device, drop_prob=0.1):
        super().__init__()
        self.nn_model = nn_model.to(device)

        # register_buffer allows accessing dictionary produced by ddpm_schedules
        # e.g. can access self.sqrtab later
        for k, v in ddpm_schedules(betas[0], betas[1], n_T).items():
            self.register_buffer(k, v)

        self.n_T = n_T
        self.device = device
        self.drop_prob = drop_prob
        self.loss_mse = nn.MSELoss()

    def forward(self, x, c):
        """
        this method is used in training, so samples t and noise randomly
        """

        _ts = torch.randint(1, self.n_T+1, (x.shape[0],)).to(self.device)  # t ~ Uniform(0, n_T)
        noise = torch.randn_like(x)  # eps ~ N(0, 1)

        x_t = (
            self.sqrtab[_ts, None, None, None] * x
            + self.sqrtmab[_ts, None, None, None] * noise
        )  # This is the x_t, which is sqrt(alphabar) x_0 + sqrt(1-alphabar) * eps
        # We should predict the "error term" from this x_t. Loss is what we return.

        # dropout context with some probability
        context_mask = torch.bernoulli(torch.zeros_like(c)+self.drop_prob).to(self.device)
        
        # return MSE between added noise, and our predicted noise
        return self.loss_mse(noise, self.nn_model(x_t, c, _ts / self.n_T, context_mask))

    def sample(self, n_sample, size, device, guide_w = 0.0):
        # we follow the guidance sampling scheme described in 'Classifier-Free Diffusion Guidance'
        # to make the fwd passes efficient, we concat two versions of the dataset,
        # one with context_mask=0 and the other context_mask=1
        # we then mix the outputs with the guidance scale, w
        # where w>0 means more guidance

        x_i = torch.randn(n_sample, *size).to(device)  # x_T ~ N(0, 1), sample initial noise
        c_i = torch.arange(0,10).to(device) # context for us just cycles throught the mnist labels
        c_i = c_i.repeat(int(n_sample/c_i.shape[0]))

        # don't drop context at test time
        context_mask = torch.zeros_like(c_i).to(device)

        # double the batch
        c_i = c_i.repeat(2)
        context_mask = context_mask.repeat(2)
        context_mask[n_sample:] = 1. # makes second half of batch context free

        for i in range(self.n_T, 0, -1):
            # print(f'sampling timestep {i}',end='\r')
            t_is = torch.tensor([i / self.n_T]).to(device)
            t_is = t_is.repeat(n_sample,1,1,1)

            # double batch
            x_i = x_i.repeat(2,1,1,1)
            t_is = t_is.repeat(2,1,1,1)

            z = torch.randn(n_sample, *size).to(device) if i > 1 else 0

            # split predictions and compute weighting
            eps = self.nn_model(x_i, c_i, t_is, context_mask)
            eps1 = eps[:n_sample]
            eps2 = eps[n_sample:]
            eps = (1+guide_w)*eps1 - guide_w*eps2
            x_i = x_i[:n_sample]
            x_i = (
                self.oneover_sqrta[i] * (x_i - eps * self.mab_over_sqrtmab[i])
                + self.sqrt_beta_t[i] * z
            )
        
        return x_i

    def conditional_sample(self, c_i, n_sample, size, device, guide_w = 0.0):
        # we follow the guidance sampling scheme described in 'Classifier-Free Diffusion Guidance'
        # to make the fwd passes efficient, we concat two versions of the dataset,
        # one with context_mask=0 and the other context_mask=1
        # we then mix the outputs with the guidance scale, w
        # where w>0 means more guidance

        x_i = torch.randn(n_sample, *size).to(device)  # x_T ~ N(0, 1), sample initial noise

        # don't drop context at test time
        context_mask = torch.zeros_like(c_i).to(device)

        # double the batch
        c_i = c_i.repeat(2)
        context_mask = context_mask.repeat(2)
        context_mask[n_sample:] = 1. # makes second half of batch context free

        for i in range(self.n_T, 0, -1):
            # print(f'sampling timestep {i}',end='\r')
            t_is = torch.tensor([i / self.n_T]).to(device)
            t_is = t_is.repeat(n_sample,1,1,1)

            # double batch
            x_i = x_i.repeat(2,1,1,1)
            t_is = t_is.repeat(2,1,1,1)

            z = torch.randn(n_sample, *size).to(device) if i > 1 else 0

            # split predictions and compute weighting
            eps = self.nn_model(x_i, c_i, t_is, context_mask)
            eps1 = eps[:n_sample]
            eps2 = eps[n_sample:]
            eps = (1+guide_w)*eps1 - guide_w*eps2
            x_i = x_i[:n_sample]
            x_i = (
                self.oneover_sqrta[i] * (x_i - eps * self.mab_over_sqrtmab[i])
                + self.sqrt_beta_t[i] * z
            )
        
        return x_i


def train_epoch(ddpm, dataloader, optim, device, epoch, args):
    print(f'epoch {epoch}')
    ddpm.train()

    loss_ema = None
    loss_meter = AverageMeter()

    for images, labels in dataloader:
        optim.zero_grad()
        images = 2*images - 1
        images = images.to(device)
        labels = labels.to(device)
        loss = ddpm(images, labels)
        loss.backward()
        if loss_ema is None:
            loss_ema = loss.item()
        else:
            loss_ema = 0.95 * loss_ema + 0.05 * loss.item()
        loss_meter.update(loss.item())
        optim.step()

    return loss_meter.avg

# Define the sample function
def sample(ddpm, n_sample, image_shape, device, ws_test):
    ddpm.eval()
    with torch.no_grad():
        sampled_images = ddpm.sample(n_sample, image_shape, device, guide_w=ws_test)
    return sampled_images

def evaluate_mnist(X, Y, gen, args):
    device = args.device
    if args.dataset=="mnist":
        tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
        original_train_dataset = MNIST("./data", train=True, download=True, transform=tf)
        test_dataset = MNIST('data', train=False,
                       transform=tf)
        train_dataset = CustomDataset(X, Y, tf)
    else:
        raise ValueError("Dataset not supported")
    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
    original_train_dataloader = DataLoader(original_train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
    generated_mnist_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
    #generated_mnist_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
    test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
    # Get test accuracy
    # test_acc = test_mnist(ddpm, dataloader, device)
    eval_model = LeNet5().to(args.device)
    ckpt = torch.load("mnist_lenet5.pth")
    eval_model.load_state_dict(ckpt)
    eval_model.eval()
    print("Checkpoint Loaded")
    sanity_check_acc = test(eval_model, device, test_dataloader)
    print(f"Sanity Check Accuracy: {sanity_check_acc}")

    eval_model_acc_on_generated_samples = test(eval_model, device, generated_mnist_dataloader)
    print(f"Accuracy of the Generated MNIST dataloader on a pretrained LeNet: {eval_model_acc_on_generated_samples}")
    if args.log_results:
        wandb.log({'acc-gen_mnist_lenet':eval_model_acc_on_generated_samples, "gen":gen})

    fid, _, _ = calculate_fid(original_train_dataloader,
                 generated_mnist_dataloader,
                  eval_model,
                  args)
    #pr, recall, density, coverage = compute_metrics(original_train_dataloader, generated_mnist_dataloader, eval_model, args)
    print("FID",fid)
    if args.log_results:
        wandb.log({'fid':fid, "gen":gen})
    

def train_mnist(args):

    # hardcoding these here
    n_T = args.diffusion_steps # 500
    device = args.device
    n_feat = args.feature_dim # 128 ok, 256 better (but slower)
    save_model = args.save_model
    save_dir = f"runs/{args.store_name}"
    ws_test = [0.0, 0.5, 2.0] # strength of generative guidance

    #num_generations = 2

    for gen in range(args.start_gen+1, args.num_generations):
        print(f'generation {gen}')
        ddpm = DDPM(nn_model=ContextUnet(in_channels=1, n_feat=n_feat, n_classes=args.n_classes), betas=(1e-4, 0.02), n_T=n_T, device=device, drop_prob=0.1)
        ddpm.to(device)

        # optionally load a model
        # ddpm.load_state_dict(torch.load("./data/diffusion_outputs/ddpm_unet01_mnist_9.pth"))

        tf = transforms.Compose([transforms.ToTensor()]) 

        if gen==0:
            dataset = MNIST("./data", train=True, download=True, transform=tf)
        else:
            guide_w = 0.0
            data = np.load(f"{save_dir}/gen_{gen-1}_generated_data_w_{guide_w}.npz")
            X = data['X']
            Y = data['Y']
            dataset = CustomDataset(X, Y, tf)
        
        dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
        optim = torch.optim.Adam(ddpm.parameters(), lr=args.lr)


        for epoch in range(args.epochs):

            optim.param_groups[0]['lr'] = args.lr * (1 - epoch / args.epochs)

            avg_loss = train_epoch(ddpm, dataloader, optim, device, epoch, args)

            print(f"Average loss: {avg_loss:.4f}")
            if args.log_results:
                wandb.log({"train_loss": avg_loss, "epoch":epoch})

            if epoch % args.log_interval == 0:
                ddpm.eval()
                with torch.no_grad():
                    n_sample = 10 * args.n_classes
                    for _, w in enumerate(ws_test):
                        x_gen = ddpm.sample(n_sample, (1, 28, 28), device, guide_w=w)
                        grid = make_grid(((x_gen + 1)/2).clamp(0.0, 1.0), nrow=10)
                        image_path = f"{save_dir}/gen_{gen}_image_ep{epoch}_w{w}.png"
                        save_image(grid, image_path)
                        # Log images to WandB
                        if args.log_results:
                            wandb.log({f"gen_{gen}_image_w{w}": [wandb.Image(image_path, caption=f"Epoch{epoch}-{w}")], "epoch":epoch})
                        print('saved image at ' + save_dir + f"image_ep{epoch}_w{w}.png")

            # optionally save model
            if save_model and epoch == int(args.epochs - 1):
                torch.save(ddpm.state_dict(), save_dir + f"gen_{gen}_model_{epoch}.pth")
                print('saved model at ' + save_dir + f"gen_{gen}_model_{epoch}.pth")
        
        # Now sample 60k images from the model and save them
        # Probably need to batch this to avoid memory issues
        ddpm.eval()
        with torch.no_grad():
            guide_w = 0.0
            n_sample = args.num_sampled_images
            #c_i = torch.arange(0,10).to(device) # context for us just cycles throught the mnist labels
            #c_i = c_i.repeat(int(n_sample/c_i.shape[0]))
            #x_gen = ddpm.conditional_sample(c_i, n_sample, (1, 28, 28), device, guide_w=guide_w)
            #idx = torch.randint(0, n_sample, (100,))
            #grid = make_grid(x_gen[idx], nrow=10)
            #image_path = f"{save_dir}/gen_{gen}_image_ep_final_w{w}.png"
            #save_image(grid, image_path)
            #if args.log_results:
            #    wandb.log({f'image_gen{gen}':wandb.Image(image_path)})

            #x_gen = ((x_gen + 1) * 127.5).clamp(0, 255).to(torch.uint8)
            # Randomly sample 100 images from x_gen and visualise them
            # Save x_gen in torch array
            #x_gen = x_gen.cpu().numpy()
            #data_dict = {'X': x_gen, 'Y': c_i.cpu().numpy()}
            #np.save(f"{save_dir}/gen_{gen}_generated_data_w_{guide_w}.npy", **data_dict)

            total_num_samples = args.num_sampled_images#//args.n_classes
            #print(n_samples_per_class)
            data_dict = {'X': [], 'Y': []}
            n_sample_batch = 1000
            for i in range(total_num_samples//n_sample_batch):
            #for i in range(5):
                 #print(f'sampling class {i}')
                 # Condtion on class i
                 print(i)
                 #c_i = (torch.ones(n_samples_per_class) * i).to(device)
                 c_i = torch.arange(0,10).to(device) # context for us just cycles throught the mnist labels
                 c_i = c_i.repeat(int(n_sample_batch/c_i.shape[0]))
                 x_gen = ddpm.conditional_sample(c_i,  n_sample_batch, (1, 28, 28), device, guide_w=0.0)
                 x_gen = ((x_gen + 1) * 127.5).clamp(0, 255).to(torch.uint8).permute(0,2,3,1).squeeze().cpu().numpy()
                 #print(x_gen.shape)
                 data_dict['X'].append(x_gen)
                 data_dict['Y'].extend(c_i.cpu().numpy())
            data_dict['X'] = np.vstack(data_dict['X']).astype(np.uint8)
            data_dict['Y'] = np.asarray(data_dict['Y'])
            #print(data_dict['X'].shape)
            np.savez(f"{save_dir}/gen_{gen}_generated_data_w_{guide_w}.npz", **data_dict)
            #     # Save x_gen in torch array
        #     torch.save(x_gen, save_dir + f"gen_{gen}_ generated_images_class{i}_w_{guide_w}.pth")
        
        # Get Test Accuracy of MNIST with generated images
        # Load the generated images
        data = np.load(f"{save_dir}/gen_{gen}_generated_data_w_{guide_w}.npz")
        #print
        X = data['X']
        Y = data['Y']
        #print(Y.shape, X.shape)
        evaluate_mnist(X, Y, gen, args)

def main():
    args = parser.parse_args()
    if args.dataset=="mnist":
        args.n_classes = 10
    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                        'This will turn on the CUDNN deterministic setting, '
                        'which can slow down your training considerably! '
                        'You may see unexpected behavior when restarting '
                        'from checkpoints.')
    args.store_name = '_'.join([args.dataset, 'ddpm', 'T', str(args.diffusion_steps),'UNet', str(args.feature_dim), 'bs', str(args.batch_size), 'epochs', str(args.epochs), 'gen', str(args.num_generations), 'seed', str(args.seed),args.exp_str])
    os.makedirs("runs/" + args.store_name, exist_ok=True)
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if args.device == 'cpu':
        raise ValueError("Cuda not available")
    if args.log_results:
        wandb.init(project="synthetic",
                                   entity="neurips", name=args.store_name)
        wandb.config.update(args)
        wandb.run.log_code(".")

    train_mnist(args)

if __name__ == "__main__":
    main()