#!/usr/bin/env python3


from dataclasses import dataclass, asdict
import functools
import random
import collections
import math
import os
import json
import gc

import hydra
from omegaconf import DictConfig
import torch
import time

from proofsearch import Policy, LMPolicy, make_policy, ProofStateNode, TreeSearchNode, \
    ProofSearchAgent, ProofSearchResult, \
    HolophrasmNode, UniformPolicy, MonteCarloTreeSearch, visualize_search_tree
from problems import ProblemSet, load_problemset
from util import setup_wandb, time_limit, format_blocks_with_indent
from curiosity import CuriositySignal, make_curiosity_signal

import peano


def _compute_depth(proof_state_node: peano.PyProofState) -> int:
    construction_history = proof_state_node.construction_history()


class PreTrainingNode(ProofStateNode):
    def __init__(self, wrapped_node,
                 curiosity_signal: CuriositySignal,
                 parent=None,
                 allow_backward=None,
                 logger=None):
        self._wrapped_node = wrapped_node
        self._curiosity_signal = curiosity_signal
        self._parent = parent
        self._logger = logger
        self._allow_backward = (allow_backward
                                if allow_backward is not None
                                else (parent and parent._allow_backward))

    @property
    def _proof_states(self):
        return self._wrapped_node._proof_states

    def clone(self) -> ProofStateNode:
        return PreTrainingNode(self._wrapped_node.clone(),
                               self._curiosity_signal,
                               self._parent,
                               self._allow_backward,
                               self._logger)

    def goal(self) -> str:
        return self._wrapped_node.goal()

    def is_terminal(self) -> bool:
        return self._wrapped_node.is_terminal()

    def is_conjunctive(self) -> bool:
        return self._wrapped_node.is_conjunctive()

    @functools.cached_property
    def actions(self) -> list:
        return [a for a in self._wrapped_node.actions
                if self._allow_backward or not a.is_apply()]

    def expand(self, action):
        new_wrapped_node = self._wrapped_node.expand(action)

        return PreTrainingNode(new_wrapped_node,
                               curiosity_signal=self._curiosity_signal,
                               parent=self,
                               logger=self._logger)

    def __str__(self):
        return str(self._wrapped_node)

    def reconstruct_proof(self, actions, is_root=True) -> list:
        return self._wrapped_node.reconstruct_proof(actions, is_root)

    def reward(self) -> float:
        reward = self._curiosity_signal.reward(self._wrapped_node)

        if self._logger is not None:
            self._logger.log_deduction(self._wrapped_node.last_construction_dtype(),
                                       reward)

        return reward


class PretrainingDataGenerator:
    def sample_episodes(self, problemset: ProblemSet) -> list['TreeSearchNode']:
        raise NotImplementedError


class ForwardChainingPretraining:
    def __init__(self, config: DictConfig, logger=None):
        self._expansions = config.expansions
        self._logger = logger

    def generate_training_examples(self, problemset: ProblemSet, problems: list[str], policy: Policy) -> list['TreeSearchNode']:
        examples = []

        for p in problems:
            initial_state = problemset.initialize_problem(p)
            roots = self._sample_starting_states(initial_state, policy)

            for root in roots:
                self.expand(root)

                goals = self._sample_goals(root, 1000)

                for node, goal, _ in goals:
                    path = node.hindsight_relabel(goal, root, [])

                    if path is not None:
                        examples.extend(policy.extract_examples_from_path(path))
                    else:
                        raise ValueError('Path is None...')

        return examples

    def _sample_goals(self, root: TreeSearchNode, n: int) -> list[TreeSearchNode]:
        candidates = []

        for node in root:
            if node.state_node._wrapped_node.is_terminal():
                goal = node._parent[0].state_node._wrapped_node._proof_states[0].goal()
            else:
                goal = node.state_node._wrapped_node._proof_states[0].last_proven_proposition()

            if goal:
                candidates.append((node, goal, node.state_node._novelty * node.state_node._construction_depth))

        candidates.sort(key=lambda c: c[-1], reverse=True)
        return candidates[:n]

    def _sample_propositions(self, root: TreeSearchNode) -> list[TreeSearchNode]:
        candidates = []

        for node in root:
            dtype = node.state_node._wrapped_node._construction_dtype

            if dtype is not None:
                candidates.append(dtype)

        return candidates


    def _sample_starting_states(self, initial_state, policy):
        root = TreeSearchNode(PreTrainingNode(
            make_curiosity_signal(),
            HolophrasmNode([initial_state]),
            allow_backward=True,
            logger=self._logger,
        ))
        mcts = MonteCarloTreeSearch(UniformPolicy({}))
        mcts.evaluate(root)

        shortest_path_to_goal = {}

        for node in root:
            if node.state_node.is_conjunctive():
                if node.is_leaf():
                    node.expand()

                for c in node._children:
                    goal = c.state_node.goal()
                    path_from_root = c.get_path_from_root()

                    if shortest_path_to_goal.get(goal, (None, len(path_from_root) + 1))[1] > len(path_from_root):
                        shortest_path_to_goal[goal] = (c, len(path_from_root))

        roots = []
        for r in [root] + [node for node, depth in shortest_path_to_goal.values()]:
            for bonus in [1]:
                roots.append(TreeSearchNode(PreTrainingNode(
                    HolophrasmNode(r.state_node._wrapped_node._proof_states),
                    pi=policy,
                    allow_backward=False,
                    parent=r.state_node._parent,
                    logger=self._logger,
                    ), parent=r._parent))
                roots[-1].update_parent_link()

        return roots

    def expand(self, root: TreeSearchNode, exploration_prefix: list = None):
        mcts = MonteCarloTreeSearch(UniformPolicy({}),
                                    budget=self._expansions,
                                    exploration_prefix=exploration_prefix)
        mcts.evaluate(root)

        if root.is_solved():
            print('Root is solved!')

    def make_root(self, initial_state, path=[], allow_backward=False) -> TreeSearchNode:
        root = TreeSearchNode(PreTrainingNode(
            HolophrasmNode([initial_state]),
            allow_backward=True,
            logger=self._logger,
        ))

        for a in path:
            root.expand()
            next = None
            for c in root.children():
                if str(c._parent[1]) == a:
                    next = c
                    break
            if next is None:
                raise ValueError(f'Impossible action {a} in path {" => ".join(path)}')
            root = next

        root.state_node._allow_backward = allow_backward
        return root


