import argparse
import sys
from copy import deepcopy
from pathlib import Path

import gurobipy as gp
import matplotlib.pyplot as plt
import numpy as np
import simdjson as json
import tqdm.auto as tqdm
from gurobipy import GRB

from utils import PriorityQueue, load_data
from run_inference4 import lazy_optimize as lazy_init


def lazy_optimize(
    merges,
    pair_counts,
    byte_counts,
    verbose=True,
    num_merges=3000,
    competitor_batch_size=10,
    max_iters=50,
):
    model = gp.Model("tokenizer_attack")
    if not verbose:
        model.setParam("OutputFlag", 0)
    langs = list(pair_counts.keys())
    lang_v = {name: model.addVar(0, 1, name=name) for name in langs}
    # vmax = model.addVar(lb=0, name=f"violmax")
    violation_v = [model.addVar(lb=0, name=f"viol{i}") for i in range(num_merges)]
    model.addConstr(sum(lang_v.values()) == 1)
    # obj_v = model.addVar(name="obj")
    # model.setObjective(obj_v, GRB.MINIMIZE)
    # model.addConstr(sum(violation_v) <= obj_v)

    lang_vals = {l: 1 / len(pair_counts) for l in pair_counts}
    # lang_vals = lazy_init(
    #     merges, pair_counts, byte_counts, verbose, num_merges, prefix_skip=0
    # )["lang_vals"]
    violation_vals = [0 for _ in range(len(violation_v))]
    all_constraints = set()
    missing_merges = set()

    pair_violation_vs, pair_violation_vals = {}, {}
    for _ in range(max_iters):
        full_pair_counts = {l: {} for l in langs}
        combined_props = {}
        for l, p in lang_vals.items():
            for pair, count in pair_counts[l][0].items():
                combined_props.setdefault(pair, 0)
                combined_props[pair] += count / byte_counts[l] * p

        pq = PriorityQueue(combined_props.items())
        added = set()

        merge_subset = merges[:num_merges]
        if verbose:
            merge_subset = tqdm.tqdm(merge_subset, dynamic_ncols=True)

        active_set = [None] * len(merge_subset)
        for i, (merge, violation) in enumerate(zip(merge_subset, violation_vals)):
            for l in lang_vals.keys():
                full_pair_counts[l].update(pair_counts[l][i])

            if str(merge) not in pq:
                missing_merges.add(str(merge))
                print(f"Warn: could not find merge '{str(merge)}'")
            else:
                top = pq.lookup(str(merge))
                popped = []
                active_set[i] = []
                cutoff = top + max(0, violation)
                candidates = set()
                while len(candidates) < competitor_batch_size:
                    top, prop = pq.pop()
                    popped.append((top, prop))
                    active_set[i].append((top, prop))
                    if top == str(merge):
                        break
                    if prop > cutoff + max(0, pair_violation_vals.get(top, 0)):
                        assert top != str(merge), f"{pair_violation_vals.get(top, 0)}"
                        if (i, top) not in all_constraints:
                            candidates.add(top)

                added.update((i, c) for c in candidates)
                candidates.add(str(merge))
                cand_counts = {
                    cand: sum(
                        full_pair_counts[l].get(cand, 0) / byte_counts[l] * lang_v[l]
                        for l in langs
                    )
                    for cand in candidates
                }

                for cand in candidates:
                    if cand == str(merge):
                        continue
                    if cand not in pair_violation_vs:
                        pair_violation_vs[cand] = model.addVar(
                            lb=0, name=f"pair_viol{i}"
                        )
                    model.addConstr(
                        cand_counts[str(merge)] + violation_v[i]
                        >= cand_counts[cand] - pair_violation_vs[cand]
                    )
                    all_constraints.add((i, cand))

                for pair, count in popped:
                    pq.add(pair, count)

            if i == num_merges - 1:
                break

            diff = {}
            for l, p in lang_vals.items():
                nextpc = pair_counts[l][i + 1]
                for pair, count in nextpc.items():
                    diff.setdefault(pair, 0)
                    diff[pair] += (
                        (nextpc[pair] - full_pair_counts[l].get(pair, 0))
                        / byte_counts[l]
                        * p
                    )

            for pair, delta in diff.items():
                new_prop = combined_props.get(pair, 0) + delta
                pq.add(pair, new_prop)
                combined_props[pair] = new_prop

        if len(added) == 0:
            print("added no constraints -- exiting")
            break
        elif len(added) >= 10:
            print(f"added {len(added)} new constraints")
        else:
            print(f"added constraints {added}")

        # model.addConstr(sum(pair_violation_vs.values()) + sum(violation_v)  <= obj_v)
        model.setObjective(
            sum(pair_violation_vs.values()) + sum(violation_v), GRB.MINIMIZE
        )
        model.optimize()

        lang_vals = {lang: lang_v[lang].X for lang in langs}
        violation_vals = [v.X for v in violation_v]
        pair_violation_vals = {k: v.X for k, v in pair_violation_vs.items()}
        print(
            f"loss: {model.ObjVal} ({sum(violation_vals)}, {sum(pair_violation_vals.values())})"
        )
        print(
            dict(
                sorted(lang_vals.items(), key=lambda langfreq: -langfreq[1])[:10]
            )
        )

    return dict(
        lang_vals=lang_vals,
        violation_vals=violation_vals,
        pair_violation_vals=pair_violation_vals,
        missing_merges=missing_merges,
        active_set=active_set,
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(prog="TokenizerInference")
    parser.add_argument("data_root")
    parser.add_argument(
        "--merges", type=int, help="Number of merges to consider", default=1000
    )
    args = parser.parse_args()
    root = Path(args.data_root)
    if not (root / "merges.txt").exists():
        print("incomplete")
        sys.exit()

    if (root / "meta.json").exists():
        with (root / "meta.json").open() as f:
            meta = json.load(f)
            langs = meta["byte_count"].keys()
    else:
        langs = [subdir.name for subdir in root.iterdir()]

    for lang in langs:
        subdir = root / lang
        if "." in subdir.name:
            continue
        if not (subdir / "1e07/all_pair_counts.json").exists():
            print(f"incomplete: {lang}")
            sys.exit()
    data = load_data(root, verbose=True)
    num_merges = args.merges
    solution = lazy_optimize(*data, num_merges=num_merges)
    solution["lang_vals"] = dict(
        sorted(solution["lang_vals"].items(), key=lambda langfreq: langfreq[1])
    )
    solution["missing_merges"] = list(solution["missing_merges"])
    print(solution["lang_vals"])
    with (root / f"solution5.2_1e07_{num_merges}.json").open("w") as f:
        json.dump(solution, f)
