#!/usr/bin/env python3

import argparse
import os
import json
import collections

import torch
import numpy as np

import proofsearch
import peano


def evaluate_agent(
        theory_path: str,
        problems_path: str,
        run_path: str,
        output_path: str):

    with open(theory_path, 'r') as f:
        theory = f.read()

    with open(theory_path + '.premises', 'r') as f:
        premises = f.read().strip().split('\n')

    with open(problems_path, 'r') as f:
        problems_txt = f.read()

    # Load problems.
    problems = []

    for line in problems_txt.strip().split('\n'):
        problem_id, statement = line.split('.', 1)
        problem_id, statement = problem_id.strip(), statement.strip()
        problems.append((problem_id, statement))

    # Load checkpoints.
    checkpoints = []

    while os.path.exists(os.path.join(run_path, f'{len(checkpoints)}.pt')):
        chkpt = torch.load(os.path.join(run_path, f'{len(checkpoints)}.pt'))
        checkpoints.append(chkpt)

    outcomes = []
    existing_keys = set()

    if os.path.exists(output_path):
        with open(output_path, 'r') as f:
            outcomes = json.load(f)
        for o in outcomes:
            existing_keys.add((o['checkpoint'], o['problem']))

    for i, agent in enumerate(checkpoints):
        print(f'# Checkpoint {i}')
        success = []

        for problem_id, statement in problems:
            key = (i, problem_id)

            if key in existing_keys:
                continue

            state = peano.PyProofState(theory, premises, statement)

            agent._max_mcts_nodes = 10000
            agent_result = agent.proof_search(statement, state)

            if agent_result.success:
                print(f'Problem {problem_id} solved by checkpoint {i}')
                proof = agent_result.root.state_node.reconstruct_proof(
                    agent_result.root.get_solution_actions())
                solution_actions = agent_result.root.get_solution_actions()
                logprob = agent_result.root.solution_logprob_under_policy(agent._policy, solution_actions)
            else:
                proof, solution_actions, logprob = None, None, None

            outcomes.append({
                'checkpoint': i,
                'problem': problem_id,
                'success': agent_result.success,
                'proof': proof,
                'logprob': logprob
            })

            with open(output_path, 'w') as f:
                json.dump(outcomes, f)

    success = collections.defaultdict(list)

    for o in outcomes:
        success[o['checkpoint']].append(o['success'])

    for k, v in success.items():
        print(k, np.mean(v))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--theory', type=str, required=True)
    parser.add_argument('--problems', type=str, required=True)
    parser.add_argument('--run', type=str, required=True)
    parser.add_argument('--output', type=str, required=True)
    args = parser.parse_args()

    evaluate_agent(args.theory, args.problems, args.run, args.output)
