"""Prompt GPT models to infer task steps and task invocation path"""
import json
import click
import os
from openai import AzureOpenAI
import time
import sys
sys.path.append("../")
from utils import get_cur_time


def inference_one_case(input, tool_string, write_file, demo_string):
    user_request = input["user_request"]
    
    if dataset_name != "dailylife":
         prompt = """\n# GOAL #: Based on the above tools, I want you generate task steps and task nodes to solve the # USER REQUEST #. The format must in a strict JSON format, like: {"task_steps": [ step description in the format 'Step x Use xx (tool's name) to do xx' ], "task_nodes": [{"task": "task name must be from # TASK LIST #", "arguments": [ {"name": "parameter name", "value": "parameter value"} ]}], "task_links": [{"source": "task name i", "target": "task name j"}]}"""
         prompt += """\n\n# REQUIREMENTS #: \n1. the generated task steps and task nodes can resolve the given user request # USER REQUEST # perfectly. Task name must be selected from # TASK LIST #; \n2. the task steps should strictly aligned with the task nodes; \n3. the dependencies among task steps should align with the argument dependencies of the task nodes; \n"""
         prompt += """4. the task steps should not contain ANY parameter arguments, instead, each step is a decomposed task to solve the user's request in a manageable way. The task step should be relatively short.\n"""
         prompt += """5. you should carefully analyze each tool's input and output requirements to ensure that each step's output can be fed into its next step.\n"""
         
         if dataset_name == 'huggingface' and gpt == "gpt-35-turbo":
            prompt += """6. [Important]in each inferred step, try to avoid the functional word (like 'translated', 'classified', 'generated', 'transcribed', etc) used in PREVIOUS step to avoid confusion.\n\n"""
    else:
        # Prompt for dailylife
        prompt = """\n# GOAL #: Based on the above tools, I want you generate task steps and task nodes to solve the # USER REQUEST #. The format must in a strict JSON format, like: {"task_steps": [ concrete steps, format as Step x: Call xxx tool ], "task_nodes": [{"task": "task name must be from # TASK LIST #", "arguments": [ {"name": "parameter name", "value": "parameter value"} ]}], "task_links": [{"source": "task name i", "target": "task name j"}]}"""
        prompt += """\n\n# REQUIREMENTS #: \n1. the generated task steps and task nodes can resolve the given user request # USER REQUEST # perfectly. Task name must be selected from # TASK LIST #; \n2. the task steps should strictly aligned with the task nodes; \nThe task links (task_links) should reflect the temporal dependencies among task nodes, i.e. the order in which the APIs are invoked; \n"""
        prompt += """4. each step is a selected task to solve the user's request in a manageable way. \n"""

    prompt += demo_string
    prompt += """\n\n# USER REQUEST #: {{user_request}}\nnow please generate your result in a strict JSON format:\n# RESULT #:"""
    final_prompt = tool_string + prompt.replace("{{user_request}}", user_request)
    payload = final_prompt
    
    st_time = time.time()
    try:
        returned_content = get_response(payload)
    except Exception as e:
        print(f"Failed #id {input['id']}: {type(e)} {e}")
        return False
    
    res = {"id": input["id"], "user_request": input["user_request"]}
    res["task_steps"] = returned_content["task_steps"]
    res["task_nodes"] = returned_content.get("task_nodes", [])
    res["task_links"] = returned_content.get("task_links", [])
    res["cost_time"] = round(time.time() - st_time, 4)

    print(input["id"], res["user_request"], res["task_steps"], "\n")
        
    write_file.write(json.dumps(res) + "\n")
    write_file.flush()
    return True


