import torch
import copy
import pandas as pd

from safetensors.torch import load_file
from diffusers import StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel
from peft import LoraConfig

def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter_name: str = "default"):
    kohya_ss_state_dict = {}
    for peft_key, weight in module.items():
        kohya_key = peft_key.replace("unet.base_model.model", prefix)
        kohya_key = kohya_key.replace("lora_A", "lora_down")
        kohya_key = kohya_key.replace("lora_B", "lora_up")
        kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
        kohya_ss_state_dict[kohya_key] = weight.to(dtype)
        # Set alpha parameter
        if "lora_down" in kohya_key:
            alpha_key = f'{kohya_key.split(".")[0]}.alpha'
            kohya_ss_state_dict[alpha_key] = torch.tensor(8).to(dtype)
    
    return kohya_ss_state_dict

# Load models (DM, CM)
def load_models(model_id,
                device,
                forward_checkpoint,
                inverse_checkpoint,
                r=64,
                w_embed_dim=0,
                teacher_checkpoint=None,
                dtype='fp32',
                ):
    # Diffusion
    # ------------------------------------------------------------
    dtype = torch.float32 if dtype=='fp32' else torch.float16
    scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False,             set_alpha_to_one=False)
    ldm_stable = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler).to(device, dtype=dtype)
    # ------------------------------------------------------------
    
    # Lora config
    lora_config = LoraConfig(
        r=r,
        target_modules=[
            "to_q",
            "to_k",
            "to_v",
            "to_out.0",
            "proj_in",
            "proj_out",
            "ff.net.0.proj",
            "ff.net.2",
            "conv1",
            "conv2",
            "conv_shortcut",
            "downsamplers.0.conv",
            "upsamplers.0.conv",
            "time_emb_proj",
        ],
    )
    
    # Forward consistency
    # ------------------------------------------------------------
    if w_embed_dim > 0:
        print(f'Forward CD is initialized with guidance embedding, dim {w_embed_dim}')
        unet = UNet2DConditionModel.from_pretrained(
            model_id, subfolder="unet",
            time_cond_proj_dim=w_embed_dim, low_cpu_mem_usage=False, device_map=None
        ).to(device)
        if teacher_checkpoint is not None:
            print(f'Embedded model is loading from {teacher_checkpoint}')
            unet.load_state_dict(torch.load(teacher_checkpoint))
    else:
        unet = UNet2DConditionModel.from_pretrained(
                model_id, subfolder="unet"
        ).to(device)
    print(f'Forward CD is loading from {forward_checkpoint}')
    forw_cons_model = copy.deepcopy(ldm_stable)
    forw_cons_model.unet = unet
    lora_weight = load_file(forward_checkpoint)
    lora_state_dict = get_module_kohya_state_dict(lora_weight, "lora_unet", torch.float16)
    forw_cons_model.load_lora_weights(lora_state_dict)
    forw_cons_model.fuse_lora()
    forw_cons_model.to(dtype=dtype)
    # ------------------------------------------------------------
    
    # Inverse consistency
    # ------------------------------------------------------------
    print(f'Inverse CD is loading from {inverse_checkpoint}')
    unet = UNet2DConditionModel.from_pretrained(
                model_id, subfolder="unet"
        ).to(device)
    inv_cons_model = copy.deepcopy(ldm_stable)
    inv_cons_model.unet = unet
    lora_weight = load_file(inverse_checkpoint)
    lora_state_dict = get_module_kohya_state_dict(lora_weight, "lora_unet", torch.float16)
    inv_cons_model.load_lora_weights(lora_state_dict)
    inv_cons_model.fuse_lora()
    inv_cons_model.to(dtype=dtype)
    # ------------------------------------------------------------
    
    
    return ldm_stable, forw_cons_model, inv_cons_model


# Load benchmarks (editing or generation)
def load_benchmark(path_to_prompts,
                   path_to_images=None):
    files = pd.read_csv(path_to_prompts)
    if path_to_images is None:
        print(f'Generation benchmark: Loading from {path_to_prompts}')
        prompts = list(files['caption'])
        return prompts
    else:
        print(f'Editing benchmark: Loading prompts, images from {path_to_prompts}, {path_to_images}')
        files = files.reset_index()
        benchmark = []
        for index, row in files.iterrows():
            name = row['file_name']
            img_path = f'{path_to_images}/{name}'
            orig_prompt = row['old_caption']
            edited_prompt = row['edited_caption']
            benchmark.append((img_path, {'before': orig_prompt,
                                         'after': edited_prompt}))
        return benchmark
        