import argparse
import os

import blobfile as bf
import torch as th

import torch.distributed as dist
import torch.nn.functional as F
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
from torch.optim import AdamW, SGD

from guided_diffusion import dist_util, logger
from guided_diffusion.fp16_util import MixedPrecisionTrainer
from guided_diffusion.image_datasets import load_data
from guided_diffusion.resample import create_named_schedule_sampler
from guided_diffusion.script_util import (
    add_dict_to_argparser,
    args_to_dict,
    classifier_and_diffusion_defaults,
    model_and_diffusion_defaults,
    create_model_and_diffusion,
    create_imagenet_classifier,
)

import random
from guided_diffusion.train_util import parse_resume_step_from_filename, log_loss_dict

import torch
from torch import nn
torch.backends.cudnn.benchmark=False

def main():
    args = create_argparser().parse_args()
    print(args)
    dist_util.setup_dist()
    logger.configure(dir = args.log_dir)

    logger.log("creating model and diffusion...")
    model = create_imagenet_classifier(args.denoise_augment)
    old_model = create_imagenet_classifier(False)
    old_model.load_state_dict(
        dist_util.load_state_dict(
            f"workdirs/256x256_classifier.pt", map_location=dist_util.dev()
        ), strict=False
    )
    old_model.convert_to_fp16()
    old_model = old_model.cuda().eval()
    model.to(dist_util.dev())

    c = 0
    for p in model.parameters():
        if p.requires_grad:
            c += p.numel() 
    logger.log(f"Trainable Parameters {c}")
    score_model, diffusion = create_model_and_diffusion(
        **args_to_dict(args, model_and_diffusion_defaults().keys())
    )
    score_model.to(dist_util.dev())
    score_model.load_state_dict(torch.load('workdirs/256x256_diffusion_uncond.pt'))
    score_model.convert_to_fp16()
    score_model = score_model.eval()
    
    if args.noised:
        schedule_sampler = create_named_schedule_sampler(
            args.schedule_sampler, diffusion
        )

    resume_step = 0
    files = os.listdir(args.log_dir)
    ckpt_file = [f for f in files if f.endswith("pt") and f.startswith('model')]
    if len(ckpt_file)>0:
        ckpt_file = ckpt_file[0]
        resume_step = parse_resume_step_from_filename(ckpt_file)
        if dist.get_rank() == 0:
            logger.log(
                f"loading model from checkpoint: {ckpt_file}... at {resume_step} step"
            )
            
        model.load_state_dict(
            dist_util.load_state_dict(
                f"{args.log_dir}/{ckpt_file}", map_location=dist_util.dev()
            )
        )
    else:
        model.load_state_dict(
            dist_util.load_state_dict(
                f"workdirs/256x256_classifier.pt", map_location=dist_util.dev()
            ), strict=False
        )
    # Needed for creating correct EMAs and fp16 parameters.
    dist_util.sync_params(model.parameters())

    mp_trainer = MixedPrecisionTrainer(
        model=model, use_fp16=args.classifier_use_fp16, initial_lg_loss_scale=13.4
    )

    model = DDP(
        model,
        device_ids=[dist_util.dev()],
        output_device=dist_util.dev(),
        broadcast_buffers=False,
        bucket_cap_mb=128,
        find_unused_parameters=False,
    )

    logger.log("creating data loader...")
    data = load_data(
        data_dir=args.data_dir,
        batch_size=args.batch_size,
        image_size=args.image_size,
        class_cond=True,
        random_crop=True,
        cifar_name_style=False,
    )
    if args.val_data_dir:
        val_data = load_data(
            data_dir=args.val_data_dir,
            batch_size=args.batch_size,
            image_size=args.image_size,
            class_cond=True,
            cifar_name_style=False,
        )
    else:
        val_data = None

    logger.log(f"creating optimizer...")
    opt = AdamW(mp_trainer.master_params, lr=args.lr, weight_decay=args.weight_decay)
    ckpt_file = [f for f in files if f.endswith("pt") and f.startswith('opt')]
    if len(ckpt_file)>0:
        try:
            opt_checkpoint = f"{args.log_dir}/{ckpt_file[0]}"
            logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
            opt.load_state_dict(
                dist_util.load_state_dict(opt_checkpoint, map_location=dist_util.dev())
            )
        except Exception as e:
            print("Could not load optimizer")
    for param_group in opt.param_groups:
        param_group["weight_decay"] = 0.05
        param_group['lr'] = 1e-5
    logger.log("training classifier model...")

    def loss_fn(prefix, sub_batch, sub_t, sub_std, sub_labels):
        #continue regularly
        with torch.no_grad():
            model_out = score_model(sub_batch, sub_t)
            score_unc = model_out[:,:3]
            logits_old = old_model(sub_batch[:,:3],sub_t)
        xstart = sub_batch - sub_std*score_unc
        aug = 1*args.denoise_augment
        logits = model(torch.cat([sub_batch, aug*(xstart)],dim=1), timesteps=sub_t)
        ce_loss = F.cross_entropy(logits, sub_labels, reduction="none")
        
        losses = {}
        losses[f"{prefix}_loss"] = ce_loss.detach()
        losses[f"{prefix}_acc@1"] = compute_top_k(
            logits, sub_labels, k=1, reduction="none"
        )
        losses[f"{prefix}_acc@5"] = compute_top_k(
            logits, sub_labels, k=5, reduction="none"
        )
        losses[f"{prefix}_oldacc@1"] = compute_top_k(
            logits_old, sub_labels, k=1, reduction="none"
        )
        losses[f"{prefix}_oldacc@5"] = compute_top_k(
            logits_old, sub_labels, k=5, reduction="none"
        )
        loss = ce_loss.mean()
        
        log_loss_dict(diffusion, sub_t, losses)
        del losses
        loss = loss 
        return loss 

    def forward_backward_log(data_loader, prefix="train"):
        batch, extra = next(data_loader)
        labels = extra["y"].to(dist_util.dev())

        batch = batch.to(dist_util.dev())
       
        # Noisy images
        if args.noised:
            t, _ = schedule_sampler.sample(batch.shape[0], dist_util.dev())
            mean, variance, _ = diffusion.q_mean_variance(batch, t)
            z = torch.randn_like(mean)
            std = variance**0.5
            xt = mean + std * z
        else:
            t = th.zeros(batch.shape[0], dtype=th.long, device=dist_util.dev())

        
        for i, (sub_batch, sub_labels, sub_t, sub_std) in enumerate(
            split_microbatches(args.microbatch, xt, labels, t, std)
        ):
            loss = loss_fn(prefix, sub_batch, sub_t, sub_std, sub_labels)
            if loss.requires_grad:
                if i == 0:
                    mp_trainer.zero_grad()
                mp_trainer.backward(loss * len(sub_batch) / len(batch))
            del loss

    for step in range(args.iterations - resume_step):
        logger.logkv("step", step + resume_step)
        logger.logkv(
            "samples",
            (step + resume_step + 1) * args.batch_size * dist.get_world_size(),
        )
        if args.anneal_lr:
            set_annealed_lr(opt, args.lr, (step + resume_step) / args.iterations)
        forward_backward_log(data)
        mp_trainer.optimize(opt)
        if val_data is not None and not step % args.eval_interval:
            with th.no_grad():
                with model.no_sync():
                    model.eval()
                    forward_backward_log(val_data, prefix="val")
                    model.train()
        if not step % args.log_interval:
            logger.dumpkvs()
        if (
            step
            and dist.get_rank() == 0
            and not (step + resume_step) % args.save_interval
        ):
            logger.log("saving model...")
            save_model(mp_trainer, opt, step + resume_step)

    if dist.get_rank() == 0:
        logger.log("saving model...")
        save_model(mp_trainer, opt, step + resume_step)
    dist.barrier()


