import argparse
import sys
from copy import deepcopy
from functools import partial
from pathlib import Path

import gurobipy as gp
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import simdjson as json
import tqdm.auto as tqdm
from gurobipy import GRB

from utils import PriorityQueue, load_data, sp_minimize


def lazy_optimize(
    merges,
    pair_counts,
    byte_counts,
    verbose=True,
    num_merges=1000,
    prefix_skip=500,
    zipf_power=1.15,
):
    langs = list(pair_counts.keys())
    apc = {l: {} for l in langs}
    terms = np.zeros((num_merges, len(langs)))
    P = partial(tqdm.tqdm, dynamic_ncols=True) if verbose else lambda x: x
    for i, m in enumerate(P(merges[:num_merges])):
        for l, pc in pair_counts.items():
            apc[l].update(pc[i])
        terms[i] = [apc[l].get(str(m), 0) / byte_counts[l] for l in langs]
    terms = terms[jnp.linalg.norm(terms, axis=1) > 1e-10]

    def loss(x):
        props = jax.nn.softmax(x)
        scale = jnp.arange(1, len(terms) + 1) ** zipf_power
        counts = (terms[prefix_skip:] @ props) * scale[prefix_skip:]
        return jnp.linalg.norm(jnp.diff(counts / counts.mean())) ** 2

    x0 = jnp.zeros(terms.shape[1])

    opt = sp_minimize(
        jax.value_and_grad(loss),
        x0,
        method="L-BFGS-B",
        options=dict(iprint=50) if verbose else {},
    )
    lang_vals = {l: v.item() for l, v in zip(langs, jax.nn.softmax(opt.x))}
    return dict(
        lang_vals=lang_vals,
        success=opt.success,
        loss=opt.fun.item(),
        iterations=opt.nit,
        status=opt.status,
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(prog="TokenizerInference")
    parser.add_argument("data_root")
    parser.add_argument(
        "--merges", type=int, help="Number of merges to consider", default=1000
    )
    args = parser.parse_args()
    root = Path(args.data_root)
    if not (root / "merges.txt").exists():
        print("incomplete")
        sys.exit()

    if (root / "meta.json").exists():
        with (root / "meta.json").open() as f:
            meta = json.load(f)
            langs = meta["byte_count"].keys()
    else:
        langs = [subdir.name for subdir in root.iterdir()]

    for lang in langs:
        subdir = root / lang
        if "." in subdir.name:
            continue
        if not (subdir / "all_pair_counts.json").exists():
            print(f"incomplete: {lang}")
            sys.exit()
    data = load_data(root, verbose=True)
    num_merges = args.merges
    solution = lazy_optimize(*data, num_merges=num_merges)
    solution["lang_vals"] = dict(
        sorted(solution["lang_vals"].items(), key=lambda langfreq: langfreq[1])
    )
    print(solution["lang_vals"])
    with (root / f"solution4_domains_{num_merges}.json").open("w") as f:
        json.dump(solution, f)