class CuriosityGuidedProofSearchAgent:
    def __init__(self, config: DictConfig, logger=None, policy=None):
        self._max_mcts_nodes = config.max_mcts_nodes
        self._config = config
        self._policy = policy
        self._curiosity = None
        self._iterations = config.get('iterations', 1)
        self._stages = config.get('stages', 'single')
        self._mcts_time_limit = config.get('mcts_time_limit', 60*30)
        self._persist_policy = config.get('persist_policy', True)
        self._logger = logger
        self._examples = []
        # HACK: Option for debugging -- write intermediat results to a file.
        self._dump_results = config.get('dump_result', False)

    def _initialize_root(self, state):
        return TreeSearchNode(PreTrainingNode(
            HolophrasmNode([state]),
            self._curiosity,
            allow_backward=True,
            logger=self._logger,
        ))

    def proof_search(self, problem, state):
        success = False
        on_expand = None

        # Initialize if either we haven't done so yet, or if we're not persisting
        # the policy across proof searches.
        if self._policy is None or not self._persist_policy:
            self._policy = make_policy(self._config.policy)
            self._curiosity = make_curiosity_signal(self._config.curiosity, self._policy)

        new_examples = []

        for i in range(self._iterations):
            if self._stages == 'backward-forward':
                fc = ForwardChainingPretraining(self._config, logger=self._logger)
                ss = fc._sample_starting_states(state, self._policy)
                print(len(ss), 'starting states.')
                expander = fc
            else:
                print('Running single-stage exploration from root.')
                ss = [self._initialize_root(state)]
                mcts = MonteCarloTreeSearch(UniformPolicy({}), self._max_mcts_nodes)
                expander = mcts

            for st in ss:
                st_path = st.get_path_from_root()
                print('Starting state:', st_path)

                if self._logger is not None:
                    self._logger.set_context({'problem': problem,
                                              'iteration': i,
                                              'starting_state': str(st_path)})
                    on_expand = lambda path: self._logger.log_mcts_expansion(path)

                try:
                    gc.collect()

                    with time_limit(self._mcts_time_limit):
                        expander.expand(st, on_expand=on_expand)
                except TimeoutError:
                    print('MCTS timed out.')

                # if self._policy:
                #    new_examples.extend(self._policy.extract_examples(st))

                if st.is_solved():
                    print('st was solved!')

                proof = None

                if st.get_root().is_solved():
                    print('Problem solved!')
                    print('begin proof', problem)
                    proof = format_blocks_with_indent(st.get_root().reconstruct_proof())
                    print(proof)
                    print('end proof', problem)
                    success = True
                    break

            if not success and i + 1 < self._iterations:
                print('Not solved, training likelihood model...')
                self.train()

            if success:
                break

        self._examples.extend(new_examples)

        if self._dump_results:
            with open('examples.json', 'w') as f:
                json.dump(self._examples, f)

                with open('proofs.jsonl', 'a') as out:
                    out.write(json.dumps({'problem': problem, 'proof': proof}))
                    out.write('\n')

        return ProofSearchResult(problem, success, st.get_root(), new_examples, None)

    def train(self):
        self._policy.train([e['str'] for e in self._examples])

    def set_logger(self, logger):
        self._logger = logger


def get_path_statistics(node, path) -> list[int]:
    counts = []

    for action in path:
        new_node = None

        for c in node.children() or []:
            if str(c._parent[1]).startswith(action):
                new_node = c
                break

        if not new_node:
            counts.append((0, None))
            break

        counts.append((action, new_node._visits, new_node.state_node._novelty))
        node = new_node

    return counts


