import os
import time

import torch
from PIL import Image, ImageFilter
from diffusers import (
    StableDiffusionInpaintPipeline, 
    UNet2DConditionModel,
    DDIMScheduler
)
from transformers import CLIPTextModel

import cv2
import json
import numpy as np
from itertools import chain
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
import torch.nn.functional as F

from model import Homography, Siren
from util import get_mgrid, apply_homography, jacobian, VideoFitting, TestVideoFitting

# device1 for our model
device1 = 'cuda:1'
# device2 for diffusion model
device2 = 'cuda:0'

# def get_uncertainty_map(uncertainty, step, output_dir):
#     uncertainty = (uncertainty - torch.min(uncertainty)) / (torch.max(uncertainty - torch.min(uncertainty)))
#     thresh = (1 + torch.mean(uncertainty)) / 2
#     uncertainty = (uncertainty > thresh) * 1
#     uncertainty = uncertainty.view(512, 1024)
#     uncertainty = uncertainty.cpu().detach().numpy()
#     if step % 1000 == 0:
#         uncertainty_save = np.copy(uncertainty)
#         uncertainty_save = Image.fromarray(np.uint8(uncertainty_save * 255))
#         uncertainty_save.save(output_dir + '/uncertainty/uncertainty_%d.png'%step)

#     return uncertainty.astype('uint8')


# def get_dilated_mask(np_mask, step, output_dir):
#     # set kernel's size & shape
#     kernel_size = 5
#     kernel_shape = cv2.MORPH_RECT  # option: MORPH_RECT, MORPH_ELLIPSE, MORPH_CROSS

#     dilation_kernel = cv2.getStructuringElement(kernel_shape, (kernel_size, kernel_size))

#     # start to dilate
#     dilated_mask = cv2.dilate(np_mask, dilation_kernel, iterations=1)
#     if step % 1000 == 0:
#         dilated_mask_save = np.copy(dilated_mask)
#         cv2.imwrite(output_dir + '/dilated/dilated_%d.png'%step, np.uint8(dilated_mask_save * 255))

#     return dilated_mask


def get_canonical(canonical, step, output_dir):
    canonical = canonical.view(512, 512, 3)
    canonical = canonical.permute(2, 0, 1)
    canonical = torch.clip(canonical, -1, 1) * 0.5 + 0.5
    if step % 1000 == 0:
        canonical_save = canonical.detach().clone().permute(1, 2, 0).cpu().numpy()
        canonical_save = Image.fromarray(np.uint8(canonical_save * 255))
        canonical_save.save(output_dir + '/canonical/canonical_%d.png'%step)

    return canonical


generator = None 
seed = None
model_path = ""
output_dir = ""

# create & load model
pipe = StableDiffusionInpaintPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-inpainting",
    torch_dtype=torch.float32,
    revision=None
)

pipe.unet = UNet2DConditionModel.from_pretrained(
    model_path, subfolder="unet", revision=None,
)
pipe.text_encoder = CLIPTextModel.from_pretrained(
    model_path, subfolder="text_encoder", revision=None,
)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(device2)

if seed is not None:
    generator = torch.Generator(device=device2).manual_seed(seed)


checkpoint_g_old = torch.load('pth file/train_all/homography_g.pth')

g_old = Homography(hidden_features=256, hidden_layers=2).to(device1)
g_old.load_state_dict(checkpoint_g_old)
g_old.eval()

print("---Loading successfully---")

# compute time
start_time = time.time()