def get_response(payload):
    try:
        response = client.chat.completions.create(
            model = gpt,
            messages = [{
                "role": "user",
                "content": payload
            }]
        )
    except Exception as e:
        sleep_time = 30 
        time.sleep(sleep_time)
        response = client.chat.completions.create(
            model = gpt,
            messages = [{
                "role": "user",
                "content": payload
            }]
        )

    origin_content = response.choices[0].message.content
    origin_content = origin_content.replace("\n", "")
    origin_content = origin_content.replace("\_", "_")
    origin_content = origin_content.replace("```", "")
    content = origin_content.replace("\\", "")

    content = content[content.find("{"):content.rfind("}")+1]
    # print(content)
    try:
        content = json.loads(content)
        return content 
    except json.JSONDecodeError as e:
        print(content)
        # encounter JSON decoder error
        # prompt LLM to reformat the response into strict JSON format
        prompt = """Please format the result # RESULT # to a strict JSON format # STRICT JSON FORMAT #. \nRequirements:\n1. Do not change the meaning of task steps, task nodes and task links;\n2. Don't tolerate any possible irregular formatting to ensure that the generated content can be converted by json.loads();\n3. Pay attention to the matching of brackets. Write in a compact format and avoid using too many space formatting controls;\n4. You must output the result in this schema: {"task_steps": [ step description in the format 'Step x Use xx (tool's name) to do xx' ], "task_nodes": [{"task": "task name must be from # TASK LIST #", "arguments": [ {"name": "parameter name", "value": "parameter value"} ]}], "task_links": [{"source": "task name i", "target": "task name j"}]}\n# RESULT #:{{illegal_result}}\n# STRICT JSON FORMAT #:"""
        
        prompt = prompt.replace("{{illegal_result}}", origin_content)
        payload = json.loads(payload)

        try:
            response = client.chat.completions.create(
                 model = gpt,
                 messages = [{
                     "role": "user",
                     "content": prompt
                 }]
            )
        except Exception as e:
            sleep_time = 30
            time.sleep(sleep_time)
            response = client.chat.completions.create(
                 model = gpt,
                 messages = [{
                     "role": "user",
                     "content": prompt
                 }]
            )

        origin_content = response.choices[0].message.content
        origin_content = origin_content.replace("\n", "")
        origin_content = origin_content.replace("\_", "_")
        content = origin_content.replace("\\", "")

        start_pos = content.find("STRICT JSON FORMAT #:")
        if start_pos!=-1:
            content = content[start_pos+len("STRICT JSON FORMAT #:"):]

        content = content[content.find("{"):content.rfind("}")+1]
        try:
            # print(content)
            content = json.loads(content)
            return content
        except json.JSONDecodeError as e:
            print(content)
            raise Exception(f"JSON Decoding Error {e}")


