"""
This code is adapted from
https://github.com/nanlliu/Unsupervised-Compositional-Concepts-Discovery/blob/master/daam_ddim_visualize.py
"""
import torch
from functools import partial
from torchvision import transforms
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer

from diffusers.models import AutoencoderKL, UNet2DConditionModel
# from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler,PNDMScheduler, LMSDiscreteScheduler
from diffusers.utils import logging
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

import os
from tqdm import tqdm
from PIL import Image
from daam import trace, set_seed
from utils import utils
from utils.stablediffusion import StableDiffuser
from typing import Callable, List, Optional, Union
from configs.eval_daam_configs import parse_args_and_update_config

from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline


def backward_ddim(x_t, alpha_t, alpha_tm1, eps_xt):
    """ from noise to image"""
    return (
        alpha_tm1**0.5
        * (
            (alpha_t**-0.5 - alpha_tm1**-0.5) * x_t
            + ((1 / alpha_tm1 - 1) ** 0.5 - (1 / alpha_t - 1) ** 0.5) * eps_xt
        )
        + x_t
    )


def forward_ddim(x_t, alpha_t, alpha_tp1, eps_xt):
    """ from image to noise, it's the same as backward_ddim"""
    return backward_ddim(x_t, alpha_t, alpha_tp1, eps_xt)