def set_annealed_lr(opt, base_lr, frac_done):
    lr = base_lr * (1 - frac_done)
    for param_group in opt.param_groups:
        param_group["lr"] = lr


def save_model(mp_trainer, opt, step):
    if dist.get_rank() == 0:
        os.system(f"rm -rf {logger.get_dir()}/*.pt")
        th.save(
            mp_trainer.master_params_to_state_dict(mp_trainer.master_params),
            os.path.join(logger.get_dir(), f"model{step:06d}.pt"),
        )
        th.save(opt.state_dict(), os.path.join(logger.get_dir(), f"opt{step:06d}.pt"))


def compute_top_k(logits, labels, k, reduction="mean"):
    _, top_ks = th.topk(logits, k, dim=-1)
    if reduction == "mean":
        return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
    elif reduction == "none":
        return (top_ks == labels[:, None]).float().sum(dim=-1)


def split_microbatches(microbatch, *args):
    bs = len(args[0])
    if microbatch == -1 or microbatch >= bs:
        yield tuple(args)
    else:
        for i in range(0, bs, microbatch):
            yield tuple(x[i : i + microbatch] if x is not None else None for x in args)


def create_argparser():
    defaults = dict(
        denoise_augment=False,
        data_dir="/scratch/ssd004/datasets/imagenet/train",
        log_dir="workdirs/im64",
        val_data_dir="/scratch/ssd004/datasets/imagenet/val",
        noised=True,
        iterations=50000,
        lr=3e-4,
        weight_decay=0.0,
        anneal_lr=False,
        batch_size=64,
        microbatch=-1,
        schedule_sampler="uniform",
        log_interval=10,
        eval_interval=5,
        save_interval=100
    )
    defaults.update(classifier_and_diffusion_defaults())
    defaults.update(model_and_diffusion_defaults())
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)
    return parser


if __name__ == "__main__":
    main()




