import torch, os
import numpy as np
from PIL import Image
from empatches import EMPatches

from omegaconf import OmegaConf
from diffusers import AutoPipelineForInpainting
from diffusers.utils import make_image_grid
import torchvision.transforms as transforms
from torchvision.transforms import functional as F

from ldm.models.diffusion.ddim import DDIMSampler
from main import instantiate_from_config



class InpaintPipeline:
    def __init__(
        self, pipe_name="runwayml", image_size=512, image_size_out=900, patch=False, patch_size=512, overlap=0.2
    ):
        self.pipe_name = pipe_name
        self.image_size = image_size
        self.image_size_out = image_size_out
        self.patch_size = patch_size
        self.overlap = overlap
        self.patch = patch
        self.emp = EMPatches()

        self.device = (
            torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        )
        self.pipe_cfg = {
            "runwayml": "runwayml/stable-diffusion-inpainting",
            "sdxl_inpaint": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
            "compvis": "models/ldm/inpainting_big/last.ckpt",
            "kandinsky": "kandinsky-community/kandinsky-2-2-decoder-inpaint",
            "sd": "runwayml/stable-diffusion-v1-5",
        }

        self.pipe = self.get_pipe()
        self.generator = torch.Generator(device=self.device).manual_seed(0)
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((image_size, image_size), interpolation=F.InterpolationMode.LANCZOS),
        ])

    def get_pipe(self):

        if self.pipe_name in ["sdxl_inpaint", "sd"]:
            variant = "fp16"
            pipe = AutoPipelineForInpainting.from_pretrained(
            self.pipe_cfg[self.pipe_name], torch_dtype=torch.float16, variant=variant,
        )
        elif self.pipe_name == "compvis":
            config = OmegaConf.load("models/ldm/inpainting_big/config.yaml")
            pipe = instantiate_from_config(config.model)
            pipe.load_state_dict(torch.load(self.pipe_cfg[self.pipe_name])["state_dict"], strict=False)
        else:
            pipe = AutoPipelineForInpainting.from_pretrained(
            self.pipe_cfg[self.pipe_name], torch_dtype=torch.float16)

        pipe = pipe.to(self.device)

        if self.pipe_name in ["kandnisky", "sd"]:
            pipe.enable_model_cpu_offload()

        return pipe

    def get_patches(self, image, mask):
        
        img_patches, indices = self.emp.extract_patches(image, patchsize=self.patch_size, overlap=self.overlap)
        mask_patches, _ = self.emp.extract_patches(mask, patchsize=self.patch_size, overlap=self.overlap)
      
        return img_patches, mask_patches, indices

    def runwayml(self, opt, suffix, image, mask, save_patch_res=False):
        '''
        input
            image: PIL.Image
            mask: PIL.Image
        output:
            save_res: PIL.Image to save, grids for image, mask, inpainted_image and unmasked_unchanged_image
            inpainted_image: PIL.Image (to numpy -> [h,w,c])
        '''

        if self.patch:
            image_patches, maske_patches, indices = self.get_patches(np.array(image), np.array(mask))
            img_patches = []
            for patch_idx, patch in enumerate(zip(image_patches, maske_patches)):

                image, mask = self.transform(patch[0]), self.transform(patch[1])
                inpainted_image = self.pipe(prompt="", image=image, mask_image=mask).images[0]

                mask_image_arr = np.array(mask.convert("L"))
                mask_image_arr = mask_image_arr[:, :, None]
                mask_image_arr = mask_image_arr.astype(np.float32) / 255.0
                mask_image_arr[mask_image_arr < 0.5] = 0
                mask_image_arr[mask_image_arr >= 0.5] = 1
                unmasked_unchanged_image_arr = (1 - mask_image_arr) * image + mask_image_arr * inpainted_image
                if save_patch_res:
                    unmasked_unchanged_image = Image.fromarray(unmasked_unchanged_image_arr.round().astype("uint8"))
                
                    save_res = make_image_grid(
                            [image, mask, inpainted_image, unmasked_unchanged_image], rows=2, cols=2
                        )
                    save_res.save(os.path.join(opt.outdir, 'grids_inpainting', f'inpt_{suffix}_{str(patch_idx)}.png'))

                img_patches.append(unmasked_unchanged_image_arr)

            
            merged_img = self.emp.merge_patches(img_patches, indices)
            unmasked_unchanged_image = Image.fromarray(merged_img.round().astype("uint8"))
            unmasked_unchanged_image = unmasked_unchanged_image.resize((self.image_size_out[0], self.image_size_out[1]), Image.Resampling.LANCZOS)
            
            return unmasked_unchanged_image, None

        else:
           
            image, mask = self.transform(image), self.transform(mask)
            inpainted_image = self.pipe(prompt="", image=image, mask_image=mask).images[0]
            mask_image_arr = np.array(mask.convert("L"))
            mask_image_arr = mask_image_arr[:, :, None]

            mask_image_arr = mask_image_arr.astype(np.float32) / 255.0
            mask_image_arr[mask_image_arr < 0.5] = 0
            mask_image_arr[mask_image_arr >= 0.5] = 1

            unmasked_unchanged_image_arr = (1 - mask_image_arr) * image + mask_image_arr * inpainted_image
            unmasked_unchanged_image = Image.fromarray(unmasked_unchanged_image_arr.round().astype("uint8"))
            
            # save_res = make_image_grid(
            #         [image, mask, inpainted_image, unmasked_unchanged_image], rows=2, cols=2
            #     )
            # unmasked_unchanged_image = unmasked_unchanged_image.resize((self.image_size_out[0], self.image_size_out[1]), Image.Resampling.LANCZOS)

            return inpainted_image, unmasked_unchanged_image


    def compvis(self, image, mask, batch):
        '''
        input
            image: PIL.Image
            mask: PIL.Image
            batch: tensors of image, mask and masked_image -> [1, c, h, w]
        output:
            save_res: PIL.Image to save, grids for image, mask and inpainted_image 
            inpainted_image: PIL.Image (to numpy -> [h,w,c])
        '''

        for k in batch:
            batch[k] = batch[k].to(device=self.device)
            batch[k] = batch[k] * 2.0 - 1.0

        sampler = DDIMSampler(self.pipe)
        
        with torch.no_grad():
            with self.pipe.ema_scope():
                c = self.pipe.cond_stage_model.encode(batch["masked_image"])
                cc = torch.nn.functional.interpolate(batch["mask"], size=c.shape[-2:])
                c = torch.cat((c, cc), dim=1)
                        
                shape = (c.shape[1]-1,)+c.shape[2:]
                samples_ddim, _ = sampler.sample(S=50,
                                                conditioning=c,
                                                batch_size=c.shape[0],
                                                shape=shape,
                                                verbose=False)
                x_samples_ddim = self.pipe.decode_first_stage(samples_ddim)

                image_torch = torch.clamp((batch["image"] + 1.0) / 2.0, min = 0.0, max = 1.0)
                mask_torch = torch.clamp((batch["mask"] + 1.0) / 2.0, min = 0.0, max = 1.0)
                predicted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min = 0.0, max= 1.0)
                inpainted = (1 - mask_torch) * image_torch + mask_torch * predicted_image
            
                s = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255
                inpainted_image = Image.fromarray(s.astype(np.uint8))

        save_res = make_image_grid([image, mask, inpainted_image], rows=1, cols=3)

        return inpainted_image, save_res


    

    def sdxl_inpaint(self, image, mask):
        '''
        input
            image: PIL.Image
            mask: PIL.Image
        output:
            save_res: PIL.Image to save, grids for image, mask and inpainted_image
            inpainted_image: PIL.Image (to numpy -> [h,w,c])
        '''
        inpainted_image = self.pipe(
            prompt="",
            image=image,
            mask_image=mask,
            guidance_scale=8,
            num_inference_steps=20,  # steps between 15 and 30 work well for us
            strength=0.99,
            generator=self.generator,
        ).images[0]
        inpainted_image = inpainted_image.resize((self.image_size, self.image_size))

        save_res = make_image_grid([image, mask, inpainted_image], rows=1, cols=3)
        return inpainted_image, save_res


    def kandinsky(self, image, mask):
        '''
        input
            image: PIL.Image
            mask: PIL.Image
        output:
            save_res: PIL.Image to save, grids for image, mask and inpainted_image
            inpainted_image: PIL.Image (to numpy -> [h,w,c])
        '''
         
        inpainted_image = self.pipe(
            prompt="", image=image, mask_image=mask, generator=self.generator
        ).images[0]
        inpainted_image = inpainted_image.resize((self.image_size, self.image_size))

        save_res = make_image_grid([image, mask, inpainted_image], rows=1, cols=3)

        return inpainted_image, save_res


    def sd(self, image, mask):

        inpainted_image = self.pipe(
            prompt="", image=image, mask_image=mask, generator=self.generator
        ).images[0]
        inpainted_image = inpainted_image.resize((self.image_size, self.image_size))

        save_res = make_image_grid([image, mask, inpainted_image], rows=1, cols=3)
        return inpainted_image, save_res

