import os
import json
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--task', default='LaMP_7', type=str, help='task')
args = parser.parse_args()

if args.task == "LaMP_2":
    file_path = "data/LaMP/LaMP_2/train/train_questions.json"
    solution_path = "data/LaMP/LaMP_2/train/train_outputs.json"
    with open(file_path, 'r') as f:
        for line in f:
            data = json.loads(line)
    print("The number of user queries: ", len(data))
    with open(solution_path, 'r') as f:
        for line in f:
            solution = json.loads(line)
        solution = solution['golds']
    print("The number of user solutions: ", len(solution))
    data_list = []
    template = "Which category does this article relate to among the following categories? Just answer with the category name without further explanation. categories: [women, religion, politics, style & beauty, entertainment, culture & arts, sports, science & technology, travel, business, crime, education, healthy living, parents, food & drink] article: {article}"
    id_list = []
    user_dict = {}
    user_id = 0
    avg_profile_num = 0
    for line, sol in zip(data, solution):
        if line['id'] == sol['id']:
            id_list.append(line['id'])
            if line['id'] not in user_dict:
                user_dict[line['id']] = user_id
                user_id += 1
            data = {'id': line['id'], 'source': line['input'], 'target': sol['output']}
            data_list.append(data)
            profile_list = line['profile']
            for profile in profile_list:
                source = template.format(article=profile['text'])
                target = profile['category']
                data = {'id': line['id'], 'source': source, 'target': target}
                data_list.append(data)
            avg_profile_num += len(profile_list)
    print("The number of different users: ", len(list(set(id_list))))
    print("The number of user history data: ", len(data_list))
    print("The average number of user history data: ", avg_profile_num / len(list(set(id_list))))
    output_path = "data/LaMP/LaMP_2/train/train_new.json"
    with open(output_path, 'a') as f:
        for data in data_list:
            json.dump(data, f)
            f.write('\n')
    output_path = "data/LaMP/LaMP_2/train/train_user_dict.json"
    with open(output_path, 'w') as f:
        json.dump(user_dict, f)
elif args.task == "LaMP_2_movie":
    file_path = "data/LaMP/LaMP_2_movie/train/train_questions.json"
    solution_path = "data/LaMP/LaMP_2_movie/train/train_outputs.json"
    with open(file_path, 'r') as f:
        for line in f:
            data = json.loads(line)
    print("The number of user queries: ", len(data))
    with open(solution_path, 'r') as f:
        for line in f:
            solution = json.loads(line)
        solution = solution['golds']
    print("The number of user solutions: ", len(solution))
    data_list = []
    template = "Which tag does this movie relate to among the following tags? Just answer with the tag name without further explanation. tags: [sci-fi, based on a book, comedy, action, twist ending, dystopia, dark comedy, classic, psychology, fantasy, romance, thought-provoking, social commentary, violence, true story] description: {description}"
    id_list = []
    user_dict = {}
    user_id = 0
    avg_profile_num = 0
    for line, sol in zip(data, solution):
        if line['id'] == sol['id']:
            id_list.append(line['id'])
            if line['id'] not in user_dict:
                user_dict[line['id']] = user_id
                user_id += 1
            data = {'id': line['id'], 'source': line['input'], 'target': sol['output']}
            data_list.append(data)
            profile_list = line['profile']
            for profile in profile_list:
                source = template.format(description=profile['description'])
                target = profile['tag']
                data = {'id': line['id'], 'source': source, 'target': target}
                data_list.append(data)
            avg_profile_num += len(profile_list)
    print("The number of different users: ", len(list(set(id_list))))
    print("The number of user history data: ", len(data_list))
    print("The average number of user history data: ", avg_profile_num / len(list(set(id_list))))
    output_path = "data/LaMP/LaMP_2_movie/train/train_new.json"
    with open(output_path, 'a') as f:
        for data in data_list:
            json.dump(data, f)
            f.write('\n')
    output_path = "data/LaMP/LaMP_2_movie/train/train_user_dict.json"
    with open(output_path, 'w') as f:
        json.dump(user_dict, f)
