from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import uvicorn, json, datetime
import torch
import transformers


def torch_gc():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()

app = FastAPI()
B_INST, E_INST = "[INST]", "[/INST]"

model = None
tokenizer = None
last_model_name = None
pipeline = None
gpu_count = torch.cuda.device_count()

@app.post("/")
async def create_item(request: Request):
    global model, tokenizer, last_model_name, gpu_count, pipeline
    json_post_raw = await request.json()
    json_post = json.dumps(json_post_raw)
    json_post_list = json.loads(json_post)
    model_name = json_post_list.get('model')
    system_message = json_post_list.get('system_message')  # system message
    prompt = json_post_list.get('prompt') # user message
    history = json_post_list.get('history') # chat history
    max_length = json_post_list.get('max_length') # max response length (default: 4096)
    top_p = json_post_list.get('top_p') # top-p value (default: 1.0)
    temperature = json_post_list.get('temperature') # temperature (default: 0.1)

    # init model
    model_path = f"/llm_models/{model_name}"

    if model_name == 'llama-3-8b-instruct':
        if pipeline is None or last_model_name != model_name:
            last_model_name = model_name

            pipeline = transformers.pipeline(
                "text-generation",
                model=model_path,
                model_kwargs={"torch_dtype": torch.bfloat16},
                device_map='auto',
            )
            print("====================================")
            print(f"Loading model {model_path}...")

        prompt = pipeline.tokenizer.apply_chat_template(
            json_post_list.get('messages'),
            tokenize=False,
            add_generation_prompt=True
        )

        terminators = [
            pipeline.tokenizer.eos_token_id,
            pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ]

        outputs = pipeline(
            prompt,
            max_new_tokens=max_length if max_length else 4096,
            eos_token_id=terminators,
            do_sample=True,
            temperature=temperature if temperature else 0.1,
            top_p=0.9,
        )
        response = outputs[0]["generated_text"][len(prompt):]
        history.append([prompt, response])
    else:
        if model is None or tokenizer is None or last_model_name != model_name:
            last_model_name = model_name
            model = None
            tokenizer = None
            model_path = f"/llm_models/{model_name}"
            print("====================================")
            print(f"Loading model {model_path}...")

            tokenizer = AutoTokenizer.from_pretrained(model_path)
            # config = AutoConfig.from_pretrained(model_path)
            # config.pretraining_tp = 1

            model = AutoModelForCausalLM.from_pretrained(
                model_path,
                # config=config,
                torch_dtype=torch.float16,
                # load_in_8bit=True,  # make it faster but may lose some accuracy
                device_map='auto'
            )
            model.eval()
            print("====================================")
            print(f"Load success!")

        if system_message:
            SYSTEM = f"""<<SYS>>\n{system_message}\n<</SYS>>\n\n"""
            prompt = SYSTEM + prompt.strip()

        input_ids = []
        for q, a in history:
            input_ids += tokenizer.encode(f"{B_INST} {q} {E_INST} {a}") + [tokenizer.eos_token_id]
        input_ids += tokenizer.encode(f"{B_INST} {prompt} {E_INST}")

        if model_name == 'mistral-7B':
            encodeds = tokenizer.apply_chat_template(
                json_post_list.get('messages'),
                return_tensors="pt")

            prompt___ = json_post_list.get('messages')
            model_inputs = encodeds.to("cuda")
            generated_ids = model.generate(model_inputs, max_new_tokens=max_length if max_length else 4096, do_sample=True)
            decoded = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            response = decoded[0][len('[INST] ' + prompt___[0]['content'] + '[\INST] '):]
            history.append([json_post_list.get('messages'), response])
        else:
            response = model.generate(torch.tensor([input_ids]).cuda(),
                                      do_sample=True,
                                      max_length=max_length if max_length else 4096,
                                      top_p=top_p if top_p else 1.0,
                                      temperature=temperature if temperature else 0.1,
                                      top_k=50)
            print(temperature, top_p)
            response = tokenizer.decode(response[0, len(input_ids):], skip_special_tokens=True)
            history.append([prompt, response])

    now = datetime.datetime.now()
    time = now.strftime("%Y-%m-%d %H:%M:%S")
    answer = {
        "response": response,
        "history": history,
        "status": 200,
        "time": time
    }
    log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
    print(log)
    torch_gc()
    return answer