class DDIMPipeline(StableDiffusionPipeline):
# class DDIMPipeline(DiffusionPipeline):
    def __init__(
            self,
            vae: AutoencoderKL,
            text_encoder: CLIPTextModel,
            tokenizer: CLIPTokenizer,
            unet: UNet2DConditionModel,
            scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
            safety_checker: StableDiffusionSafetyChecker = None,
            feature_extractor: CLIPFeatureExtractor = None,
    ):
        super().__init__(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            scheduler=scheduler,
            safety_checker=safety_checker,
            feature_extractor=feature_extractor,
        )

        # self.register_modules(
        #     vae=vae,
        #     text_encoder=text_encoder,
        #     tokenizer=tokenizer,
        #     unet=unet,
        #     scheduler=scheduler,
        #     safety_checker=safety_checker,
        #     feature_extractor=feature_extractor,
        # )
        self.vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
        self.forward_diffusion = partial(self.backward_diffusion, reverse_process=True)

    def run_safety_checker(self, image, device, dtype):
        if self.safety_checker is not None:
            safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
            image, has_nsfw_concept = self.safety_checker(
                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
            )
        else:
            has_nsfw_concept = None
        return image, has_nsfw_concept

    @torch.inference_mode()
    def get_text_embedding(self, prompt):
        text_input_ids = self.tokenizer(
            prompt,
            padding="max_length",
            truncation=True,
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids
        text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
        return text_embeddings

    @torch.inference_mode()
    def get_image_latents(self, image, sample=True, rng_generator=None):
        encoding_dist = self.vae.encode(image).latent_dist
        if sample:
            encoding = encoding_dist.sample(generator=rng_generator)
        else:
            encoding = encoding_dist.mode()
        latents = encoding * 0.18215
        return latents

    @torch.inference_mode()
    def backward_diffusion(
            self,
            use_old_emb_i=25,
            prompt=None,
            text_embeddings=None,
            old_text_embeddings=None,
            new_text_embeddings=None,
            latents: Optional[torch.FloatTensor] = None,
            num_inference_steps: int = 50,
            guidance_scale: float = 7.5,
            callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
            callback_steps: Optional[int] = 1,
            reverse_process: True = False,
            **kwargs,
    ):
        """ Generate image from text prompt and latents
        """
        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
        # corresponds to doing no classifier free guidance.
        do_classifier_free_guidance = guidance_scale > 1.0
        if text_embeddings is None:
            text_embeddings = self._encode_prompt(prompt, device=self.device,
                                                  num_images_per_prompt=1,
                                                  do_classifier_free_guidance=do_classifier_free_guidance)
        # set timesteps
        self.scheduler.set_timesteps(num_inference_steps)
        # Some schedulers like PNDM have timesteps as arrays
        # It's more optimized to move all timesteps to correct device beforehand
        timesteps_tensor = self.scheduler.timesteps.to(self.device)
        # scale the initial noise by the standard deviation required by the scheduler
        latents = latents * self.scheduler.init_noise_sigma

        if old_text_embeddings is not None and new_text_embeddings is not None:
            prompt_to_prompt = True
        else:
            prompt_to_prompt = False

        for i, t in enumerate(
                self.progress_bar(timesteps_tensor if not reverse_process else reversed(timesteps_tensor))):
            if prompt_to_prompt:
                if i < use_old_emb_i:
                    text_embeddings = old_text_embeddings
                else:
                    text_embeddings = new_text_embeddings

            # expand the latents if we are doing classifier free guidance
            latent_model_input = (
                torch.cat([latents] * 2) if do_classifier_free_guidance else latents
            )
            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

            # predict the noise residual
            noise_pred = self.unet(
                latent_model_input, t, encoder_hidden_states=text_embeddings
            ).sample

            # perform guidance
            if do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * (
                        noise_pred_text - noise_pred_uncond
                )

            prev_timestep = (
                    t
                    - self.scheduler.config.num_train_timesteps
                    // self.scheduler.num_inference_steps
            )
            # call the callback, if provided
            if callback is not None and i % callback_steps == 0:
                callback(i, t, latents)

            # ddim
            alpha_prod_t = self.scheduler.alphas_cumprod[t]
            alpha_prod_t_prev = (
                self.scheduler.alphas_cumprod[prev_timestep]
                if prev_timestep >= 0
                else self.scheduler.final_alpha_cumprod
            )
            if reverse_process:
                alpha_prod_t, alpha_prod_t_prev = alpha_prod_t_prev, alpha_prod_t
            latents = backward_ddim(
                x_t=latents,
                alpha_t=alpha_prod_t,
                alpha_tm1=alpha_prod_t_prev,
                eps_xt=noise_pred,
            )
        return latents

    @torch.inference_mode()
    def decode_image(self, latents: torch.FloatTensor, **kwargs) -> List["PIL_IMAGE"]:
        scaled_latents = 1 / 0.18215 * latents
        image = [
            self.vae.decode(scaled_latents[i: i + 1]).sample for i in range(len(latents))
        ]
        image = torch.cat(image, dim=0)
        return image

    @torch.inference_mode()
    def torch_to_numpy(self, image) -> List["PIL_IMAGE"]:
        # image = (image / 2 + 0.5).clamp(0, 1) # the output of decode_latents() is in [0, 1]
        image = image.clamp(0, 1)
        image = image.cpu().permute(0, 2, 3, 1).numpy()
        return image

    def _encode_prompt(
            self,
            prompt,
            device,
            num_images_per_prompt,
            do_classifier_free_guidance,
            negative_prompt=None,
            prompt_embeds: Optional[torch.FloatTensor] = None,
            negative_prompt_embeds: Optional[torch.FloatTensor] = None,
    ):
        r"""
        Encodes the prompt into text encoder hidden states.
        Args:
             prompt (`str` or `List[str]`, *optional*):
                prompt to be encoded
            device: (`torch.device`):
                torch device
            num_images_per_prompt (`int`):
                number of images that should be generated per prompt
            do_classifier_free_guidance (`bool`):
                whether to use classifier free guidance or not
            negative_prompt (`str` or `List[str]`, *optional*):
                The prompt or prompts not to guide the image generation. If not defined, one has to pass
                `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
                Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
            prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
                argument.
        """
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        if prompt_embeds is None:
            text_inputs = self.tokenizer(
                prompt,
                padding="max_length",
                max_length=self.tokenizer.model_max_length,
                truncation=True,
                return_tensors="pt",
            )
            text_input_ids = text_inputs.input_ids
            untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids

            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
                    text_input_ids, untruncated_ids
            ):
                removed_text = self.tokenizer.batch_decode(
                    untruncated_ids[:, self.tokenizer.model_max_length - 1: -1]
                )
                logger.warning(
                    "The following part of your input was truncated because CLIP can only handle sequences up to"
                    f" {self.tokenizer.model_max_length} tokens: {removed_text}"
                )

            if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
                attention_mask = text_inputs.attention_mask.to(device)
            else:
                attention_mask = None

            prompt_embeds = self.text_encoder(
                text_input_ids.to(device),
                attention_mask=attention_mask,
            )
            prompt_embeds = prompt_embeds[0]

        prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)

        bs_embed, seq_len, _ = prompt_embeds.shape
        # duplicate text embeddings for each generation per prompt, using mps friendly method
        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)

        # get unconditional embeddings for classifier free guidance
        if do_classifier_free_guidance and negative_prompt_embeds is None:
            uncond_tokens: List[str]
            if negative_prompt is None:
                uncond_tokens = [""] * batch_size
            elif type(prompt) is not type(negative_prompt):
                raise TypeError(
                    f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
                    f" {type(prompt)}."
                )
            elif isinstance(negative_prompt, str):
                uncond_tokens = [negative_prompt]
            elif batch_size != len(negative_prompt):
                raise ValueError(
                    f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
                    f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
                    " the batch size of `prompt`."
                )
            else:
                uncond_tokens = negative_prompt

            max_length = prompt_embeds.shape[1]
            uncond_input = self.tokenizer(
                uncond_tokens,
                padding="max_length",
                max_length=max_length,
                truncation=True,
                return_tensors="pt",
            )

            if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
                attention_mask = uncond_input.attention_mask.to(device)
            else:
                attention_mask = None

            negative_prompt_embeds = self.text_encoder(
                uncond_input.input_ids.to(device),
                attention_mask=attention_mask,
            )
            negative_prompt_embeds = negative_prompt_embeds[0]

        if do_classifier_free_guidance:
            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
            seq_len = negative_prompt_embeds.shape[1]

            negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)

            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

            # For classifier free guidance, we need to do two forward passes.
            # Here we concatenate the unconditional and text embeddings into a single batch
            # to avoid doing two forward passes
            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

        return prompt_embeds