elif args.task == 'LaMP_3':
    file_path = "data/LaMP/LaMP_3/train/train_questions.json"
    solution_path = "data/LaMP/LaMP_3/train/train_outputs.json"
    with open(file_path, 'r') as f:
        for line in f:
            data = json.loads(line)
    print("The number of user queries: ", len(data))
    with open(solution_path, 'r') as f:
        for line in f:
            solution = json.loads(line)
        solution = solution['golds']
    print("The number of user solutions: ", len(solution))
    data_list = []
    template = "What is the score of the following review on a scale of 1 to 5? just answer with 1, 2, 3, 4, or 5 without further explanation. review: {review}"
    id_list = []
    user_dict = {}
    user_id = 0
    avg_profile_num = 0
    for line, sol in zip(data, solution):
        # print(line.keys())
        # print(sol.keys())
        # print('--->', line['input'])
        # print('===>', line['profile'][0])
        # print('===>', sol['output'])
        # input()
        if line['id'] == sol['id']:
            id_list.append(line['id'])
            if line['id'] not in user_dict:
                user_dict[line['id']] = user_id
                user_id += 1
            data = {'id': line['id'], 'source': line['input'], 'target': sol['output']}
            data_list.append(data)
            profile_list = line['profile']
            for profile in profile_list:
                source = template.format(review=profile['text'])
                target = profile['score']
                data = {'id': line['id'], 'source': source, 'target': target}
                data_list.append(data)
            avg_profile_num += len(profile_list)
    print("The number of different users: ", len(list(set(id_list))))
    print("The number of user history data: ", len(data_list))
    print("The average number of user history data: ", avg_profile_num / len(list(set(id_list))))
    output_path = "data/LaMP/LaMP_3/train/train_new.json"
    with open(output_path, 'a') as f:
        for data in data_list:
            json.dump(data, f)
            f.write('\n')
    output_path = "data/LaMP/LaMP_3/train/train_user_dict.json"
    with open(output_path, 'w') as f:
        json.dump(user_dict, f)
elif args.task == 'LaMP_4':
    file_path = "data/LaMP/LaMP_4/train/train_questions.json"
    solution_path = "data/LaMP/LaMP_4/train/train_outputs.json"
    with open(file_path, 'r') as f:
        for line in f:
            data = json.loads(line)
    print("The number of user queries: ", len(data))
    with open(solution_path, 'r') as f:
        for line in f:
            solution = json.loads(line)
        solution = solution['golds']
    print("The number of user solutions: ", len(solution))
    data_list = []
    template = "Generate a headline for the following article: {article}"
    id_list = []
    user_dict = {}
    user_id = 0
    avg_profile_num = 0
    for line, sol in zip(data, solution):
        # print(line.keys())
        # print(sol.keys())
        # print('--->', line['input'])
        # print('===>', line['profile'][0])
        # print('===>', sol['output'])
        # input()
        if line['id'] == sol['id']:
            id_list.append(line['id'])
            if line['id'] not in user_dict:
                user_dict[line['id']] = user_id
                user_id += 1
            data = {'id': line['id'], 'source': line['input'], 'target': sol['output']}
            data_list.append(data)
            profile_list = line['profile']
            for profile in profile_list:
                source = template.format(article=profile['text'])
                target = profile['title']
                data = {'id': line['id'], 'source': source, 'target': target}
                data_list.append(data)
            avg_profile_num += len(profile_list)
    print("The number of different users: ", len(list(set(id_list))))
    print("The number of user history data: ", len(data_list))
    print("The average number of user history data: ", avg_profile_num / len(list(set(id_list))))
    output_path = "data/LaMP/LaMP_4/train/train_new.json"
    with open(output_path, 'a') as f:
        for data in data_list:
            json.dump(data, f)
            f.write('\n')
    output_path = "data/LaMP/LaMP_4/train/train_user_dict.json"
    with open(output_path, 'w') as f:
        json.dump(user_dict, f)
