from functools import partial

import torch
import torch.nn as nn



from omegaconf import OmegaConf
import numpy as np
from LFDNM.models.diffusion.ddim import DDIMSampler
from LFDNM.util import load_model
from premodel.models.model_helper import ModelHelper
from premodel.utils.misc_helper import (
    update_config,
)
class CustomGroupNorm(nn.Module):
    def __init__(self, indices):
        super().__init__()
        self.indices = indices   

    def forward(self, x):
         
        group_normed_features = []
        start = 0
        
        for end in self.indices:
             
            group = x[:, start:end]
            
             
            mean = group.mean(dim=1, keepdim=True)
            std = group.std(dim=1, keepdim=True) + 1e-5   

             
            normalized_group = (group - mean) / std
            
             
            group_normed_features.append(normalized_group)
            
             
            start = end
        
         
        z = torch.cat(group_normed_features, dim=1)
        
        return z


class DMLR(nn.Module):
    """ DMLR___
    """
    def __init__(self,
                 use_class_label=False,
                 pretrained_LFDNM_ckpt=None,
                 pretrained_LFDNM_cfg=None):
        super().__init__()

        self.global_avg_pool = nn.AdaptiveAvgPool2d(1) 
         

        self.use_class_label = use_class_label
        self.pool = nn.AdaptiveAvgPool2d((2, 2))
         
         

        if pretrained_LFDNM_ckpt is not None and pretrained_LFDNM_cfg is not None:
            LFDNM_config = OmegaConf.load(pretrained_LFDNM_cfg)
            LFDNM_config = update_config(LFDNM_config)
            self.LFDNM_fake_class_label = LFDNM_config.model.params.cond_stage_config.params.n_classes - 1
            LFDNM_model = load_model(LFDNM_config, pretrained_LFDNM_ckpt)
            self.LFDNM_sampler = DDIMSampler(LFDNM_model)
        else:
            self.LFDNM_fake_class_label = 0
        if pretrained_LFDNM_cfg is not None:
            pretrained_enc_config = LFDNM_config.model.params.pretrained_enc_config
            self.instantiate_pretrained_enc(pretrained_enc_config)


    def forward(self, imgs, class_label,
                gen__image=True, bsz=None, num_iter=None, choice_temperature=None,
                sampled_rep=None, LFDNM_steps=100, eta=1.0, cfg=0.0, class_label_gen=None):
        if gen__image:
            self.pretrained_encoder.eval()
            with torch.no_grad():
                self.pretrained_encoder.eval()   
                input = {"image":imgs,}
                outputs = self.pretrained_encoder(input)
                rep= outputs["feature_align"]
            z= self.global_avg_pool(rep )
            z = z.squeeze(-1).squeeze(-1)
            indices=[0,24, 56, 112, 272, 720]
            custom_group_norm = CustomGroupNorm(indices)
            # # Apply the custom group normalization
            z = custom_group_norm(z)
            sampled_rep, class_label=self.gen_image(64, z,num_iter, choice_temperature, sampled_rep, LFDNM_steps, eta, cfg, class_label_gen)
            return sampled_rep, class_label,z

    def gen_image(self, bsz=64, imgs=None, num_iter=12, choice_temperature=4.5, sampled_rep=None, LFDNM_steps=100, eta=1.0,
                  cfg=0.0, class_label=None):

        unknown_number_in_the_beginning = 256
  
         
        if sampled_rep is None:
            with self.LFDNM_sampler.model.ema_scope("Plotting"):
                shape = [self.LFDNM_sampler.model.model.diffusion_model.in_channels,
                         self.LFDNM_sampler.model.model.diffusion_model.image_size,
                         self.LFDNM_sampler.model.model.diffusion_model.image_size]
                if self.LFDNM_sampler.model.class_cond:
                    cond = {"class_label": class_label}
                else:
                    class_label = self.LFDNM_fake_class_label * torch.ones(bsz).cuda().long()
                    cond = {"class_label": class_label}
                cond = self.LFDNM_sampler.model.get_learned_conditioning(cond)

                sampled_rep, _ = self.LFDNM_sampler.sample(LFDNM_steps, conditioning=cond,x0=imgs, batch_size=bsz,
                                                                  shape=shape,
                                                                  eta=eta, verbose=False)
                sampled_rep = sampled_rep.squeeze(-1).squeeze(-1)

                 
                if cfg > 0:
                    uncond_rep = self.fake_latent.repeat(bsz, 1)
                    sampled_rep = torch.cat([sampled_rep, uncond_rep], dim=0)

        if self.use_class_label:
            assert cfg == 0
            class_label = torch.randint(0, 1000, (bsz,)).cuda()

        return sampled_rep, class_label

    def instantiate_pretrained_enc(self, pretrained_enc_cfg):
        
        if isinstance(pretrained_enc_cfg, OmegaConf):
            pretrained_enc_cfg = OmegaConf.to_container(pretrained_enc_cfg, resolve=True)


        self.pretrained_encoder = ModelHelper(pretrained_enc_cfg)
        self.pretrained_encoder.cuda()
        self.pretrained_encoder.eval()
