import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, LlamaTokenizer
from transformers.generation import GenerationConfig
import requests
from PIL import Image
import sys
sys.path.append('/mnt/14T-disk/code/contrastive_decoding/Multi-Modality-Arena/model_weights/Qwen/Qwen-VL-Chat-Int4')
from qwen_generation_utils import make_context,pad_batch
import re
from transformers import TextStreamer
try:
    # 导入需要的模块
    from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
    from mplug_owl2.conversation import conv_templates, SeparatorStyle
    from mplug_owl2.model.builder import load_pretrained_model
    from mplug_owl2.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
except ImportError:
    # 发生 ImportError 错误时的处理
    pass
import google.generativeai as genai

'''from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from mplug_owl2.conversation import conv_templates, SeparatorStyle
from mplug_owl2.model.builder import load_pretrained_model
from mplug_owl2.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria'''
# THUDM/visualglm-6b
# internlm/internlm-xcomposer-7b
# Qwen/Qwen-VL-Chat


class TestAutoModel:
    def __init__(self, model_name, device=None) -> None:
        device = 'cuda' if device is None else device
        if 'Qwen' in model_name:
            self.response_type = 'Qwen'
            model_name = "/mnt/14T-disk/code/contrastive_decoding/Multi-Modality-Arena/model_weights/Qwen/Qwen-VL-Chat-Int4"
            self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
            self.model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True).eval()
            self.model.generation_config = GenerationConfig.from_pretrained(model_name, trust_remote_code=True)
        elif 'Gemini' in model_name:
            self.response_type = 'gemini'
            GOOGLE_API_KEY = 'AIzaSyD6hOtpphqFpZ2B6Ai1H7WRLv2pwhvirks'
            genai.configure(api_key=GOOGLE_API_KEY)
            self.model = genai.GenerativeModel('gemini-pro-vision')
            '''model_name = "/mnt/14T-disk/code/contrastive_decoding/Multi-Modality-Arena/model_weights/Qwen/Qwen-VL-Chat-Int4"
            self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
            self.model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True).eval()
            self.model.generation_config = GenerationConfig.from_pretrained(model_name, trust_remote_code=True)'''
        elif 'mplug-owl2' in model_name:
            self.response_type = 'mplug-owl2'
            model_path = '/mnt/14T-disk/code/contrastive_decoding/Multi-Modality-Arena/model_weights/mplug-owl2-llama2-7b'
            model_name = get_model_name_from_path(model_path)
            self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(model_path, None, model_name, load_8bit=False, load_4bit=False, device="cuda")


        elif 'cogvlm' in model_name:
            self.response_type = 'cogvlm'
            self.tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
            self.model = AutoModelForCausalLM.from_pretrained(
                'THUDM/cogvlm-chat-hf',
                torch_dtype=torch.bfloat16,
                low_cpu_mem_usage=True,
                trust_remote_code=True,
                device_map="auto"
            ).eval()
        else:
            self.response_type = 'intern' if 'intern' in model_name else 'glm'
            if self.response_type == 'intern':
                model_name = "/mnt/14T-disk/code/contrastive_decoding/Multi-Modality-Arena/model_weights/internlm-xcomposer-vl-7b"
            else: 
                model_name = "/mnt/14T-disk/code/contrastive_decoding/Multi-Modality-Arena/model_weights/THUDM/visualglm-6b"
            self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
            self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True).half().to(device)
            if self.response_type == 'intern':
                self.model.tokenizer = self.tokenizer

    @torch.no_grad()
    def generate(self, image, question, max_new_tokens=64, bad_words_ids=False,num_beams=1):
        #assert type(image) is str
        
        with torch.cuda.amp.autocast():
            if bad_words_ids:
                if self.response_type == 'glm':
                    bwi = [[666, 103], [31880, 103], [174, 666, 103], [105, 666, 103], [105, 666], [174, 666], [15099], [4340], [206, 114], [2530, 114], [525], [925], [350], [8999], [32007], [38427], ]
                    response, history = self.model.chat(self.tokenizer, image, question, history=[], max_length=max_new_tokens,num_beams=num_beams,bad_words_ids=bwi)
                    import re
                    response = re.sub(r'[^\x00-\x7F]+', '', response)
                elif self.response_type == 'cogvlm':
                    image = Image.open(image).convert('RGB')
                    inputs = self.model.build_conversation_input_ids(self.tokenizer, query=question, history=[], images=[image])  # chat mode
                    inputs = {
                        'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'),
                        'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'),
                        'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'),
                        'images': [[inputs['images'][0].to('cuda').to(torch.bfloat16)]],
                    }
                    gen_kwargs = {"max_length": max_new_tokens, "do_sample": False,"num_beams":num_beams,"bad_words_ids":[[29892]]}

                    with torch.no_grad():
                        outputs = self.model.generate(**inputs, **gen_kwargs)
                        outputs = outputs[:, inputs['input_ids'].shape[1]:]
                        response = self.tokenizer.decode(outputs[0])
                elif self.response_type == 'mplug-owl2':

                    conv = conv_templates["mplug_owl2"].copy()
                    roles = conv.roles

                    image = Image.open(image).convert('RGB')
                    max_edge = max(image.size) # We recommand you to resize to squared image for BEST performance.
                    image = image.resize((max_edge, max_edge))

                    image_tensor = process_images([image], self.image_processor)
                    image_tensor = image_tensor.to(self.model.device, dtype=torch.float16)

                    inp = DEFAULT_IMAGE_TOKEN + question
                    conv.append_message(conv.roles[0], inp)
                    conv.append_message(conv.roles[1], None)
                    prompt = conv.get_prompt()

                    input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.model.device)
                    stop_str = conv.sep2
                    keywords = [stop_str]
                    stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
                    streamer = TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)

                    temperature = 0.7
                    #max_new_tokens = 512
                    #bwi = [[6124, 304], [3462, 654, 304], [512, 6124, 304], [297, 6124, 304], [297, 6124], [512, 6124], [6124, 635], [19814], [1316, 408], [10506, 408], [3160], [7805], [3704], [  512, 2325], [512, 27722], [512, 22368]]
                    bwi = [[6124], [3462, 654], [6124, 635], [19814], [1316, 408], [3160], [7805], [3704], [ 512, 2325], [512, 27722], [512, 22368]]
                    with torch.inference_mode():
                        output_ids = self.model.generate(
                            input_ids,
                            images=image_tensor,
                            do_sample=False,
                            temperature=temperature,
                            max_new_tokens=max_new_tokens,
                            use_cache=True,
                            num_beams=num_beams,
                            bad_words_ids=bwi,
                            )

                    response = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
                elif self.response_type == 'intern':
                    #in addition to[389, 5416, 442]
                    #addition to[]s
                    #','
                    bad_words_ids = [b[1:] for b in self.tokenizer([' addition to',' Addition to',' In addition to',' in addition to',' in addition',' In addition',' additionally',' Additionally',' such as',' Such as',' include',' includes',' including','Include',' Includes',' Including',]).input_ids]
                    response, history = self.model.chat([question], image, max_new_tokens=max_new_tokens,num_beams=num_beams,bad_words_ids=bad_words_ids)
                elif self.response_type == 'Qwen':
                    tokens = [make_context(self.tokenizer, self.tokenizer.from_list_format([{'image': image}, {'text': question}]))[1] for image, question in zip(image, question)]
                    tokens = pad_batch(tokens, 151643, max([len(t) for t in tokens]))
                    tokens = torch.stack([torch.tensor(token) for token in tokens], 0).cuda()
                    #‘,’
                    bad_words_ids = self.tokenizer([' addition to',' Addition to',' In addition to',' in addition to',' in addition',' In addition',' additionally',' Additionally',' such as',' Such as',' include',' includes',' including','Include',' Includes',' Including','Additionally']).input_ids 
                    outputs = self.model.generate(tokens,num_beams=num_beams,bad_words_ids=bad_words_ids,  max_new_tokens=max_new_tokens,do_sample=False) 
                    response = [text.split('\nassistant\n')[1] for text in self.tokenizer.batch_decode(outputs, skip_special_tokens=True)]
                else:
                    raise NotImplementedError(f"Invalid response type: {self.response_type}")
            else:
                if self.response_type == 'glm':
                    response, history = self.model.chat(self.tokenizer, image, question, history=[], max_length=max_new_tokens,num_beams=num_beams)
                    import re
                    response = re.sub(r'[^\x00-\x7F]+', '', response)
                elif self.response_type == 'gemini':
                    #image = Image.load_from_file("image.jpg")

                    # Prepare contents
                    #prompt = "Describe this image?"
                    import PIL.Image
                    import re

                    image = PIL.Image.open(image)
                    contents = [image, question]

                    #response = self.model.generate_content(contents)
                    while True:
                        try:
                            response = self.model.generate_content(contents)
                            break  # 如果没有错误，跳出循环
                        except Exception as e:
                            print(f"An error occurred: {e}. Retrying...")
                    try:
                        response = re.sub(r'[^\x00-\x7F]+', '', response.text)
                    except Exception as e:
                        print(f'response:{response}')
                        print(f"An error occurred: {e}.")
                        response = 'ERROR'
                
                
                elif self.response_type == 'mplug-owl2':

                    conv = conv_templates["mplug_owl2"].copy()
                    roles = conv.roles

                    image = Image.open(image).convert('RGB')
                    max_edge = max(image.size) # We recommand you to resize to squared image for BEST performance.
                    image = image.resize((max_edge, max_edge))

                    image_tensor = process_images([image], self.image_processor)
                    image_tensor = image_tensor.to(self.model.device, dtype=torch.float16)

                    inp = DEFAULT_IMAGE_TOKEN + question
                    conv.append_message(conv.roles[0], inp)
                    conv.append_message(conv.roles[1], None)
                    prompt = conv.get_prompt()

                    input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.model.device)
                    stop_str = conv.sep2
                    keywords = [stop_str]
                    stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
                    streamer = TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)

                    temperature = 0.7
                    #max_new_tokens = 512

                    with torch.inference_mode():
                        output_ids = self.model.generate(
                            input_ids,
                            images=image_tensor,
                            do_sample=False,
                            temperature=temperature,
                            max_new_tokens=max_new_tokens,
                            use_cache=True,
                            num_beams=num_beams,
                            )

                    response = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
                elif self.response_type == 'cogvlm':
                    image = Image.open(image).convert('RGB')
                    inputs = self.model.build_conversation_input_ids(self.tokenizer, query=question, history=[], images=[image])  # chat mode
                    inputs = {
                        'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'),
                        'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'),
                        'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'),
                        'images': [[inputs['images'][0].to('cuda').to(torch.bfloat16)]],
                    }
                    gen_kwargs = {"max_length": max_new_tokens, "do_sample": False,"num_beams":num_beams}

                    with torch.no_grad():
                        outputs = self.model.generate(**inputs, **gen_kwargs)
                        outputs = outputs[:, inputs['input_ids'].shape[1]:]
                        response = self.tokenizer.decode(outputs[0])
                elif self.response_type == 'intern':
                    response, history = self.model.chat([question], image, max_new_tokens=max_new_tokens,num_beams=num_beams)
                elif self.response_type == 'Qwen':
                    tokens = [make_context(self.tokenizer, self.tokenizer.from_list_format([{'image': image}, {'text': question}]))[1] for image, question in zip(image, question)]
                    tokens = pad_batch(tokens, 151643, max([len(t) for t in tokens]))
                    tokens = torch.stack([torch.tensor(token) for token in tokens], 0).cuda()
                    #outputs = self.model.generate(tokens,num_beams=num_beams, max_new_tokens=max_new_tokens,do_sample=False) 
                    ##############################
                    outputs = self.model.generate(tokens,num_beams=num_beams, max_new_tokens=max_new_tokens,do_sample=False, output_attentions=True, return_dict_in_generate=True, output_scores=True) 
                    import numpy as np
                    attns = [o[-1] for o in outputs['attentions']]
                    attns = [attns[i][outputs['beam_indices'][0][i]] for i in range(len(attns))]
                    s = attns[0].shape[-1]
                    #for wi in range(len(attns)-1):
                    ttn_w = [[torch.max(a[:,:,s+wi]).cpu().numpy() for a in attns[wi+4:]] for wi in range(len(attns)-4)]                    
                    ttn_w = [np.mean(t) for t in ttn_w]
                    
                    mean = np.mean(ttn_w)
                    std = np.std(ttn_w)
                    print(sum(ttn_w>(mean+2*std)))

                    ##############################################

                    response = [text.split('\nassistant\n')[1] for text in self.tokenizer.batch_decode(outputs['sequences'], skip_special_tokens=True)]
                else:
                    raise NotImplementedError(f"Invalid response type: {self.response_type}")
        return response

    @torch.no_grad()
    def batch_generate(self, image_list, question_list, max_new_tokens=1282, bad_words_ids=False,num_beams=1,):
        if self.response_type=="Qwen":
            output = self.generate(image_list, question_list, max_new_tokens,bad_words_ids=bad_words_ids,num_beams=num_beams) 
        elif self.response_type=="gemini":
            #output = [self.generate(image, question, max_new_tokens,
            # bad_words_ids=bad_words_ids,num_beams=num_beams) for image, question in zip(image_list, question_list)]
            from concurrent.futures import ThreadPoolExecutor

            def generate_wrapper(args):
                return self.generate(*args)

            num_threads = 8  # 指定线程数
            with ThreadPoolExecutor(max_workers=num_threads) as executor:
                output = list(executor.map(generate_wrapper, zip(image_list, question_list)))
        else:
            output = [self.generate(image, question, max_new_tokens,bad_words_ids=bad_words_ids,num_beams=num_beams) for image, question in zip(image_list, question_list)] 
        
        return output