import os
import torch as t
from tqdm import tqdm
import torch.nn.functional as F

from utils import utils
from utils.itdiffusion import DiffusionModel
from utils.stablediffusion import StableDiffuser
from configs.eval_itd_configs import parse_args_and_update_config

import json
from matplotlib import pyplot as plt
import textwrap
from transformers import AutoTokenizer, BertForMaskedLM


def word_prob_from_model(tokenizer, model, word, context = "[MASK]"):
    inputs = tokenizer(context, return_tensors = "pt", add_special_tokens = False)

    with t.no_grad():
        logits = model(**inputs).logits

    # retrieve index of [MASK]
    mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple = True)[0]
    logits = logits[0, mask_token_index][0]
    probs = t.softmax(logits, dim = -1)

    word_idx = tokenizer.encode(word, add_special_tokens = False)
    prob = probs[word_idx].cpu().numpy()[0]

    return prob


def min_max_norm(im):
    # min-max normalization
    im -= im.min()
    im /= im.max()
    return im


def plot_heatmaps_pid(img, caption, obj1, obj2, redun, syn, uniq1, uniq2, type_):
    titles = ['Real COCO', 'Redundancy', 'Uniqueness1', 'Uniqueness2', 'Synergy']
    sample_num = len(titles)
    cmap = 'jet'
    fig, ax = plt.subplots(1, sample_num, figsize=(7, 3))

    redun_ = min_max_norm(redun)
    uniq1_ = min_max_norm(uniq1)
    uniq2_ = min_max_norm(uniq2)
    syn_ = min_max_norm(syn)
    ax[0].imshow(t.clamp(img, 0, 1))
    ax[1].imshow(redun_, cmap=cmap, vmax=1, vmin=0)
    ax[2].imshow(uniq1_, cmap=cmap, vmax=1, vmin=0)
    ax[3].imshow(uniq2_, cmap=cmap, vmax=1, vmin=0)
    ax[4].imshow(syn_, cmap=cmap, vmax=1, vmin=0)

    for i in range(sample_num):
        ax[i].set_title(titles[i], fontsize=14)
        ax[i].axis('off')

    # split caption if it's too long
    wrapped_text = "\n".join(textwrap.wrap(caption, width=52))
    text = f'c = {wrapped_text}\n$y_*$ = {obj1} vs {obj2} ({type_})'
    ax[0].text(0, 700 + 50 * len(wrapped_text.split('\n')), text, va="bottom", ha='left', fontsize=13)

    return fig


