import requests
from bs4 import BeautifulSoup
import argparse
import orjson, os
from tqdm import tqdm
import random
import time

TASK_LIST = "/cmlscratch/dengch/w2s/Code_new/gdown/cml_solution/task_list.jsonl"
SAVE_DIR = "/fs/cml-projects/E2H/Codeforces/contest_solution/contest_solution_json_local"
BLACK_LIST = [1275, 1252, 751, 1482, 1250, 1773, 1578, 1302, 1402]

### Only repeat this part when the session is expired


def generate_sucessful_wait():
    return random.uniform(10, 20)


def generate_failure_wait():
    return random.uniform(1000, 1200)


def get_keys(contest_id, wait_time):
    headers = {
        "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36",
    }

    try:
        session = requests.Session()
        response = session.get(
            f"https://codeforces.com/contest/{contest_id}/status", headers=headers
        )
        response.raise_for_status()  # Raise an exception for 4xx or 5xx status codes

        # Extract the JSESSIONID cookie from the response
        jsessionid = session.cookies.get("JSESSIONID")
        print("JSESSIONID:", jsessionid)

        # Extract the CSRF token from the response
        soup = BeautifulSoup(response.text, "html.parser")
        csrf_token = soup.find("meta", {"name": "X-Csrf-Token"})["content"]
        time.sleep(generate_sucessful_wait())
        return jsessionid, csrf_token

    except requests.exceptions.RequestException as e:
        print("An error occurred in getting key:", e)
        print(f"Sleep {wait_time} seconds")
        time.sleep(wait_time)
        return None, None


### This part is for getting the submission details

def crawl_solution(contest_id, submission_id, problem_id, candidate_id, jsession_id, csrf_token):
    headers = {
        "Referer": f"https://codeforces.com/contest/{contest_id}/submission/{submission_id}",
        "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36",
    }

    cookies = {
        "JSESSIONID": jsession_id,
    }

    payload = {
        "submissionId": str(submission_id),
        "csrf_token": csrf_token,
    }

    try:
        response = requests.post(
            "https://codeforces.com/data/submitSource",
            headers=headers,
            data=payload,
            cookies=cookies,
        )
        response.raise_for_status()  # Raise an exception for 4xx or 5xx status codes

        # Check the content encoding and decode accordingly
        content_encoding = response.headers.get("Content-Encoding")
        data = response.json()
        assert "error" not in data.keys(), "crawling error"
        #pprint.pprint(data)
        with open(f"{SAVE_DIR}/contest_{contest_id}_problem_{problem_id}_solution_{candidate_id}_{submission_id}.json", "w") as wf:
            json_line = orjson.dumps(data, option=orjson.OPT_NAIVE_UTC | orjson.OPT_SERIALIZE_NUMPY)
            wf.write(f"{str(json_line, encoding='utf-8')}")
        time.sleep(generate_sucessful_wait())
        return True
    except Exception as e:
        wait_time = generate_sucessful_wait()
        print(f"An error occurred in crawling {submission_id}:", e)
        print(f"Sleep {wait_time} seconds")
        time.sleep(wait_time)
        return False


def main(chunk_id=0, max_retry=1):
    while True:
        with open(TASK_LIST, "r") as f:
            a = f.read().splitlines()
            total_tasks = [orjson.loads(task_json) for task_json in a]
            num_total = len(total_tasks)
            tasks = total_tasks#[num_total//9*chunk_id:min(num_total, num_total//9*(chunk_id+1))]
        finish_list = os.listdir(SAVE_DIR)
        solution_list = []
        for task in tasks:
            contestId = task["contestId"]
            index = task["index"]
            status = task["status"]
            count = task["count"]
            if f"contest_{contestId}_problem_{index}_solution_{count}_{status}.json" not in finish_list and status!=-1:
                solution_list += [{
                    "contestId":contestId,
                    "status":status,
                    "index":index,
                    "count":count
                },]
        if len(solution_list)==0:
            break
        random.shuffle(solution_list)

        jsession_id = None
        csrf_token = None
        for solution in tqdm(solution_list):
            for n in range(max_retry):
                key_contest_id = chunk_id*20+37
                while not jsession_id:
                    jsession_id, csrf_token = get_keys(contest_id=key_contest_id, wait_time=generate_failure_wait())
                    if jsession_id and csrf_token:
                        print(f"Update: jsession_id {jsession_id} csrf_token {csrf_token}")
                    else:
                        key_contest_id += random.randint(1,2)

                if crawl_solution(
                    solution["contestId"], solution["status"], solution["index"], solution["count"], jsession_id, csrf_token
                ):
                    break
                else:
                    jsession_id = None
                    csrf_token = None


if __name__ == "__main__":
    argparser = argparse.ArgumentParser()
    argparser.add_argument(
        "--chunk_id",
        type=int,
        default=0,
    )
    args = argparser.parse_args()

    main(chunk_id=args.chunk_id)