elif args.task == 'LaMP_5':
    file_path = "data/LaMP/LaMP_5/train/train_questions.json"
    solution_path = "data/LaMP/LaMP_5/train/train_outputs.json"
    with open(file_path, 'r') as f:
        for line in f:
            data = json.loads(line)
    print("The number of user queries: ", len(data))
    with open(solution_path, 'r') as f:
        for line in f:
            solution = json.loads(line)
        solution = solution['golds']
    print("The number of user solutions: ", len(solution))
    data_list = []
    template = "Generate a title for the following abstract of a paper: {paper}"
    id_list = []
    user_dict = {}
    user_id = 0
    avg_profile_num = 0
    for line, sol in zip(data, solution):
        # print(line.keys())
        # print(sol.keys())
        # print('--->', line['input'])
        # print('===>', line['profile'][0])
        # print('===>', sol['output'])
        # input()
        if line['id'] == sol['id']:
            id_list.append(line['id'])
            if line['id'] not in user_dict:
                user_dict[line['id']] = user_id
                user_id += 1
            data = {'id': line['id'], 'source': line['input'], 'target': sol['output']}
            data_list.append(data)
            profile_list = line['profile']
            for profile in profile_list:
                source = template.format(paper=profile['abstract'])
                target = profile['title']
                data = {'id': line['id'], 'source': source, 'target': target}
                data_list.append(data)
            avg_profile_num += len(profile_list)
    print("The number of different users: ", len(list(set(id_list))))
    print("The number of user history data: ", len(data_list))
    print("The average number of user history data: ", avg_profile_num / len(list(set(id_list))))
    output_path = "data/LaMP/LaMP_5/train/train_new.json"
    with open(output_path, 'a') as f:
        for data in data_list:
            json.dump(data, f)
            f.write('\n')
    output_path = "data/LaMP/LaMP_5/train/train_user_dict.json"
    with open(output_path, 'w') as f:
        json.dump(user_dict, f)

elif args.task == 'LaMP_7':
    file_path = "data/LaMP/LaMP_7/train/train_questions.json"
    solution_path = "data/LaMP/LaMP_7/train/train_outputs.json"
    with open(file_path, 'r') as f:
        for line in f:
            data = json.loads(line)
    print("The number of user queries: ", len(data))
    with open(solution_path, 'r') as f:
        for line in f:
            solution = json.loads(line)
        solution = solution['golds']
    print("The number of user solutions: ", len(solution))
    data_list = []
    template = "Generate a title for the following abstract of a paper: {paper}"
    id_list = []
    user_dict = {}
    user_id = 0
    avg_profile_num = 0
    for line, sol in zip(data, solution):
        print(line.keys())
        print(sol.keys())
        print('--->', line['input'])
        print('===>', line['profile'][0])
        print('===>', sol['output'])
        input()
        if line['id'] == sol['id']:
            id_list.append(line['id'])
            if line['id'] not in user_dict:
                user_dict[line['id']] = user_id
                user_id += 1
            data = {'id': line['id'], 'source': line['input'], 'target': sol['output']}
            data_list.append(data)
            profile_list = line['profile']
            for profile in profile_list:
                source = template.format(paper=profile['abstract'])
                target = profile['title']
                data = {'id': line['id'], 'source': source, 'target': target}
                data_list.append(data)
            avg_profile_num += len(profile_list)
    print("The number of different users: ", len(list(set(id_list))))
    print("The number of user history data: ", len(data_list))
    print("The average number of user history data: ", avg_profile_num / len(list(set(id_list))))
    output_path = "data/LaMP/LaMP_5/train/train_new.json"
    with open(output_path, 'a') as f:
        for data in data_list:
            json.dump(data, f)
            f.write('\n')
    output_path = "data/LaMP/LaMP_5/train/train_user_dict.json"
    with open(output_path, 'w') as f:
        json.dump(user_dict, f)