import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import CLIPImageProcessor, CLIPVisionModel, StoppingCriteria
from .llava import LlavaMPTForCausalLM, LlavaLlamaForCausalLM, conv_templates, SeparatorStyle
from . import get_image
import torch.nn.functional as F


DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"


class KeywordsStoppingCriteria(StoppingCriteria):
    def __init__(self, keywords, tokenizer, input_ids):
        self.keywords = keywords
        self.tokenizer = tokenizer
        self.start_len = None
        self.input_ids = input_ids

    def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        if self.start_len is None:
            self.start_len = self.input_ids.shape[1]
        else:
            outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
            for keyword in self.keywords:
                if keyword in outputs:
                    return True
        return False


def get_model_name(model_path):
    # get model name
    if model_path.endswith("/"):
        model_path = model_path[:-1]
    model_paths = model_path.split("/")
    if model_paths[-1].startswith('checkpoint-'):
        model_name = model_paths[-2] + "_" + model_paths[-1]
    else:
        model_name = model_paths[-1]
    
    return model_name


def get_conv(model_name):
    if "llava" in model_name.lower():
        if "v1" in model_name.lower():
            template_name = "llava_v1"
        elif "mpt" in model_name.lower():
            template_name = "mpt_multimodal"
        else:
            template_name = "multimodal"
    elif "mpt" in model_name:
        template_name = "mpt_text"
    elif "koala" in model_name: # Hardcode the condition
        template_name = "bair_v1"
    elif "v1" in model_name:    # vicuna v1_1/v1_2
        template_name = "vicuna_v1_1"
    else:
        template_name = "v1"
    return conv_templates[template_name].copy()


def load_model(model_path, model_name, dtype=torch.float16, device='cpu'):
    # get tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    if 'llava' in model_name.lower():
        if 'mpt' in model_name.lower():
            model = LlavaMPTForCausalLM.from_pretrained(model_path, torch_dtype=dtype, low_cpu_mem_usage=True)
        else:
            model = LlavaLlamaForCausalLM.from_pretrained(model_path, torch_dtype=dtype, low_cpu_mem_usage=True)
    elif 'mpt' in model_name.lower():
        model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=dtype, low_cpu_mem_usage=True, trust_remote_code=True)
    else:
        model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=dtype, low_cpu_mem_usage=True)

    # get image processor
    image_processor = None
    if 'llava' in model_name.lower():
        image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=dtype)

        mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
        tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
        if mm_use_im_start_end:
            tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)

        vision_tower = model.get_model().vision_tower[0]
        if vision_tower.device.type == 'meta':
            vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=dtype, low_cpu_mem_usage=True).to(device=device)
            model.get_model().vision_tower[0] = vision_tower
        else:
            vision_tower.to(device=device, dtype=dtype)
        
        vision_config = vision_tower.config
        vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
        vision_config.use_im_start_end = mm_use_im_start_end
        if mm_use_im_start_end:
            vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])

    if hasattr(model.config, "max_sequence_length"):
        context_len = model.config.max_sequence_length
    else:
        context_len = 2048

    model.to(device=device)

    return tokenizer, model, image_processor, context_len


