import torch
from collections import OrderedDict
from os import path as osp
from tqdm import tqdm
from torch.nn import functional as F
from functools import partial
import numpy as np
import math
from torch import device, nn, einsum
from inspect import isfunction
from functools import partial
from dpm_model.dpm_utils import make_beta_schedule, default
from config.base_config import Config

from ldm.modules.diffusionmodules.openaimodel_tailor import UNetModel
import yaml

from dpm_model.DiT_txt_trunc import DiT_txt_trunc

def load_config(config_file):
    with open(config_file, 'r') as f:
        config = yaml.safe_load(f)
    return config

def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))

def noise_like(shape, device, repeat=False):
    def repeat_noise(): return torch.randn(
        (1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))

    def noise(): return torch.randn(shape, device=device)
    return repeat_noise() if repeat else noise()


class DPM(nn.Module):
    def __init__(self, config: Config, dpm_arch='3D_transformer'):
        super(DPM, self).__init__()
        self.config = config
        self.dpm_arch = dpm_arch
        if self.dpm_arch == '2D_transformer':
            self.net_d = denoising_transformer(config)
        elif self.dpm_arch == '3D_transformer':
            self.net_d = denoising_transformer_3D(config)
        elif self.dpm_arch == 'unet_256':
            self.net_d = UNetMaxDim256(config)
        elif self.dpm_arch == 'unet_512':
            self.net_d = UNetMaxDim512(config)
        elif self.dpm_arch == 'ldm_unet':
            yml_config = load_config(config.yml_pth)
            self.net_d = UNetModel(yml_config)
        elif self.dpm_arch == 'DiT_video':
            self.net_d = DiT_video(config)
        elif self.dpm_arch == 'DiT_frame':
            self.net_d = DiT_frame(config)
        elif self.dpm_arch == 'DiT_vidpool':
            self.net_d = DiT_vidpool(config)
        elif self.dpm_arch == 'DiT_txt':
            self.net_d =DiT_txt(config)
        elif self.dpm_arch == 'DiT_txt_trunc':
            self.net_d = DiT_txt_trunc(config)
        else:
            raise NotImplementedError

    def set_new_noise_schedule(self):
        to_torch = partial(torch.tensor, dtype=torch.float32, device='cuda')

        # β1, β2, ..., βΤ (T)
        betas = make_beta_schedule(
            schedule=self.config.dpm_beta_schedule,
            n_timestep=self.config.n_timestep,
            linear_start=self.config.beta_linear_start,
            linear_end=self.config.beta_linear_end)
        betas = betas.detach().cpu().numpy() if isinstance(
            betas, torch.Tensor) else betas

        # α1, α2, ..., αΤ (T)
        alphas = 1. - betas
        # α1, α1α2, ..., α1α2...αΤ (T)
        alphas_cumprod = np.cumprod(alphas, axis=0)
        # 1, α1, α1α2, ...., α1α2...αΤ-1 (T)
        alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
        # 1, √α1, √α1α2, ...., √α1α2...αΤ (T+1)
        self.sqrt_alphas_cumprod_prev = np.sqrt(
            np.append(1., alphas_cumprod))

        timesteps, = betas.shape
        self.num_timesteps = int(timesteps)
        self.trunc_timestep = self.config.trunc_timestep
        self.v_posterior = self.config.v_posterior
        self.t_var = self.config.t_var

        self.register_buffer('betas', to_torch(betas))
        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
        self.register_buffer('alphas_cumprod_prev',
                             to_torch(alphas_cumprod_prev))

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer('sqrt_alphas_cumprod',
                             to_torch(np.sqrt(alphas_cumprod)))
        self.register_buffer('sqrt_one_minus_alphas_cumprod',
                             to_torch(np.sqrt(1. - alphas_cumprod)))
        self.register_buffer('log_one_minus_alphas_cumprod',
                             to_torch(np.log(1. - alphas_cumprod)))
        self.register_buffer('sqrt_recip_alphas_cumprod',
                             to_torch(np.sqrt(1. / alphas_cumprod)))
        self.register_buffer('sqrt_recipm1_alphas_cumprod',
                             to_torch(np.sqrt(1. / alphas_cumprod - 1)))

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
        self.register_buffer('posterior_variance',
                             to_torch(posterior_variance))
        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
        self.register_buffer('posterior_log_variance_clipped', to_torch(
            np.log(np.maximum(posterior_variance, 1e-20))))
        self.register_buffer('posterior_mean_coef1', to_torch(
            betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
        self.register_buffer('posterior_mean_coef2', to_torch(
            (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))

    def predict_start_from_noise(self, x_t, t, noise):
        return (
            extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
            extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
        )

    def q_posterior(self, x_start, x_t, t):

        if self.config.wo_posterior_mean_coef1:
            term1 = x_start
        else:
            term1 = extract(self.posterior_mean_coef1, t, x_t.shape) * x_start

        if self.config.wo_posterior_mean_coef2:
            term2 = x_t
        else:
            term2 = extract(self.posterior_mean_coef2, t, x_t.shape) * x_t

        posterior_mean = (
                 term1 + term2
        )
        # posterior_mean = (
        #     extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
        #     extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
        # )
        posterior_variance = extract(self.posterior_variance, t, x_t.shape)
        posterior_log_variance_clipped = extract(
            self.posterior_log_variance_clipped, t, x_t.shape)
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    def p_mean_variance(self, x, t, clip_denoised=True, ema_model=False):
        # if condition_txt is None:
        #     raise RuntimeError('Must have condition txt')

        if ema_model:
            print("TODO")
        else:
            if self.dpm_arch == 'DiT_txt_trunc':
                # print(f'>>>[p_mean_variance], x={x.shape}, condition_txt={condition_txt.shape}') # x=torch.Size([1000, 1, 512]), condition_txt=torch.Size([1000, 512])
                x_recon = self.predict_start_from_noise(x, t=t, noise=self.net_d(x=x, t=t))
            else:
                raise ValueError
        if clip_denoised:
            x_recon.clamp_(-1., 1.)

        # print(f'>>>[p_mean_variance], x_recon={x_recon.shape}') # [bs,1, dim]
        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
            x_start=x_recon, x_t=x, t=t)
        # print(f'>>>[p_mean_variance] model_mean={model_mean.shape}, posterior_log_variance={posterior_log_variance.shape}') # [bs, 1, dim], scalar
        return model_mean, posterior_variance, posterior_log_variance

    def p_sample(self, x, t, clip_denoised=True, repeat_noise=False, ema_model=False):
        b, *_, device = *x.shape, x.device
        model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
        noise = noise_like(x.shape, device, repeat_noise)
        nonzero_mask =  (1 - (t == 0).float()).reshape(b,*((1,) * (len(x.shape) - 1)))
        return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise

    def p_sample_loop(self, x_txt, ema_model=False):
        shape = x_txt.shape
        device = self.betas.device

        embed = x_txt.unsqueeze(1)
        b = x_txt.shape[0]
        for i in reversed(range(0, self.trunc_timestep)):
            embed = self.p_sample(embed, torch.full((b,), i, device=device, dtype=torch.long), ema_model=ema_model)
        return embed

    def q_sample(self, x_start, t, noise=None):
        noise = default(noise, lambda: torch.randn_like(x_start))

        # fix gama
        return (
                extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
                extract(self.sqrt_one_minus_alphas_cumprod,
                        t, x_start.shape) * noise
        )

    def q_sample_without_trunc_timestep(self, txt_feature, t, noise, sqrt_alphas_cumprod_trunc_value, sqrt_one_minus_alphas_cumprod_trunc_value):
        noise = default(noise, lambda: torch.randn_like(txt_feature))

        x_start = 1./sqrt_alphas_cumprod_trunc_value * (txt_feature - sqrt_one_minus_alphas_cumprod_trunc_value * noise)

        return (
                extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
                extract(self.sqrt_one_minus_alphas_cumprod,
                        t, x_start.shape) * noise
        )

    def q_sample_at_trunc_timestep(self, x_start, noise, sqrt_alphas_cumprod_trunc_value, sqrt_one_minus_alphas_cumprod_trunc_value):
        noise = default(noise, lambda: torch.randn_like(x_start))
        return sqrt_alphas_cumprod_trunc_value * x_start + sqrt_one_minus_alphas_cumprod_trunc_value * noise

    def set_loss(self):
        device = self.betas.device
        if self.config.dm_loss_type == 'l1':
            self.loss_func = nn.L1Loss(reduction='mean').to(device)
        elif self.config.dm_loss_type == 'l2':
            self.loss_func = nn.MSELoss(reduction='mean').to(device)
        else:
            raise NotImplementedError()


    def naive_loss(self, text_feature, vid_feature, noise=None):

        x_start = vid_feature
        bs, dim = x_start.shape
        t = torch.randint(0, self.num_timesteps, (bs,), device=x_start.device).long()
        # print(f'>>>[naive_loss] self.num_timesteps={self.num_timesteps}') # == self.config.n_timestep

        noise = default(noise, lambda: torch.randn_like(x_start))
        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
        # print(f'>>>[naive_loss] x_noisy={x_noisy.shape}') # [bs, dim]

        x_recon = self.net_d(x=x_noisy.unsqueeze(1),  t=t)
        # print(f'>>>[naive_loss] x_recon={x_recon.shape}') # [bs, 1, dim]

        loss = self.loss_func(noise, x_recon.squeeze(1))
        # print(f'>>>[naive_loss] loss={loss.item()}') # scalar values verified

        return loss