@click.command()
@click.option("--dataset", default="huggingface", help="The directory of the data")
@click.option("--use_demos", type=int, default=1)
@click.option("--gpt_type", type=str, default="gpt-35-turbo") # gpt-4 or gpt-35-turbo
def main(dataset, use_demos, gpt_type):
    print('= ' * 20)
    print('## Starting Time:', get_cur_time(), flush=True)

    global gpt, dataset_name 
    
    prediction_dir = f"../prediction/{dataset}/{gpt_type}"
    if not os.path.exists(prediction_dir):
        os.makedirs(prediction_dir, exist_ok=True)
    
    gpt = gpt_type
    dataset_name = dataset
    infer_step_file = f"{prediction_dir}/direct_promptdemo.json"

    alignment_ids = json.load(open(f"../data/{dataset}/split_ids.json", 'r'))["test_ids"]["chain"]
    
    has_inferenced = []
    if os.path.exists(infer_step_file):
        rf = open(infer_step_file, 'r')
        for line in rf:
            data = json.loads(line)
            has_inferenced.append(data["id"])
        rf.close()
    
    user_request_file = open(f"../data/{dataset}/user_requests.json", 'r')
    inputs = []
    for line in user_request_file:
        input = json.loads(line)
        if input["id"] not in has_inferenced and input["id"] in alignment_ids:
            inputs.append(input)
    user_request_file.close()


    write_file = open(infer_step_file, "a") 
    
    # Prepare Tool String to prompt LLM
    #  - Different from TaskBench: Remove input and outut types
    tool_list = json.load(open(f"../data/{dataset}/tool_desc.json", "r"))["nodes"]
    tool_string = "# TASK LIST #:\n"
    for k, tool in enumerate(tool_list):
        tool_string += json.dumps(tool) + "\n"
    
    # Prepare Demo(s) String to prompt LLM
    demo_string = ""
    if use_demos:
        demos_dict = {
            "huggingface": {
                "id": "14611002",
                "user_request": "I have this image of a document (example.jpg), and I'd like to know which category it belongs to. Also, I have a specific question about its content: 'What is the main topic of the document?'",
                "result": {
                    "task_steps": ["Step 1: Use Image Classification to classify the image", "Step 2: Use Document Question Answering to answer the user's question regarding the image's content"],
                    "task_nodes": [{"task": "Document Question Answering", "arguments": ["example.jpg", "What is the main topic of the document?"]}, {"task": "Image Classification", "arguments": ["example.jpg"]}],
                    "task_links": [{"source": "Image Classification", "target": "Document Question Answering"}]
                }
            },
            "multimedia": {
                "id": "30934207",
                "user_request": "Please create a panorama using two images 'example1.jpg' and 'example2.jpg'.",
                "result": {
                    "task_steps": ["Step 1: Use Image Stitcher to stitch together two images."],
                    "task_nodes": [{"task": "Image Stitcher", "arguments": ["example1.jpg", "example2.jpg"]}],
                    "task_links": []
                }
            },
            "dailylife": {
                "id": "27267145",
                "user_request": "I want to watch the movie 'The Avengers' and then organize an online meeting to discuss the movie with my friends.",
                "result": {
                    'task_steps': ["Step 1: Call play_movie_by_title API", "Step 2: Call organize_meeting_online API"],
                    "task_nodes": [{"task": "organize_meeting_online", "arguments": [{"name": "topic", "value": "Discuss The Avengers Movie"}]}, {"task": "play_movie_by_title", "arguments": [{"name": "title", "value": "The Avengers"}]}], 
                    "task_links": [{"source": "play_movie_by_title", "target": "organize_meeting_online"}],
                }
            },
            "tmdb": {
                "id": 1, 
                "user_request": "Who was the lead actor in the movie The Dark Knight?",
                "result": {
                    "task_steps": ["Step 1 Call SearchMovie to find the movie titled The Dark Knight", "Step 2 Call GetMovieCredit to retrieve the lead actor in the movie"],
                    "task_nodes": [{"task": "SearchMovie"}, {"task": "GetMovieCredit"}], 
                    "task_links": [{"source": "SearchMovie", "target": "GetMovieCredit"}],
                }
            }
        }
    
        demo_string += "\nHere are provided examples for your reference.\n"
        demo = demos_dict[dataset]
        demo_string += f"""\n# EXAMPLE #:\n# USER REQUEST #: {demo["user_request"]}\n# RESULT #: {json.dumps(demo["result"])}"""

    if len(inputs) == 0:
        print("All Completed!")
        return 
    else:
        print(f"Detected {len(has_inferenced)} has been inferenced, ")
        print(f"Start inferencing {len(inputs)} tasks ... ")
    
    num_succ, num_error = 0, 0
    for input in inputs:
        flag = inference_one_case(input, tool_string, write_file, demo_string)

        if flag:
            num_succ += 1
        else:
            num_error += 1
       
    print(f"\nCompleted {num_succ} Failed {num_error}")
    
    print('\n## Finishing Time:', get_cur_time(), flush=True)
    print('= ' * 20)
    print("Done!")


if __name__ == "__main__":
    global gpt, dataset_name 
    client = AzureOpenAI(
        azure_endpoint = "END_POINT", 
        api_key = "YOUR_API_KEY",  
        api_version = "2023-03-15-preview"
    )
    
    main()
