import asyncio
from dataclasses import dataclass
from typing import Optional

import munch
import numpy as np
import openai
import tiktoken

tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
aclient = None


@dataclass
class Consumption:
    prompt_tokens: int = 0
    completion_tokens: int = 0
    requests: int = 0

    def reset(self):
        self.prompt_tokens = 0
        self.completion_tokens = 0
        self.requests = 0

    def cost(self, model="gpt-3.5-turbo-instruct"):
        return (
            0.0015 * self.prompt_tokens / 1000 + 0.0020 * self.completion_tokens / 1000
        )


consumption = Consumption()


async def query_logprobs(
    prompt, bias: dict[int, float] = {}, **kwargs
) -> dict[int, float]:
    # Helper function that allows us to only ever deal with tokens as ints
    def encode_token(tok):
        try:
            return tokenizer.encode_single_token(
                tok[6:].encode("latin-1").decode("unicode_escape").encode("latin-1")
                if tok.startswith("bytes:")
                else tok.encode()
            )
        except KeyError:
            if tok != "<|diff_marker|>":
                print(f"KeyError: {tok!r}")
            return -1

    def process_completion(compl):
        return [
            {encode_token(tok): logprob for tok, logprob in tlp.items()}
            for tlp in compl.logprobs.top_logprobs
        ]

    def process_tokens(compl):
        return [encode_token(tok) for tok in compl.logprobs.tokens]

    output = await aclient.completions.create(
        prompt=prompt,
        logit_bias=bias,
        **{
            **dict(
                model="gpt-3.5-turbo-instruct",
                max_tokens=1,
                temperature=0,
                logprobs=5,
            ),
            **kwargs,
        },
    )

    consumption.requests += 1

    if output.usage.completion_tokens is not None:
        consumption.completion_tokens += output.usage.completion_tokens

    if output.usage.prompt_tokens is not None:
        consumption.prompt_tokens += output.usage.prompt_tokens

    tokens = [process_tokens(choice) for choice in output.choices]
    logprobs = [process_completion(choice) for choice in output.choices]

    if "n" in kwargs:

        def list_reshape(lst):
            return list(map(list, zip(*[iter(lst)] * kwargs["n"])))

        tokens = list_reshape(tokens)
        logprobs = list_reshape(logprobs)

    if isinstance(prompt, str) or (
        isinstance(prompt, list) and (len(prompt) == 0 or isinstance(prompt[0], int))
    ):
        assert len(tokens) == 1
        assert len(logprobs) == 1
        tokens, logprobs = tokens[0], logprobs[0]

    return munch.Munch(tokens=tokens, logprobs=logprobs)


def infer_logprob_single(biased_logprob: float, bias: float) -> float:
    """Return the true logprob of a token given logprob with logit_bias applied"""
    # NOTE: This only works if the biased_logprob is the only bias logprob in the call!
    # NOTE: Numerical accuracy degrades if biased_logprob is too close to 0
    return biased_logprob - np.logaddexp(
        biased_logprob, bias + np.log(-np.expm1(biased_logprob))
    )


async def test_recovery(bias):
    biased_logprob = (
        await query_logprobs(
            prompt="this is a test of a long prompt", max_tokens=1, bias={271: bias}
        )
    ).logprobs[0][271]
    print(biased_logprob)
    return infer_logprob_single(biased_logprob, bias)


BATCH_SIZE = 20
MAX_LOGPROB = -1e-3  # largest value we will tolerate passing to infer_logprob_single
BIAS_SHIFT_UP = 10
BIAS_SHIFT_DOWN = 5


async def score_single_uncorrected_batched(
    prompts: list[list[int]], target: int, bias: float = 0
) -> list[Optional[float]]:
    """Score a batch of prompts with a single token suffix, do not correct bias applied"""
    if not prompts:
        return []
    batch = prompts[:BATCH_SIZE]
    query, next_batch = await asyncio.gather(
        query_logprobs(prompt=batch, max_tokens=1, bias={target: bias}),
        score_single_uncorrected_batched(prompts[BATCH_SIZE:], target, bias),
    )
    return [
        None if not lps else lps[0].get(target) for lps in query.logprobs
    ] + next_batch


async def score_single_corrected_batched(
    prompts: list[list[int]], target: int, bias: float = 0
) -> list[float]:
    """Score a batch of prompts with a single token suffix, apply correction for the bias"""
    if not prompts:
        return []

    logprobs = await score_single_uncorrected_batched(prompts, target, bias)
    corrected_logprobs = [None] * len(logprobs)

    below_min_idxs, below_min_prompts = [], []
    above_max_idxs, above_max_prompts = [], []
    for i, (logprob, prompt) in enumerate(zip(logprobs, prompts)):
        if logprob is None:
            below_min_idxs.append(i)
            below_min_prompts.append(prompt)
            continue

        if logprob > MAX_LOGPROB:
            above_max_idxs.append(i)
            above_max_prompts.append(prompt)
            continue

        corrected_logprobs[i] = infer_logprob_single(logprob, bias)

    # if below_min_idxs:
    #     print(f"below {below_min_idxs}")

    # if above_max_idxs:
    #     print(f"above {above_max_idxs}")

    below_min_scores, above_max_scores = await asyncio.gather(
        score_single_corrected_batched(below_min_prompts, target, bias + BIAS_SHIFT_UP),
        score_single_corrected_batched(
            above_max_prompts, target, bias - BIAS_SHIFT_DOWN
        ),
    )

    for i, logprob in zip(below_min_idxs, below_min_scores):
        corrected_logprobs[i] = logprob

    for i, logprob in zip(above_max_idxs, above_max_scores):
        corrected_logprobs[i] = logprob

    return corrected_logprobs


async def score_batch(
    prompts: list[list[int]], suffix: list[int], initial_bias=None
) -> list[list[float]]:
    """Score a batch of prompts with an arbitrary length suffix, apply correction for the bias"""
    if not suffix:
        return [[] for _ in range(len(prompts))]
    initial_bias = [0] * len(suffix) if initial_bias is None else initial_bias
    shifted_prompts = [prompt + [suffix[0]] for prompt in prompts]

    current_score, rest_score = await asyncio.gather(
        score_single_corrected_batched(prompts, suffix[0], initial_bias[0]),
        score_batch(shifted_prompts, suffix[1:], initial_bias[1:]),
    )
    return [[c] + r for c, r in zip(current_score, rest_score)]


async def score_batch_limit(
    prompts: list[list[int]], suffix: list[int], initial_bias=None, limits=None
) -> list[list[float]]:
    n, k = len(prompts), len(suffix)
    # print(f"sbl {n}")
    if not n:
        return []
    if not k:
        return [[] for _ in range(len(prompts))]

    if isinstance(limits, (float, int)):
        limits = [limits] * n

    initial_bias = [0] * k if initial_bias is None else initial_bias

    current_scores = await score_single_corrected_batched(
        prompts, suffix[0], initial_bias[0]
    )

    results = [None] * n

    cand_idxs, cand_prompts, cand_limits = [], [], []
    for i, (score, prompt, limit) in enumerate(zip(current_scores, prompts, limits)):
        results[i] = [score]
        if score < limit:
            results[i] += [-float("inf")] * (k - 1)
        else:
            cand_idxs.append(i)
            cand_prompts.append(prompt + [suffix[0]])
            cand_limits.append(limit - score)

    cand_scores = await score_batch_limit(
        cand_prompts, suffix[1:], initial_bias[1:], cand_limits
    )
    for i, cand_score in zip(cand_idxs, cand_scores):
        results[i] += cand_score

    return results