class ExplorationLogger:
    '''Class that keeps track and computes statistics of MCTS runs.'''
    def __init__(self):
        self._records = []
        self._current_context = {}

    def set_context(self, ctx: dict):
        self._current_context = {**self._current_context, **ctx}

    def log_deduction(self, dtype, reward):
        self._records.append({
            'type': 'deduction',
            'context': self._current_context,
            'deduction': dtype,
            'reward': reward
        })

    def log_mcts_expansion(self, path):
        self._records.append({
            'type': 'expansion',
            'context': self._current_context,
            'path': path,
        })

    def dump_records(self) -> list:
        return self._records.copy()


def evaluate_exploration_policy():
    p = load_problemset('nng')

    fc = ForwardChainingPretraining(DictConfig({'expansions': 4000}))

    for problem in ['a_succ_add']:
        root = fc.make_root(p.initialize_problem(problem),
                            ['intro.',
                                # 'intro.',
                             'a nat_ind',
                             "=> (= (+ (s x) z) (s (+ x z))); [('n : nat) -> (= (+ (s x) 'n) (s (+ x 'n))) -> (= (+ (s x) (s 'n)) (s (+ x (s 'n))))].",
                            # "=> (= (+ (+ x x0) z) (+ x (+ x0 z))); [('n : nat) -> (= (+ (+ x x0) 'n) (+ x (+ x0 'n))) -> (= (+ (+ x x0) (s 'n)) (+ x (+ x0 (s 'n))))].",
                             '0',
                             ],
                            )

        breakpoint()

        fc.expand(root)

        visit_counts_on_hardest_path = get_path_statistics(
            root,
            [
                'c +_z',
                '',
                'c +_z',
                '',
                'c eq_symm',
                '=> (= x (+ x z))',
                'c rewrite',
                '=> (= (+ (s x) z) (s (+ x z)))'
            ],
            #[
            #    'intro.',
            #    'intro.',
            #    'c +_s',
            #    '=> (= (+ (s x) (s x0)) (s (+ (s x) x0)))',
            #    'c rewrite',
            #    '=> (= (+ (s x) (s x0)) (s (s (+ x x0))))',
            #    'c +_s',
            #    '=> (= (+ x (s x0)) (s (+ x x0)))',
            #    'c eq_symm',
            #    '=> (= (s (+ x x0)) (+ x (s x0)))',
            #    'c rewrite',
            #    '=> (= (+ (s x) (s x0)) (s (+ x (s x0))))'
            #]
        )

        print('Visit counts on', problem)
        print(visit_counts_on_hardest_path)

        visualize_search_tree(root, os.path.join(os.path.dirname(__file__), f'pretraining_{problem}.dot'))


def test():
    fc = ForwardChainingPretraining(DictConfig({'expansions': 5000}))
    p = load_problemset('nng')

    pi = LMPolicy(DictConfig({'value_prior_weight': 0,
                              'max_pos_neg_ratio': 10,
                              'batch_size': 2000,
                              'train_iterations': 10**5,
                              }))

    examples = fc.generate_training_examples(p, ['a_succ_add', 'a_add_assoc'], pi)

    print(len(examples), 'training examples extracted.')

    with open('examples.txt', 'w') as f:
        f.write('\n'.join(examples))

    print('Wrote examples to examples.txt')

    import wandb
    wandb.init(project='peano')

    print('Training...')
    pi.train(examples, verbose=True)

    torch.save(pi, 'pretrained.pt')
    print('Saved pretrained.pt')


def exploration_iteration(cfg: DictConfig):
    import problems
    ps = problems.load_problemset('nng')

    logger = ExplorationLogger()

    if cfg.get('pretrained_agent'):
        agent = torch.load(cfg.pretrained_agent)
        agent.set_logger(logger)
    else:
        agent = CuriosityGuidedProofSearchAgent(cfg.agent, logger=logger)

    problems = cfg.problems

    for p in problems:
        print('Problem:', p)
        state = ps.initialize_problem(p)
        result = agent.proof_search(p, state)
        print('Success?', result.success)

    with open('exploration.log', 'w') as out:
        json.dump(logger.dump_records(), out)


def test_proof_search_agent():
    import problems
    ps = problems.load_problemset('nng')
    state = ps.initialize_problem('a_zero_add')
    agent = CuriosityGuidedProofSearchAgent(DictConfig({'expansions': 3000,
                                                        'max_mcts_nodes': 5000}))
    result = agent.proof_search('a_zero_add', state)
    print('Success?', result.success)


@hydra.main(version_base="1.2", config_path="config", config_name="pretraining")
def main(cfg: DictConfig):
    print('Running from', os.getcwd())

    if cfg.task == 'generate':
        test()
    elif cfg.task == 'iterate':
        setup_wandb(cfg)
        exploration_iteration(cfg)
    elif cfg.task == 'proofsearch':
        test_proof_search_agent()
    elif cfg.task == 'eval':
        evaluate_exploration_policy()

if __name__ == '__main__':
    main()
