import argparse, os, sys, glob, datetime, yaml
import torch
import torch.nn as nn
import time
import numpy as np
from tqdm import trange

from omegaconf import OmegaConf
from PIL import Image

from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config
from ldm.modules.diffusionmodules.util import make_ddim_timesteps

from quant.utils_dm import resume_cali_model
from quant.quant_model_ldm import QuantModel

from torch import autocast
import csv

rescale = lambda x: (x + 1.) / 2.

def custom_to_pil(x):
    x = x.detach().cpu()
    x = torch.clamp(x, -1., 1.)
    x = (x + 1.) / 2.
    x = x.permute(1, 2, 0).numpy()
    x = (255 * x).astype(np.uint8)
    x = Image.fromarray(x)
    if not x.mode == "RGB":
        x = x.convert("RGB")
    return x


def custom_to_np(x):
    # saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py
    sample = x.detach().cpu()
    sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
    sample = sample.permute(0, 2, 3, 1)
    sample = sample.contiguous()
    return sample


def logs2pil(logs, keys=["sample"]):
    imgs = dict()
    for k in logs:
        try:
            if len(logs[k].shape) == 4:
                img = custom_to_pil(logs[k][0, ...])
            elif len(logs[k].shape) == 3:
                img = custom_to_pil(logs[k])
            else:
                print(f"Unknown format for key {k}. ")
                img = None
        except:
            img = None
        imgs[k] = img
    return imgs

@torch.no_grad()
def convsample(model, shape, return_intermediates=True,
               verbose=True,
               make_prog_row=False):


    if not make_prog_row:
        return model.p_sample_loop(None, shape,
                                   return_intermediates=return_intermediates, verbose=verbose)
    else:
        return model.progressive_denoising(
            None, shape, verbose=True
        )


@torch.no_grad()
def convsample_ddim(model, steps, shape, eta=1.0, weight_quant=True, quant_steps_abits=None
                    ):
    ddim = DDIMSampler(model, weight_quant=weight_quant)
    bs = shape[0]
    shape = shape[1:]
    samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False, quant_steps_abits=quant_steps_abits)
    return samples, intermediates


@torch.no_grad()
def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=None, eta=1.0, weight_quant=True, quant_steps_abits=None):


    log = dict()

    shape = [batch_size,
             model.model.diffusion_model.in_channels,
             model.model.diffusion_model.image_size,
             model.model.diffusion_model.image_size]

    with model.ema_scope("Plotting"):
        t0 = time.time()
        if vanilla:
            sample, progrow = convsample(model, shape,
                                         make_prog_row=True)
        else:
            sample, intermediates = convsample_ddim(model,  steps=custom_steps, shape=shape,
                                                    eta=eta, weight_quant=weight_quant, quant_steps_abits=quant_steps_abits)

        t1 = time.time()

    x_sample = model.decode_first_stage(sample)

    log["sample"] = x_sample
    log["time"] = t1 - t0
    log['throughput'] = sample.shape[0] / (t1 - t0)
    print(f'Throughput for this batch: {log["throughput"]}')
    return log

