# %%
from MyDiffusers import StableDiffusionManager
import torch
from torch import nn
from torch.nn import functional as F
import matplotlib.pyplot as plt
from torchvision.transforms.functional import to_pil_image
from torchvision.utils import save_image
import torchvision.transforms.functional as TF
import os
from copy import deepcopy
from MyDiffusers import *
from flowsrepo import *
import time
from PIL import Image
import numpy as np
import cv2
from diffusers import StableDiffusionPipeline
import argparse

# %%
SDM = StableDiffusionManager(
    device='cuda:2',
    tau=100
)

# %%
# prompt='a picture of a planet earth'
# negative_prompt = [""]
# num_inference_steps=200
# seed=6
# name='earth'

# %%
# prompt = 'transparent man made by water and smoke, in style of Yoji Shinkawa and Hyung-tae Kim, trending on ArtStation, dark fantasy, great composition, concept art, highly  human  made of water and foam, in the style of Pierre Koenig, red pigment, pastel paint, pink color scheme'
# negative_prompt = ["poorly drawn"]
# num_inference_steps = 200
# seed = 1330
# name='melting_man'

# %%
# prompt = 'a satellite image of a city'
# negative_prompt = ["poorly drawn,cartoon, 2d, disfigured, bad art, deformed, poorly drawn, extra limbs, close up, b&w, weird colors, blurry"]
# num_inference_steps = 200
# seed = 46
# name='satellite_city'

# %%
# prompt = 'a Baroque-style batlle with only  dragons that  spit fire each other, shining like diamonds, draped in white fabric. The background is a dark, creepy eastern europen forrest.  night, horroristic shadows, high contrasts, lumnious, theatrical, character concept art by ruan jia, thomas kinkade, and  trending on Artstation'
# negative_prompt = ["poorly drawn,cartoon, 2d, disfigured, bad art, deformed, poorly drawn, extra limbs, close up, b&w, weird colors, blurry"]
# num_inference_steps = 200
# seed = 8
# name='dragons'

# %%
# prompt = "a small flock bird flying in the sky at the sunset"
# negative_prompt = ["poorly drawn,cartoon, 2d, disfigured, bad art, deformed, poorly drawn, extra limbs, close up, b&w, weird colors, blurry"]
# num_inference_steps = 100
# seed = 801
# name='birds'

# %%
# prompt = "a drop of water falling on a table"
# negative_prompt = ["poorly drawn"]
# num_inference_steps = 100
# seed = 1330
# name = 'glass'

# %%
with torch.no_grad():
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)        

    z = torch.randn(1, 4, 64, 64, device='cpu').to(SDM.device)
    out = SDM.full_generation(
        z = z,
        prompt=prompt,
        guidance_scale=7.5,
        eta=0,
        negative_prompt = negative_prompt,
        num_inference_steps=num_inference_steps,
    )
    print(seed, flush=True)
    # display(out.images[0])
    display(out[0].images[0])

# %%
with torch.no_grad():
    latent_t0 = SDM.image_to_latent(out[0].images[0])
    assert latent_t0.shape == (1, 4, 64, 64)
    out2 = SDM.full_inversion(
        z = latent_t0,
        prompt=prompt,
        guidance_scale=0.0,
        eta=0,
        num_inference_steps=num_inference_steps,
    )

# %%
latent_tT = out2[1]['latents'][-1]
assert latent_tT.shape == (1, 4, 64, 64)
out3 = SDM.full_generation(
    z = latent_tT,
    prompt=prompt,
    guidance_scale=0.0,
    eta=0,
    num_inference_steps=999,
)

# %%
display(out3[0].images[0])

# %%
out3[0].images[0].save(f'base_images/{name}.png')