class TestLLaVA:
    def __init__(self):
        model_path="/mnt/14T-disk/code/contrastive_decoding/Multi-Modality-Arena/model_weights/llava_weights/LLaVA-Lightning-7B-v1-1"
        model_name = get_model_name(model_path)
        self.tokenizer, self.model, self.image_processor, self.context_len = load_model(model_path, model_name)
        self.conv = get_conv(model_name)
        self.image_process_mode = "Resize" # Crop, Resize, Pad

        self.move_to_device()
        
    def move_to_device(self):
        if torch.cuda.is_available():
            self.dtype = torch.float16
            self.device = 'cuda'
        else:
            self.dtype = torch.float32
            self.device = 'cpu'
        vision_tower = self.model.get_model().vision_tower[0]
        vision_tower.to(device=self.device, dtype=self.dtype)
        self.model.to(device=self.device, dtype=self.dtype)
    
    @torch.no_grad()
    def generate(self, image, question, max_new_tokens=128):
        image = get_image(image)
        conv = self.conv.copy()
        text = question + '\n<image>'
        text = (text, image, self.image_process_mode)
        conv.append_message(conv.roles[0], text)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()
        stop_str = conv.sep if conv.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else conv.sep2
        output = self.do_generate([prompt], [image], stop_str=stop_str, dtype=self.dtype, max_new_tokens=max_new_tokens)[0]

        return output
    
    @torch.no_grad()
    def batch_generate(self, image_list, question_list, max_new_tokens=128,cond_coeff=0.,
    cf_coeff=0.,
    mi_coeff=0.,max_ensemble=False,
                min_ensemble=False, instruction=None, ins_coeff=False,uncertainty_threshold=-1e9,penalty_alpha=0., bad_words_ids=False, num_beams=1):
        raw_question_list = question_list
        if instruction is not None:
            question_list = [instruction.format(q) for q in question_list]
        images, prompts = [], []
        for image, question in zip(image_list, question_list):
            image = get_image(image)
            conv = self.conv.copy()
            text = question + '\n<image>'
            text = (text, image, self.image_process_mode)
            conv.append_message(conv.roles[0], text)
            conv.append_message(conv.roles[1], None)
            prompt = conv.get_prompt()
            prompts.append(prompt)
            images.append(image)
        stop_str = conv.sep if conv.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else conv.sep2
        outputs = self.do_generate(prompts, images, stop_str=stop_str, dtype=self.dtype, max_new_tokens=max_new_tokens,cond_coeff=cond_coeff, cf_coeff=cf_coeff, mi_coeff=mi_coeff, question_list=raw_question_list, ins_coeff=ins_coeff, uncertainty_threshold=uncertainty_threshold,bad_words_ids=bad_words_ids,num_beams=num_beams)

        return outputs
    
    def get_images_input_ids(self, images, prompts, dtype=torch.float16, keep_aspect_ratio=False):
        if keep_aspect_ratio:
            new_images = []
            for image, prompt in zip(images, prompts):
                max_hw, min_hw = max(image.size), min(image.size)
                aspect_ratio = max_hw / min_hw
                max_len, min_len = 448, 224
                shortest_edge = int(min(max_len / aspect_ratio, min_len))
                image = self.image_processor.preprocess(image, return_tensors='pt', do_center_crop=False, size={"shortest_edge": shortest_edge})['pixel_values'][0]
                new_images.append(image.to(self.model.device, dtype=dtype))
                # replace the image token with the image patch token in the prompt (each occurrence)
                cur_token_len = (image.shape[1]//14) * (image.shape[2]//14)
                replace_token = DEFAULT_IMAGE_PATCH_TOKEN * cur_token_len
                if getattr(self.model.config, 'mm_use_im_start_end', False):
                    replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
                prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token, 1)
            images = new_images
        else:
            images = self.image_processor(images, return_tensors='pt')['pixel_values']
            images = images.to(self.model.device, dtype=dtype)
            replace_token = DEFAULT_IMAGE_PATCH_TOKEN * 256    # HACK: 256 is the max image token length hacked
            if getattr(self.model.config, 'mm_use_im_start_end', False):
                replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
            prompts = [prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token) for prompt in prompts]

        input_ids = self.tokenizer(prompts).input_ids
        batch_size = len(input_ids)
        max_prompt_size = max([len(input_id) for input_id in input_ids])
        for i in range(len(input_ids)):
            padding_size = max_prompt_size - len(input_ids[i])
            input_ids[i] = [self.tokenizer.pad_token_id] * padding_size + input_ids[i]

        return images, input_ids, batch_size

    def get_input_ids(self, prompts, dtype=torch.float16, keep_aspect_ratio=False):
       

        input_ids = self.tokenizer(prompts).input_ids
        batch_size = len(input_ids)
        max_prompt_size = max([len(input_id) for input_id in input_ids])
        for i in range(len(input_ids)):
            padding_size = max_prompt_size - len(input_ids[i])
            input_ids[i] = [self.tokenizer.pad_token_id] * padding_size + input_ids[i]

        return input_ids, batch_size


    @torch.no_grad()
    def do_generate(self, prompts, images, dtype=torch.float16, temperature=0.2, max_new_tokens=256, stop_str=None, keep_aspect_ratio=False, cond_coeff=0., cf_coeff=0., mi_coeff=0., question_list=None, ins_coeff=False, uncertainty_threshold=-1e9, remove_qs=True, bad_words_ids=False,num_beams=1):

        images, input_ids, batch_size = self.get_images_input_ids(images, prompts, dtype, keep_aspect_ratio)

        stop_idx = None
        if stop_str is not None:
            stop_idx = self.tokenizer(stop_str).input_ids
            if len(stop_idx) == 1:
                stop_idx = stop_idx[0]
            else:
                stop_idx = None
        
        output_ids = []
        get_result = [False for _ in range(batch_size)]
        for i in range(max_new_tokens):
            if i == 0:
                out = self.model(
                    torch.as_tensor(input_ids).to(self.model.device),
                    use_cache=True,
                    images=images)
                logits = out.logits
                if bad_words_ids:
                    logits[:,:,6124]=-torch.inf
                    logits[:,:,19814]=-torch.inf
                    logits[:,:,1316]=-torch.inf
                    logits[:,:,3160]=-torch.inf
                    logits[:,:,7805]=-torch.inf
                    logits[:,:,3704]=-torch.inf

                past_key_values = out.past_key_values
            else:
                out = self.model(input_ids=token,
                            use_cache=True,
                            attention_mask=torch.ones(batch_size, past_key_values[0][0].shape[-2] + 1, device=self.model.device),
                            past_key_values=past_key_values)
                logits = out.logits
                if bad_words_ids:
                    logits[:,:,29892]=-torch.inf
                past_key_values = out.past_key_values

            last_token_logits = logits[:, -1]
            if temperature < 1e-4:
                token = torch.argmax(last_token_logits, dim=-1)
            else:
                probs = torch.softmax(last_token_logits / temperature, dim=-1)
                token = torch.multinomial(probs, num_samples=1)
            token = token.long().to(self.model.device)

            output_ids.append(token)
            for idx in range(len(token)):
                if token[idx] == stop_idx or token[idx] == self.tokenizer.eos_token_id:
                    get_result[idx] = True
            if all(get_result):
                break
        
        output_ids = torch.cat(output_ids, dim=1).long()
        outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=False)
        if stop_str is not None:
            for i in range(len(outputs)):
                pos = outputs[i].rfind(stop_str)
                if pos != -1:
                    outputs[i] = outputs[i][:pos]
        
        return outputs

    @torch.no_grad()
    def old_do_generate(self, prompts, images, dtype=torch.float16, temperature=0.2, max_new_tokens=256, stop_str=None, keep_aspect_ratio=False, cond_coeff=0., cf_coeff=0., mi_coeff=0., question_list=None, ins_coeff=False, uncertainty_threshold=-1e9, remove_qs=True):
        
        images, input_ids, batch_size = self.get_images_input_ids(images, prompts, dtype, keep_aspect_ratio)
        if cf_coeff>0:
            if remove_qs:
                cf_prompts = [prompts[i].replace(question_list[i],'') for i in range(len(prompts))]
            cf_input_ids, _ = self.get_input_ids(cf_prompts, dtype, keep_aspect_ratio)
        stop_idx = None
        if stop_str is not None:
            stop_idx = self.tokenizer(stop_str).input_ids
            if len(stop_idx) == 1:
                stop_idx = stop_idx[0]
            else:
                stop_idx = None
        
        output_ids = []
        
        get_result = [False for _ in range(batch_size)]
        for i in range(max_new_tokens):
            if i == 0:
                # 魏洪亮 llava
                out = self.model(
                    torch.as_tensor(input_ids).to(self.model.device),
                    use_cache=True,
                    images=images)
                logits = out.logits
                logits[:,:,29892]=-torch.inf
                
                past_key_values = out.past_key_values
                last_token_logits = logits[:, -1]



            else:
                out = self.model(input_ids=token,
                            use_cache=True,
                            attention_mask=torch.ones(batch_size, past_key_values[0][0].shape[-2] + 1, device=self.model.device),
                            past_key_values=past_key_values)
                logits = out.logits
                logits[:,:,29892]=-torch.inf

                last_token_logits = logits[:, -1]

                past_key_values = out.past_key_values
                from torch import nn

                predictive_uncertanty = -(nn.functional.log_softmax(
                logits[:, -1].to(torch.float), dim=-1)*nn.functional.softmax(logits[:, -1].to(torch.float), dim=-1)).sum(dim=1)
                
                if cf_coeff>0 and sum(predictive_uncertanty>uncertainty_threshold)>0:

                    cf_out = self.model(
                    torch.as_tensor([cf_input_ids[i] for i in range(len(cf_input_ids)) if (predictive_uncertanty>uncertainty_threshold)[i]]).to(self.model.device),
                    )
                    cf_logits = cf_out.logits
                    cf_last_token_logits = cf_logits[:, -1]

                    zeros = logits[:, -1].clone()
                    m = 0
                    for n in range(len(predictive_uncertanty>uncertainty_threshold)):
                        if (predictive_uncertanty>uncertainty_threshold)[n]:
                            zeros[n] = cf_last_token_logits[m].clone()
                            m = m+1



                    cf_last_token_logits = zeros
                    cf_last_token_logits[torch.isinf(cf_last_token_logits)] = 0



                    last_token_logits = last_token_logits+ torch.mul((predictive_uncertanty>uncertainty_threshold).unsqueeze(1),(
                                    cf_coeff*(last_token_logits- cf_last_token_logits)))
                    last_token_logits[torch.isinf(last_token_logits)] = 0
                                    
                
                    
            

            if temperature < 1e-4:
                token = torch.argmax(last_token_logits, dim=-1)
            else:
                probs = torch.softmax(last_token_logits / temperature, dim=-1)
                token = torch.multinomial(probs, num_samples=1)
            token = token.long().to(self.model.device)

            if cf_coeff>0:
                for m in range(len(cf_input_ids)):
                    cf_input_ids[m].append(token[m])
            output_token = token
            output_ids.append(output_token)
            for idx in range(len(output_token)):
                if output_token[idx] == stop_idx or output_token[idx] == self.tokenizer.eos_token_id:
                    get_result[idx] = True
            if all(get_result):
                break
        
        output_ids = torch.cat(output_ids, dim=1).long()
        outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        if stop_str is not None:
            for i in range(len(outputs)):
                pos = outputs[i].rfind(stop_str)
                if pos != -1:
                    outputs[i] = outputs[i][:pos]
        
        return outputs