def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None,
        weight_quant=True, quant_steps_abits=None):
    if vanilla:
        print(f'Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.')
    else:
        print(f'Using DDIM sampling with {custom_steps} sampling steps and eta={eta}')


    tstart = time.time()
    n_saved = len(glob.glob(os.path.join(logdir,'*.png')))-1
    # path = logdir
    if model.cond_stage_model is None:
        all_images = []

        print(f"Running unconditional sampling for {n_samples} samples")
        for _ in trange(n_samples // batch_size, desc="Sampling Batches (unconditional)"):
            logs = make_convolutional_sample(model, batch_size=batch_size,
                                             vanilla=vanilla, custom_steps=custom_steps,
                                             eta=eta,
                                             weight_quant=weight_quant,
                                             quant_steps_abits=quant_steps_abits)
            n_saved = save_logs(logs, logdir, n_saved=n_saved, key="sample")
            all_images.extend([custom_to_np(logs["sample"])])
            if n_saved >= n_samples:
                print(f'Finish after generating {n_saved} samples')
                break
            # added for reduce GPU-memory
            torch.cuda.empty_cache()
        all_img = np.concatenate(all_images, axis=0)
        all_img = all_img[:n_samples]
        shape_str = "x".join([str(x) for x in all_img.shape])
        nppath = os.path.join(nplog, f"{shape_str}-samples.npz")
        np.savez(nppath, all_img)

    else:
       raise NotImplementedError('Currently only sampling for unconditional models supported.')

    print(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.")


def save_logs(logs, path, n_saved=0, key="sample", np_path=None):
    for k in logs:
        if k == key:
            batch = logs[key]
            if np_path is None:
                for x in batch:
                    img = custom_to_pil(x)
                    imgpath = os.path.join(path, f"{key}_{n_saved:06}.png")
                    img.save(imgpath)
                    n_saved += 1
            else:
                npbatch = custom_to_np(batch)
                shape_str = "x".join([str(x) for x in npbatch.shape])
                nppath = os.path.join(np_path, f"{n_saved}-{shape_str}-samples.npz")
                np.savez(nppath, npbatch)
                n_saved += npbatch.shape[0]
    return n_saved


def get_parser():
    parser = argparse.ArgumentParser()
    # parameters for data and model
    parser.add_argument("--img_type", type=str, default="lsun_churches", nargs="?", help="type of image generation [lsun_churches, lsun_beds]")
    parser.add_argument("-l", "--logdir", type=str, nargs="?", help="extra logdir", default="none")
    parser.add_argument("-n", "--n_samples", type=int, nargs="?",help="number of samples to draw",default=50000)
    parser.add_argument("-c", "--custom_steps", type=int, nargs="?",
                        help="number of steps for ddim and fastdpm sampling", default=200)
    parser.add_argument("--batch_size", type=int, nargs="?", help="the bs", default=10)

    # ldm model options
    parser.add_argument("-e", "--eta", type=float, nargs="?",
                        help="eta for ddim sampling (0.0 yields deterministic sampling)", default=0.0)
    parser.add_argument("-v", "--vanilla_sample", default=False, action='store_true',
                        help="vanilla sampling (default option is DDIM sampling)?")

    # quantization parameters
    parser.add_argument('--n_bits_w', default=4, type=int, help='bitwidth for weight quantization')
    parser.add_argument('--channel_wise', action='store_false', help='apply channel_wise quantization for weights')
    parser.add_argument('--weight_quant', default=True, action='store_false', help='apply weight quantization')
    parser.add_argument("--scale_method_w", default="mse", type=str, help="scale_method for weight")
    parser.add_argument('--n_bits_a', default=[8], type=int, nargs='+', help='bitwidth for activation quantization')
    parser.add_argument('--act_quant', default=True, action='store_false', help='apply activation quantization')
    parser.add_argument("--scale_method_a", default="max", type=str, help="scale_method for activation")

    # path to calibrated file
    parser.add_argument("--cali_ckpt", type=str, nargs="?", default="/mnt/nfs-vlsi/yulhwakim/pre_trained/qdiff_model/church_w4a8_ckpt.pth", help="load from partially or fully calibrated model")

    # path to save results
    parser.add_argument('--outdir', default="outputs", type=str, help="dir to write results to")

    # qdiff options
    parser.add_argument("--act_quant_mode", default="dynamic", type=str, help="quantization mode to use")
    parser.add_argument('--a_sym', action='store_true', help='act quantizers use symmetric quantization')
    parser.add_argument('--sm_abit', default=8, type=int, help='attn softmax activation bit')

    parser.add_argument('--quant_start_steps', default=[200], nargs='+', type=int)
    parser.add_argument('--quant_end_steps', default=[0], nargs='+',type=int)

    return parser


if __name__ == "__main__":

    now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    sys.path.append(os.getcwd())
    command = " ".join(sys.argv)

    parser = get_parser()
    opt, unknown = parser.parse_known_args()

    logdir = opt.img_type
    base_configs = sorted(glob.glob(os.path.join("./configs", f"{logdir}.yaml")))
    opt.base = base_configs

    configs = [OmegaConf.load(cfg) for cfg in opt.base]
    cli = OmegaConf.from_dotlist(unknown)
    config = OmegaConf.merge(*configs, cli)

    gpu = True
    eval_mode = True

    if opt.logdir != "none":
        locallog = logdir.split(os.sep)[-1]
        if locallog == "": locallog = logdir.split(os.sep)[-2]
        print(f"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'")
        logdir = os.path.join(opt.logdir, locallog)

    print(config)


    # load pre_trained model
    model = instantiate_from_config(config.model)
    ddim_timesteps = make_ddim_timesteps(ddim_discr_method="uniform", num_ddim_timesteps=opt.custom_steps,
                                        num_ddpm_timesteps=model.num_timesteps, verbose=True)

    # build quantization parameters
    wq_params = {'n_bits': opt.n_bits_w,
                'channel_wise': opt.channel_wise,
                'scale_method': opt.scale_method_w}
    aq_params = {'n_bits': opt.n_bits_a[0],
                'channel_wise': False,
                'symmetric': opt.a_sym,
                'scale_method': opt.scale_method_a,
                'leaf_param': opt.act_quant,
                'act_quant_mode': opt.act_quant_mode}

    # copy model_ema parameter if use ema
    if model.use_ema:
        model.model_ema.copy_to(model.model)
        model.use_ema = False

    # upload model to gpu
    model.cuda()
    model.eval()

    # quantize model
    q_model = QuantModel(model=model.model.diffusion_model,
                    weight_quant_params=wq_params,
                    act_quant_params=aq_params,
                    sm_abits=opt.sm_abit)

    model.model.diffusion_model = q_model

    ## upload model to gpu
    #model.cuda()
    model.eval()

    print(75 * "=")
    print("logging to:")

    logdir = os.path.join(logdir)
    imglogdir = os.path.join(logdir, "img")
    numpylogdir = os.path.join(logdir, "numpy")

    os.makedirs(imglogdir, exist_ok=True)
    os.makedirs(numpylogdir, exist_ok=True)
    print(logdir)
    print(75 * "=")

    # write config out
    sampling_file = os.path.join(logdir, "sampling_config.yaml")
    sampling_conf = vars(opt)

    with open(sampling_file, 'w') as f:
        yaml.dump(sampling_conf, f, default_flow_style=False)
    print(sampling_conf)

    # build quantization steps
    num_quant_groups = len(opt.quant_start_steps)
    quant_steps_abits = {}
    for i in range(num_quant_groups):
        if len(opt.n_bits_a) == 1:
            abits_tmp = opt.n_bits_a[0]
        else:
            abits_tmp = opt.n_bits_a[i]
        quant_steps_tmp = [x for x in range(opt.quant_end_steps[i], opt.quant_start_steps[i])]
        quant_steps_tmp.reverse()
        for j in quant_steps_tmp:
            quant_steps_abits[j] = abits_tmp
        print(f"quant_steps : from {opt.quant_end_steps[i]} to {opt.quant_start_steps[i]}, activation bits of {abits_tmp}")


    # load calibrated model
    if opt.cali_ckpt is not None:
        if os.path.exists(opt.cali_ckpt):
            cali_data = (torch.randn(1, config.model.params.channels,
                                    config.model.params.image_size, config.model.params.image_size),
                        torch.randint(0, 1000, (1,)))
            tmp_act_quant = False
            resume_cali_model(q_model, opt.cali_ckpt, cali_data, tmp_act_quant, opt.act_quant_mode, cond=False)

    # set quant state
    q_model.set_quant_state(weight_quant=opt.weight_quant, act_quant=opt.act_quant)

    print(q_model)

    # sample images after quantization
    run(model, imglogdir, eta=opt.eta,
        vanilla=opt.vanilla_sample,  n_samples=opt.n_samples, custom_steps=opt.custom_steps,
        batch_size=opt.batch_size, nplog=numpylogdir, weight_quant=opt.weight_quant, quant_steps_abits=quant_steps_abits)

    print("done.")