def main():
    config = parse_args_and_update_config()
    t.manual_seed(config.seed)

    # set hyper-parameters
    res_type = config.res_type
    sdm_version = config.sdm_version
    eval_metrics = config.eval_metrics
    res_out_dir = config.res_out_dir
    data_in_dir = config.data_in_dir
    csv_name = config.csv_name
    z_sample_num = config.n_samples_per_point
    snr_num = config.num_steps
    bs = config.batch_size
    logsnr_loc = config.logsnr_loc
    logsnr_scale = config.logsnr_scale
    clip = config.clip
    upscale_mode = config.upscale_mode
    int_mode = config.int_mode
    dataset_type = config.dataset_type
    save_freq = config.save_freq
    min_batch_idx = config.min_batch_idx
    max_batch_idx = config.max_batch_idx
    word_freq_path = config.word_freq_path
    word_freq_model = config.word_freq_model
    image_level = config.image_level

    print(f'sdm: {sdm_version} | result: {res_type} | dataset: {csv_name} | integration: {int_mode} '
          f'| upscale: {upscale_mode} | z sample #: {z_sample_num} | snr #: {snr_num} | eval metrics: {eval_metrics}')

    # load diffusion models
    if sdm_version == 'sdm_2_0_base':
        sdm = StableDiffuser("stabilityai/stable-diffusion-2-base")
    elif sdm_version == 'sdm_2_1_base':
        sdm = StableDiffuser("stabilityai/stable-diffusion-2-1-base")

    latent_shape = (sdm.channels, sdm.width, sdm.height)
    itd = DiffusionModel(sdm.unet, latent_shape, logsnr_loc=logsnr_loc, logsnr_scale=logsnr_scale, clip=clip, logsnr2t=sdm.logsnr2t)

    # load data
    # TODO: convert it from CSV file to JSON file
    img_dir = os.path.join(data_in_dir, f'val2017')
    csv_dir = os.path.join(data_in_dir, f'{csv_name}.csv')
    annotation_file = os.path.join(data_in_dir, f'annotations/instances_val2017.json')

    if dataset_type == "COCO-IT":
        dataset = utils.CocoDataset(img_dir, annotation_file, csv_dir)
    elif dataset_type == "coco_ours":
        dataset = utils.CocoDatasetOurs(img_dir, annotation_file, csv_dir)
    elif dataset_type == "custom":
        dataset = utils.CustomDataset(csv_dir)

    dataloader = t.utils.data.DataLoader(dataset, batch_size=bs, shuffle=False, num_workers=0)

    # assign noise levels
    logsnrs = t.linspace(logsnr_loc - clip * logsnr_scale, logsnr_loc + clip * logsnr_scale, snr_num).to(sdm.device)

    tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
    model = BertForMaskedLM.from_pretrained("google-bert/bert-base-uncased")

    # evaluate
    with t.no_grad():
        results = {}
        mmses_list = []
        mmses_diff_appx_list = []
        pixel_mmses_list = []
        pixel_mmses_diff_appx_list = []
        nll_list = []
        mi_appx_list = []
        pixel_nll_list = []
        pixel_mi_appx_list = []

        pixel_redun_list = []
        pixel_uniq_list = []
        pixel_syn_list = []

        # save results
        if not os.path.exists(res_out_dir):
            os.makedirs(res_out_dir)
            
        word_probs = {}
        if eval_metrics == 'pid':
            with open(word_freq_path) as f:
                word_probs = json.load(f)

        for batch_idx, batch in tqdm(enumerate(dataloader)):
            if batch_idx < min_batch_idx or (max_batch_idx >= 0 and batch_idx > max_batch_idx):
                continue

            if eval_metrics == 'pid' and dataset_type == "COCO-IT":
                batch['obj1'] = batch['category']
                batch['obj2'] = batch['context']
                batch['full'] = batch['caption']

            elif eval_metrics in ['mi', 'cmi'] and dataset_type == "custom":
                batch['image'] = t.cat([batch['image'], batch['image']])
                batch['caption'] = batch['caption'] + batch['caption']
                batch['category'] = batch['obj1'] + batch['obj2']
                batch['context'] = batch['context1'] + batch['context2']

            if eval_metrics == 'mi':
                images = batch['image']
                correct_prompts = batch['category']
                wrong_prompts = [""] * len(images)
                none_prompts = [""] * len(images)
            elif eval_metrics == 'cmi':
                images = batch['image']
                correct_prompts = batch['caption']
                wrong_prompts = batch['context']
                none_prompts = [""] * len(images)
            elif eval_metrics == 'pid':
                images = batch['image']
                full_prompts = batch['full']
                correct_prompts1 = batch['obj1']
                correct_prompts2 = batch['obj2']
                correct_prompts = correct_prompts1 + correct_prompts2 + full_prompts
                wrong_prompts = [""] * len(images)
                none_prompts = [""] * len(images)

                batch_word_probs1 = []
                batch_word_probs2 = []

                if word_freq_model:
                    for ix in range(len(batch['obj1'])):
                        batch_word_probs1.append(word_prob_from_model(tokenizer, model, batch['obj1'][ix]))
                        batch_word_probs2.append(word_prob_from_model(tokenizer, model, batch['obj2'][ix]))

                else:
                    for ix in range(len(batch['obj1'])):
                        if batch['obj1'][ix] in word_probs.keys():
                            batch_word_probs1.append(word_probs[batch['obj1'][ix]])
                        else:
                            batch_word_probs1.append(0.0)

                        if batch['obj2'][ix] in word_probs.keys():
                            batch_word_probs2.append(0.0)
                        else:
                            batch_word_probs2.append(0.0)

                batch_word_probs = batch_word_probs1 + batch_word_probs2
            elif eval_metrics == 'cpid':
                images = batch['image']
                full_prompts = batch['caption']
                correct_prompts1 = batch['context2']
                correct_prompts2 = batch['context1']
                correct_prompts = correct_prompts1 + correct_prompts2 + full_prompts
                wrong_prompts = batch['context']
                none_prompts = [""] * len(images)

                batch_word_probs1 = []
                batch_word_probs2 = []

                for ix in range(len(batch['obj1'])):
                    context1 = batch['context2'][ix].replace(batch['obj1'][ix], "[MASK]")
                    context2 = batch['context1'][ix].replace(batch['obj2'][ix], "[MASK]")
                    batch_word_probs1.append(word_prob_from_model(tokenizer, model, batch['obj1'][ix], context1))
                    batch_word_probs2.append(word_prob_from_model(tokenizer, model, batch['obj2'][ix], context2))

                batch_word_probs = batch_word_probs1 + batch_word_probs2
            else:
                assert "Please input the correct strategy type, choosing among: 'mi', 'cmi'. "

            # compute latent variable z and text embeddings in stable diffusion model
            latent_images = sdm.encode_latents(images)
            text_embeddings = sdm.encode_prompts(correct_prompts)
            wro_embeddings = sdm.encode_prompts(wrong_prompts)
            uncond_embeddings = sdm.encode_prompts(none_prompts)
            if eval_metrics == 'pid' or eval_metrics == 'cpid':
                conds = t.stack((*text_embeddings.chunk(3), wro_embeddings, uncond_embeddings))
            else:
                conds = t.stack([text_embeddings, wro_embeddings, uncond_embeddings])

            if res_type == 'mse_1D':
                print(f'Calculate image-level mmses and its difference for Batch {batch_idx}...')
                mmses, mmses_diff_appx = itd.image_level_mmse(latent_images, conds, logsnrs, total=z_sample_num)

                # Post-process results
                mmses = mmses.permute(2, 1, 0) # bs * 3 * snr_num
                mmses_diff_appx = mmses_diff_appx.permute(2, 1, 0)  # bs * 3 * snr_num
                mmses_list.append(mmses)
                mmses_diff_appx_list.append(mmses_diff_appx)
                print('Done\n')
            elif res_type == 'mse_2D':
                print(f'Calculate pixel-level mmses and its difference for Batch {batch_idx}...')
                pixel_mmses, pixel_mmses_diff_appx = itd.pixel_level_mmse(latent_images, conds, logsnrs, total=z_sample_num)

                # Post-process results
                pixel_mmses_up = t.zeros(list(pixel_mmses.shape[:-2]) + [512, 512])
                pixel_mmses_diff_appx_up = t.zeros(list(pixel_mmses_diff_appx.shape[:-2]) + [512, 512])
                for i in range(snr_num):
                    pixel_mmses_up[i] = F.interpolate(pixel_mmses[i], size=(512, 512), mode=upscale_mode)
                    pixel_mmses_diff_appx_up[i] = F.interpolate(pixel_mmses_diff_appx[i], size=(512, 512), mode=upscale_mode)
                pixel_mmses_up = pixel_mmses_up.permute(2, 1, 0, 3, 4)  # bs * 3 * snr_num * h * w
                pixel_mmses_diff_appx_up = pixel_mmses_diff_appx_up.permute(2, 1, 0, 3, 4)  # bs * 3 * snr_num * h * w
                pixel_mmses_list.append(pixel_mmses_up)
                pixel_mmses_diff_appx_list.append(pixel_mmses_diff_appx_up)
                print('Done\n')
            elif res_type == 'nll_1D':
                print(f'Calculate image-level nll and {eval_metrics} for Batch {batch_idx}...')
                # nll, mi_appx = itd.image_level_nll_fast(latent_images, conds, total=snr_num, int_mode=int_mode)
                nll, mi_appx = itd.image_level_nll(latent_images, conds, snr_num=snr_num, z_sample_num=z_sample_num, int_mode=int_mode)

                # Post-process results
                nll = nll.permute(1, 0) # bs * 3
                mi_appx = mi_appx.permute(1, 0) # bs * 3
                nll_list.append(nll)
                mi_appx_list.append(mi_appx)
                print('Done\n')
            elif res_type == 'nll_2D':
                print(f'Calculate pixel-level nll and {eval_metrics} for Batch {batch_idx}...')
                pixel_nll, pixel_mi_appx = itd.pixel_level_nll(latent_images, conds, snr_num=snr_num, z_sample_num=z_sample_num, int_mode=int_mode)

                # Post-process results
                pixel_nll = F.interpolate(pixel_nll, size=(512, 512), mode=upscale_mode)  # 3 * bs * h' * w'
                pixel_mi_appx = F.interpolate(pixel_mi_appx, size=(512, 512), mode=upscale_mode)  # 3 * bs * h' * w'
                pixel_nll = pixel_nll.permute(1, 0, 2, 3) # bs * 3 * h' * w'
                pixel_mi_appx = pixel_mi_appx.permute(1, 0, 2, 3) # bs * 3 * h' * w'
                pixel_nll_list.append(pixel_nll)
                pixel_mi_appx_list.append(pixel_mi_appx)
                print('Done\n')

            elif res_type == 'pid':
                print(f'Calculate pixel-level nll and {eval_metrics} for Batch {batch_idx}...')
                redundancy, uniqueness, synergy = itd.pid(latent_images, conds, batch_word_probs, snr_num=snr_num, z_sample_num=z_sample_num, int_mode=int_mode, image_level=image_level)

                # Post-process results
                if not image_level:
                    pixel_mi_redun = F.interpolate(redundancy.unsqueeze(0), size = (512, 512), mode = upscale_mode)  # 1 * bs * h' * w'
                    pixel_mi_uniq = F.interpolate(uniqueness, size = (512, 512), mode = upscale_mode)  # num_words * bs * h' * w'
                    pixel_mi_syn = F.interpolate(synergy.unsqueeze(0), size = (512, 512), mode = upscale_mode)  # 1 * bs * h' * w'
                else:
                    pixel_mi_redun = redundancy.unsqueeze(0)
                    pixel_mi_uniq = uniqueness
                    pixel_mi_syn = synergy.unsqueeze(0)

                pixel_redun_list.append(pixel_mi_redun[0])
                pixel_uniq_list.append(pixel_mi_uniq.transpose(1, 0))
                pixel_syn_list.append(pixel_mi_syn[0])
                print('Done\n')

            else:
                assert "Please input the correct results type, choosing among: 'mse_1D', 'mse_2D', 'nll_1D', 'nll_2D'."

            if (not batch_idx % save_freq and batch_idx) or batch_idx == len(dataloader) - 1:
                if res_type == 'mse_1D':
                    results['mmses'] = t.cat(mmses_list)  # N * 3 * snr_num
                    results['mmses_diff_appx'] = t.cat(mmses_diff_appx_list)  # N * 3 * snr_num
                elif res_type == 'mse_2D':
                    results['pixel_mmses'] = t.cat(pixel_mmses_list)  # N * 3 * snr_num
                    results['pixel_mmses_diff_appx'] = t.cat(pixel_mmses_diff_appx_list)  # N * 3 * snr_num
                elif res_type == 'nll_1D':
                    results['nll'] = t.cat(nll_list)
                    results['mi'] = t.cat(mi_appx_list)
                elif res_type == 'nll_2D':
                    results['pixel_nll'] = t.cat(pixel_nll_list)
                    results['pixel_mi'] = t.cat(pixel_mi_appx_list)
                elif res_type == 'pid':
                    results['pixel_redun'] = t.cat(pixel_redun_list)
                    results['pixel_uniq'] = t.cat(pixel_uniq_list)
                    results['pixel_syn'] = t.cat(pixel_syn_list)

                out_file_name = f'{sdm_version}-{res_type}-{csv_name}-{int_mode}-{z_sample_num}-{snr_num}-{eval_metrics}-{batch_idx}.pt'
                out_path = os.path.join(res_out_dir, out_file_name)
                t.save(results, out_path)
                print(f'Results are saved to: {out_path}')

                mmses_list = []
                mmses_diff_appx_list = []
                pixel_mmses_list = []
                pixel_mmses_diff_appx_list = []
                nll_list = []
                mi_appx_list = []
                pixel_nll_list = []
                pixel_mi_appx_list = []

                pixel_redun_list = []
                pixel_uniq_list = []
                pixel_syn_list = []


if __name__ == "__main__":
    main()
