import argparse
import torch
import copy
import sys
import pickle
import os
# TODO DELETE
sys.path.append('/home/quickjkee/projects/consistency_inversion_editing/repo_editing')

from safetensors.torch import load_file
from peft import LoraConfig
from diffusers import StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel,DDPMScheduler
from tqdm import tqdm

from utils import p2p, generation, inversion, metrics
from utils.loading import load_models, load_benchmark

# Utils
# -------------------------------------------------------------------------------------
def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def find_difference(word1, word2):
    splitted_w1 = word1.split(' ')
    splitted_w2 = word2.split(' ')
    for i,j in zip(splitted_w1, splitted_w2):
        if i != j:
            return i, j

def n_differences(word1, word2):
    splitted_w1 = word1.split(' ')
    splitted_w2 = word2.split(' ')
    diff = 0
    for i,j in zip(splitted_w1, splitted_w2):
        if i != j:
            diff += 1
    return diff
# -------------------------------------------------------------------------------------


# Arguments parser
# -------------------------------------------------------------------------------------
def parse_args():
    parser = argparse.ArgumentParser()
    
    # Loading settings
    ################################
    parser.add_argument(
        "--model_id_DM",
        type=str,
        default=None,
        required=True,
        help="Path to pretrained DM",
    )
    parser.add_argument(
        "--forward_checkpoint",
        type=str,
        default=None,
        help="Path to forward CM",
    )    
    parser.add_argument(
        "--inverse_checkpoint",
        type=str,
        default=None,
        help="Path to inverse CM",
    )   
    parser.add_argument(
        "--teacher_checkpoint",
        type=str,
        default=None,
        help="Path to teacher DM with w embedding",
    ) 
    parser.add_argument(
        "--path_to_prompts",
        type=str,
        default=None,
        required=True,
        help="Path to prompts for benchmarking",
    ) 
    parser.add_argument(
        "--path_to_images",
        type=str,
        default=None,
        required=False,
        help="Path to images for editing benchmarking only",
    ) 
    ################################
    
    # Models settings
    ################################
    parser.add_argument(
        "--lora_rank",
        type=int,
        default=64,
        help="rank of lora weights",
    ) 
    parser.add_argument(
        "--w_embed_dim",
        type=int,
        default=0,
        help="dimension of guidance embedding",
    ) 
    parser.add_argument(
        "--num_ddim_steps",
        type=int,
        default=50,
    ) 
    parser.add_argument(
        "--num_forw_cons_steps",
        type=int,
        default=4,
        required=True,
        help="number of steps for forward CM",
    ) 
    parser.add_argument(
        "--num_inv_cons_steps",
        type=int,
        default=3,
        required=True,
        help="number of steps for inverse CM",
    ) 
    parser.add_argument(
        "--max_inverse_timestep_idx",
        type=int,
        default=49,
        help="the last timestep for inverse CM for encode",
    ) 
    parser.add_argument(
        "--start_timestep",
        type=int,
        default=19,
        help="starting timestep for noising real images",
    )
    ################################
    
    # Inversion settings
    ################################
    parser.add_argument(
        "--use_cons_inversion",
        type=str2bool,
        default='True',
        required=True,
        help='whether to do inversion with CM'
    )
    parser.add_argument(
        "--nti_guidance_scale",
        type=float,
        default=8.0,
        help="guidance scale for inversion with NTI",
    )
    parser.add_argument(
        "--use_npi",
        type=str2bool,
        default='False',
        help='whether to use negative prompt inversion'
    )
    parser.add_argument(
        "--use_nti",
        type=str2bool,
        default='False',
        help='whether to use null text inversion'
    )
    ################################
    
    # Editing settings
    ################################
    parser.add_argument(
        "--preservation_metric",
        type=str,
        default='dinov2',
    ) 
    parser.add_argument(
        "--editing_metric",
        type=str,
        default='imagereward',
    )
    parser.add_argument(
        "--use_cons_editing",
        type=str2bool,
        default='True',
        required=True,
        help='whether to do editing with CM'
    )
    parser.add_argument(
        "--dynamic_guidance",
        type=str2bool,
        default='True',
        required=True,
        help='whether to use dynamic guidance for editing'
    )
    parser.add_argument(
        "--tau1",
        type=float,
        default=0.8,
        required=True,
        help="first hyperparameter for dynamic guidance",
    )
    parser.add_argument(
        "--tau2",
        type=float,
        default=0.8,
        required=True,
        help="second hyperparameter for dynamic guidance",
    )
    parser.add_argument(
        "--cross_replace_steps",
        type=float,
        default=0.4
    )
    parser.add_argument(
        "--self_replace_steps",
        type=float,
        default=0.4
    )
    parser.add_argument(
        "--amplify_factor",
        type=float,
        default=3
    )
    parser.add_argument(
        "--guidance_scale",
        type=float,
        default=8.0,
        required=True,
        help="guidance scale for editing",
    )
    ################################
    
    # Others
    ################################
    parser.add_argument(
        "--device",
        type=str,
        default='cuda',
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=30,
    )
    parser.add_argument(
        "--saving_dir",
        type=str,
        required=True,
        default='results',
    )
    parser.add_argument(
        "--dtype",
        type=str,
        default='fp32',
    )
    ################################
    
    args = parser.parse_args()

    return args