def train_residual_flow(path, total_steps, lambda_flow=0.02, verbose=True, steps_til_summary=100):
    global g_old, pipe, output_dir, generator

    transform = Compose([
        Resize(512),
        ToTensor(),
        Normalize(torch.Tensor([0.5, 0.5, 0.5]), torch.Tensor([0.5, 0.5, 0.5]))
    ])
    v = VideoFitting(path, transform)
    videoloader = DataLoader(v, batch_size=1, pin_memory=True, num_workers=0)

    g = Siren(in_features=3, out_features=2, hidden_features=256,
              hidden_layers=5, outermost_linear=True)
    g.to(device1)
    f = Siren(in_features=2, out_features=3, hidden_features=256, 
            hidden_layers=5, outermost_linear=True)
    f.to(device1)

    optim = torch.optim.Adam(lr=1e-4, params=chain(g.parameters(), f.parameters()))
    # scheduler = torch.optim.lr_scheduler.MultiStepLR(optim, milestones=[4000, 8000, 10000, 12500], gamma=0.1)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optim, milestones=[2000, 8000, 10000], gamma=0.1)

    model_input, ground_truth = next(iter(videoloader))
    model_input, ground_truth = model_input[0].to(device1), ground_truth[0].to(device1)

    batch_size = (v.H * v.W) // 32
    for step in range(total_steps):
        start = (step * batch_size) % len(model_input)
        end = min(start + batch_size, len(model_input))
        
        xy, t = model_input[start:end, :-1], model_input[start:end, [-1]]
        xyt = model_input[start:end].requires_grad_()

        h_old = apply_homography(xy, g_old(t))
        h = g(xyt)
        xy_ = h_old + h
        o = f(xy_)
        
        loss_recon = (o - ground_truth[start:end]).abs().mean()
        loss_flow = jacobian(h, xyt).abs().mean()
        loss = loss_recon + lambda_flow * loss_flow

        # set dm render frequency and strength
        if step <= 3000:
            dm_freq = 10
            dm_strength = 0.4
        elif step <= 5000:
            dm_freq = 100
            dm_strength = 0.3
        else:
            dm_freq = 2000
            dm_strength = 0.2

        loss_dm = 0
        # start to join diffusion prior
        if step >= 1000:
            xy_c = get_mgrid([512, 512], [-1.0577, -1.1265], [1.0863, 1.0740]).to(device1)
            o_c = f(xy_c)
            # o_c shape: (C, H, W)
            o_c = get_canonical(o_c, step, output_dir)
            # dilated_mask shape: (C, H, W)
            dilated_mask = torch.ones((1, 512, 512), dtype=torch.float32).to(device1)
            # use pre-trained diffusion model
            # image shape: (B, C, H, W) or (C, H, W)
            # image value: 0~1
            if step % dm_freq == 0:
                dm_result = pipe(
                    ["a photo of sks"] * 1, image=o_c.detach().clone().to(device2), mask_image=dilated_mask.detach().clone().to(device2), 
                    num_inference_steps=60, guidance_scale=1, generator=generator, strength=dm_strength
                ).images
                # conver tensor to image
                o_image = o_c.detach().clone().permute(1, 2, 0).cpu().numpy()
                o_image = Image.fromarray(np.uint8(o_image * 255))

                mask_image = dilated_mask.detach().clone().cpu().numpy().squeeze()
                mask_image = Image.fromarray(np.uint8(mask_image * 255))

                erode_kernel = ImageFilter.MaxFilter(3)
                mask_image = mask_image.filter(erode_kernel)
                
                blur_kernel = ImageFilter.BoxBlur(1)
                mask_image = mask_image.filter(blur_kernel)

                for idx, result in enumerate(dm_result):
                    result = Image.composite(result, o_image, mask_image)
                    if step % 1000 == 0:
                        result.save(output_dir + '/dm_result/result_%d.png'%step)
                
                result = torch.tensor(np.array(result), dtype=torch.float32).to(device1) / 255.0
                result = result.permute(2, 0, 1)
            
            # compute DM MSE loss
            myweight = torch.zeros(512, 512, dtype=torch.float32).to(device1)
            myweight[175:, :] = 1.0
            loss_dm += (myweight * (o_c - result)).abs().mean()
            # loss_dm += (o_c - result).abs().mean()

            loss += loss_dm

        if verbose and not step % steps_til_summary:
            print("Step [%04d/%04d]: recon=%0.8f, flow=%0.4f, dm=%.05f" % (step, total_steps, loss_recon, loss_flow, loss_dm))

        optim.zero_grad()
        loss.backward()
        optim.step()
        scheduler.step()
    
    return f, g


path = 'data/train_all'
f, g = train_residual_flow(path, 13001, lambda_flow=0.03)

end_time = time.time()

torch.save(f.state_dict(), 'pth file/train_all/DM_mlp_f.pth')
torch.save(g.state_dict(), 'pth file/train_all/DM_mlp_g.pth')

transform = Compose([
    Resize(512),
    ToTensor(),
    Normalize(torch.Tensor([0.5, 0.5, 0.5]), torch.Tensor([0.5, 0.5, 0.5]))
])
v = TestVideoFitting(path, transform)
videoloader = DataLoader(v, batch_size=1, pin_memory=True, num_workers=0)

model_input, ground_truth = next(iter(videoloader))
model_input, ground_truth = model_input[0].to(device1), ground_truth[0].to(device1)

myoutput = None
data_len = len(os.listdir(path))

with torch.no_grad():
    batch_size = (v.H * v.W)
    for step in range(data_len):
        start = (step * batch_size) % len(model_input)
        end = min(start + batch_size, len(model_input))

        xy, t = model_input[start:end, :-1], model_input[start:end, [-1]]
        xyt = model_input[start:end]
        h_old = apply_homography(xy, g_old(t))
        h = g(xyt)
        xy_ = h_old + h
        o = f(xy_)

        if step == 0:
            myoutput = o
        
        else:
            myoutput = torch.cat([myoutput, o])

out_folder = output_dir + '/reconstruction/'
myoutput = myoutput.reshape(512, 910, data_len, 3).permute(2, 0, 1, 3).clone().detach().cpu().numpy().astype(np.float32)
myoutput = np.clip(myoutput, -1, 1) * 0.5 + 0.5

for i in range(len(myoutput)):
    img = Image.fromarray(np.uint8(myoutput[i] * 255)).resize((854, 480))
    img.save(out_folder + '%05d.jpg'%(i))

print(end_time - start_time)