from collections import OrderedDict

import logging
import torch

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

from PIL import Image
from pytorch_gan_metrics import get_inception_score_and_fid


def create_logger(logging_dir=None):
    """
    Create a logger that writes to a log file and stdout.
    """
    if dist.get_rank() == 0:  # real logger
        logging.basicConfig(
            level=logging.INFO,
            format='[\033[34m%(asctime)s\033[0m] %(message)s',
            datefmt='%Y-%m-%d %H:%M:%S',
            handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
        )
        logger = logging.getLogger(__name__)
    else:  # dummy logger (does nothing)
        logger = logging.getLogger(__name__)
        logger.addHandler(logging.NullHandler())
    return logger


# For FID & IS calculation
class TensorDataset(torch.utils.data.Dataset):
    r"""Dataset wrapping tensors.
    Each sample will be retrieved by indexing tensors along the first dimension.
    Args:
        *tensors (Tensor): tensors that have the same size of the first dimension.
    """
    def __init__(self, *tensors) -> None:
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), "Size mismatch between tensors"
        self.tensors = tensors

    def __getitem__(self, index):
        out = tuple(tensor[index] for tensor in self.tensors)
        out = out[0] if len(out) == 1 else out
        return out

    def __len__(self):
        return self.tensors[0].size(0)


@torch.no_grad()
def update_ema(ema_model, model, decay=0.9999):
    '''
    Step the EMA model towards the current model.
    '''
    ema_params = OrderedDict(ema_model.named_parameters())
    model_params = OrderedDict(model.named_parameters())

    for name, param in model_params.items():
        # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
        ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)


def requires_grad(model, flag=True):
    '''
    Set requires_grad flag for all parameters in a model.
    '''
    for p in model.parameters():
        p.requires_grad = flag


def save_ckpt(args, model, ema, opt, checkpoint_path):
    checkpoint = {
            'args': args,
            'model': model.module.state_dict(),
            'ema': ema.state_dict(),
            'opt': opt.state_dict(),
            }
    torch.save(checkpoint, checkpoint_path)
 

def sample_image(args, model, device, image_path, cond=False):
    model.eval()
    z = torch.randn(16*16, 3, 32, 32).to(device)
    c = torch.randint(0, args.num_classes, (16*16,)).to(device) if cond else None
    with torch.no_grad():
        x = model(z, c)
    
    x = x.view(16, 16, 3, 32, 32)
    x = (x * 127.5 + 128).clip(0, 255).to(torch.uint8)
    images = x.permute(0, 3, 1, 4, 2).reshape(16*32, 16*32, 3).cpu().numpy()
    
    Image.fromarray(images, 'RGB').save(image_path)
    model.train()


def num_to_groups(num, divisor):
    # For sampling using DDP
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr


def sample_fid(args, model, device, rank, set_grad=False, cond=False):
    # Setup batches for each node
    assert args.eval_samples % dist.get_world_size() == 0
    samples_per_node = args.eval_samples // dist.get_world_size()
    batches = num_to_groups(samples_per_node, args.eval_batch_size)
    
    # DDP EMA/online model
    model.eval()
    requires_grad(model, True)
    fid_model = DDP(model.to(device), device_ids=[rank])

    images = []
    with torch.no_grad():
        for n in batches:
            z = torch.randn(n, 3, 32, 32).to(device)
            c = torch.randint(0, args.num_classes, (n,)).to(device) if cond else None
            x = fid_model(z, c)
            images.append(x)
    images = torch.cat(images, dim=0)
    
    model.train()
    del fid_model
    torch.cuda.empty_cache()

    if not set_grad:
        requires_grad(model, False)
        model.eval()

    return images


def compute_fid_is(args, all_images, rank):
    all_images = torch.cat(all_images, dim=0)
    # all_images = (all_images * 127.5 + 128).clip(0, 255).to(torch.uint8).float().div(255).cpu()
    all_images = (all_images * 127.5 + 128).clip(0, 255).div(255).cpu()
    fid_dataset = TensorDataset(all_images)
    fid_sampler = DistributedSampler(
            fid_dataset, num_replicas=1, rank=rank, shuffle=False,
    )
    fid_loader = DataLoader(
            fid_dataset, batch_size=args.eval_batch_size, num_workers=args.num_workers,
            shuffle=False, sampler=fid_sampler, pin_memory=True,
    )
    (IS, IS_std), FID = get_inception_score_and_fid(fid_loader, args.stats_path)
    
    del fid_dataset, fid_sampler, fid_loader
    torch.cuda.empty_cache()

    return FID, IS
 