# -------------------------------------------------------------------------------------

# Running 
# -------------------------------------------------------------------------------------
def main(args):
    
    # Models loading
    root = '/extra_disk_1/quickjkee/projects/consistency_inversion_editing/dir_checkpoints'
    ldm_stable, forw_cons_model, inv_cons_model = load_models(
        model_id=args.model_id_DM,
        device=args.device,
        forward_checkpoint=args.forward_checkpoint,
        inverse_checkpoint=args.inverse_checkpoint,
        r=args.lora_rank,
        w_embed_dim=args.w_embed_dim,
        teacher_checkpoint=args.teacher_checkpoint,
        dtype=args.dtype)

    tokenizer = ldm_stable.tokenizer
    noise_scheduler = DDPMScheduler.from_pretrained(
        args.model_id_DM, subfolder="scheduler", 
    )
    
    # Benchmark loading
    editing_benchmark = load_benchmark(args.path_to_prompts,
                                       args.path_to_images)
    available_metrics = {'preservation': {'clip_score': metrics.calc_clip_score_images_images,
                                          'dinov2': metrics.calc_dinov2_images_images},
                         'editing': {'clip_score': metrics.calc_clip_score_images_prompts,
                                     'imagereward': metrics.calc_ir},
                        }
    assert args.preservation_metric in list(available_metrics['preservation'].keys()), 'available metrics: clip_score, dinov2' 
    assert args.editing_metric in list(available_metrics['editing'].keys()), 'available metrics: clip_score, imagereward' 
    
    # Generator configuration
    generator = generation.Generator(
                            model=ldm_stable, 
                            noise_scheduler=noise_scheduler,
                            n_steps=args.num_ddim_steps,
                            inv_cons_model=inv_cons_model,
                            forw_cons_model=forw_cons_model,
                            num_endpoints=args.num_forw_cons_steps, 
                            num_inverse_endpoints=args.num_inv_cons_steps,
                            max_inverse_timestep_index=args.max_inverse_timestep_idx,
                            start_timestep=args.start_timestep)
    p2p.NUM_DDIM_STEPS = args.num_ddim_steps
    p2p.tokenizer = tokenizer
    p2p.device = args.device
    
    # EDITING PART
    print('Running editing...')
    eval_collection = {'orig_prompt': [], 'orig_image': [], 'edited_prompt': [], 'edited_image': []}
    for image_path, prompts_dict in tqdm(editing_benchmark):
        if n_differences(prompts_dict['before'], prompts_dict['after']) != 1:
            continue
        if len(prompts_dict['before'].split(' ')) != len(prompts_dict['after'].split(' ')):
            continue
            
        prompt = [prompts_dict['before']]
        (image_gt, image_rec), latent, uncond_embeddings = inversion.invert(
                                                                   # Playing params
                                                                   is_cons_inversion=args.use_cons_inversion,
                                                                   do_npi=args.use_npi,
                                                                   do_nti=args.use_nti,
                                                                   stop_step=50, # from [0, NUM_DDIM_STEPS]
        
                                                                   nti_guidance_scale=args.nti_guidance_scale,
                                                                   inv_guidance_scale=1.0,
                                                                   dynamic_guidance=False,
                                                                   tau1=0.0,
                                                                   tau2=0.0,
    
                                                                   # Fixed params
                                                                   solver=generator,
                                                                   image_path=image_path, 
                                                                   prompt=prompt,
                                                                   offsets=(0,0,200,0),
                                                                   num_inner_steps=10, 
                                                                   early_stop_epsilon=1e-5,
                                                                   seed=args.seed)
        
        if args.use_cons_editing:
            p2p.NUM_DDIM_STEPS = args.num_forw_cons_steps
            model = forw_cons_model
        else:
            model = ldm_stable
            
        prompts = [prompts_dict['before'], prompts_dict['after']]
        cross_replace_steps = {'default_': args.cross_replace_steps,} 
        self_replace_steps = args.self_replace_steps
        w1, w2 = find_difference(prompts_dict['before'], prompts_dict['after'])
        blend_word = (((w1,), (w2,)))
        eq_params = {"words": (w2,), "values": (args.amplify_factor,)} 
        
        controller = p2p.make_controller(prompts, 
                                         True,
                                         cross_replace_steps,
                                         self_replace_steps, 
                                         blend_word, 
                                         eq_params)
        image, _ = generation.runner(
                                 # Playing params
                                 model=model, # ldm_stable or forw_cons_model
                                 is_cons_forward=args.use_cons_editing,
                                
                                 w_embed_dim=args.w_embed_dim,
                                 guidance_scale=args.guidance_scale,
                                 dynamic_guidance=args.dynamic_guidance,
                                 tau1=args.tau1,
                                 tau2=args.tau2,
                                 start_time=50,
    
                                 # Fixed params
                                 solver=generator,
                                 prompt=prompts,
                                 controller=controller,
                                 num_inference_steps=50,
                                 generator=None,
                                 latent=latent,
                                 uncond_embeddings=uncond_embeddings,
                                 return_type='image')
        
        pil_img_orig = generation.to_pil_images(image_gt)
        pil_img_edited = generation.to_pil_images(image[1, :, :, :])
        eval_collection['orig_prompt'].append(prompts_dict['before'])
        eval_collection['orig_image'].append(pil_img_orig)
        eval_collection['edited_prompt'].append(prompts_dict['after'])
        eval_collection['edited_image'].append(pil_img_edited)
        
    # VALIDATION PART
    preservation_metric_fn = available_metrics['preservation'][args.preservation_metric]
    editing_metric_fn = available_metrics['editing'][args.editing_metric]
    
    preserve_metric_value = preservation_metric_fn(eval_collection['orig_image'], 
                                                   eval_collection['edited_image'], 
                                                   device=args.device, batch_size=16)
    editing_metric_value = editing_metric_fn(eval_collection['edited_image'], 
                                             eval_collection['edited_prompt'], 
                                             device=args.device, batch_size=16)
    results = {'preservation': preserve_metric_value, 'editing': editing_metric_value}
    
    # SAVING PART
    outdir = args.saving_dir
    os.makedirs(outdir, exist_ok=True)
    with open(f'{outdir}/editing_metrics_values.pickle', 'wb') as handle:
        pickle.dump(results, handle, protocol=pickle.HIGHEST_PROTOCOL)
        
    outdir_images = f'{args.saving_dir}/edited_images'
    os.makedirs(outdir_images, exist_ok=True)
    for j, image in enumerate(eval_collection['edited_image']):
        image.save(f'{outdir_images}/{j}.png')
# -------------------------------------------------------------------------------------


if __name__ == "__main__":
    args = parse_args()
    main(args)