def load_img(path, target_size=512):
    """Load an image, resize and output -1..1"""
    image = Image.open(path).convert("RGB")

    tform = transforms.Compose(
        [
            transforms.Resize(target_size),
            transforms.CenterCrop(target_size),
            transforms.ToTensor(),
        ]
    )
    image = tform(image)
    return 2.0 * image - 1.0


def latents_to_imgs(latents):
    x = pipe.decode_image(latents)
    x = pipe.torch_to_numpy(x)
    x = pipe.numpy_to_pil(x)
    return x


if __name__ == '__main__':
    config = parse_args_and_update_config()
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # set hyper-parameters
    sdm_version = config.sdm_version
    data_in_dir = config.data_in_dir
    res_out_dir = config.res_out_dir
    guide_scale = config.guide_scale
    csv_name = config.csv_name
    gen = set_seed(config.seed)
    bs = config.batch_size  # DAAM only support single image processing
    steps = config.infer_step
    dataset_type = config.dataset_type
    save_freq = config.save_freq
    min_batch_idx = config.min_batch_idx
    max_batch_idx = config.max_batch_idx

    print(f'sdm: {sdm_version} | result: attnmaps | dataset: {csv_name} | guide_scale: {guide_scale} | steps: {steps}')

    # load data
    img_dir = os.path.join(data_in_dir, f'val2017')
    csv_dir = os.path.join(data_in_dir, f'{csv_name}.csv')
    annotation_file = os.path.join(data_in_dir, f'annotations/instances_val2017.json')

    if dataset_type == "COCO-IT":
        dataset = utils.CocoDataset(img_dir, annotation_file, csv_dir)
    elif dataset_type == "coco_ours":
        dataset = utils.CocoDatasetOurs(img_dir, annotation_file, csv_dir)
    elif dataset_type == "custom":
        dataset = utils.CustomDataset(csv_dir)

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=bs, shuffle=False, num_workers=0)

    # load model pipeline
    if sdm_version == 'sdm_2_0_base':
        sdm = StableDiffuser("stabilityai/stable-diffusion-2-base")
    elif sdm_version == 'sdm_2_1_base':
        sdm = StableDiffuser("stabilityai/stable-diffusion-2-1-base")
    sdm.scheduler = DDIMScheduler.from_config(sdm.scheduler.config)
    pipe = DDIMPipeline(
        vae=sdm.vae,
        text_encoder=sdm.text_encoder,
        tokenizer=sdm.tokenizer,
        unet=sdm.unet,
        scheduler=sdm.scheduler,
        safety_checker=sdm.sdm_pipe.safety_checker
    )

    heat_map_list = []

    if not os.path.exists(res_out_dir):
        os.makedirs(res_out_dir)

    with torch.no_grad():
        for batch_idx, batch in tqdm(enumerate(dataloader)):
            if batch_idx < min_batch_idx or (max_batch_idx >= 0 and batch_idx > max_batch_idx):
                continue

            print(f"\nProcessing batch {batch_idx}...\n")

            if dataset_type == "custom":
                batch['image'] = torch.cat([batch['image'], batch['image']])
                batch['caption'] = batch['caption'] + batch['caption']
                batch['obj'] = batch['obj1'] + batch['obj2']

            img = batch['image']
            prompt = batch['caption']
            keyword = batch['obj']

            # encode image and text
            text_embeddings = pipe.get_text_embedding(prompt)
            image_latents = pipe.get_image_latents(img.to(device), rng_generator=gen)

            # adding noise process
            reversed_latents = pipe.forward_diffusion(
                latents=image_latents,
                text_embeddings=text_embeddings,
                guidance_scale=guide_scale,
                num_inference_steps=steps,
            )

            # denoising process
            with torch.cuda.amp.autocast(dtype=torch.float16), torch.no_grad():
                with trace(pipe) as tc:
                    reconstructed_latents = pipe.backward_diffusion(
                        latents=reversed_latents,
                        prompt=prompt,
                        guidance_scale=guide_scale,
                        num_inference_steps=steps,
                    )
                    # generate heatmap with [correct obj word + context]
                    rec_img = latents_to_imgs(reconstructed_latents)[0]
                    # rec_img = pipe(prompt, guidance_scale = guide_scale, num_inference_steps = steps)

                    for ix in range(len(prompt)):
                        heat_map = tc.compute_global_heat_map(prompt=prompt[ix])
                        heat_map = heat_map.compute_word_heat_map(keyword[ix])
                        heat_map_abs = heat_map.expand_as(rec_img)
                        heat_map_list.append(heat_map_abs)

            if (not batch_idx % save_freq and batch_idx) or batch_idx == len(dataloader) - 1:
                # save heatmaps
                results = {}
                heat_maps = torch.stack(heat_map_list)
                results['attnmaps'] = heat_maps
                file_name = f'{sdm_version}-attnmaps-{csv_name}-{steps}-{batch_idx}.pt'
                out_path = os.path.join(res_out_dir, file_name)
                torch.save(results, out_path)
                print('Done')
                print(f'The result is stored to {out_path}.')
                heat_map_list = []
