
import time
import os
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm
from adaptive_softmax.sftm import SFTM
from tests.test_utils import construct_random_example
from adaptive_softmax.constants import (
    GAIN_POINTS,
    NUM_TRIALS,
    NORM_GAINS_DIR,
)

def estimate_uniform(A, X, eps, delta):
    n_arms, d = A.shape
    arm_budget = 1  
    true_log_norms = np.log(np.sum(np.exp(A @ X.T), axis=0))
    
    while True:
        mu_hats =  A[:, :arm_budget] @ X[:, :arm_budget].T * (d/arm_budget) # shape (n, q)
        norm_hats = np.log(np.sum(np.exp(mu_hats), axis=0))  # shape (q,)
   
        # if within error margin 1-delta%, break
        n_errors = np.sum(np.abs(norm_hats - true_log_norms) > eps)
        if n_errors < X.shape[0] * delta:
            break

        # otherwise, get more samples
        arm_budget = int(arm_budget * 2)
    
    return n_arms * arm_budget


def estimate_sftm(A, X, eps, delta):
    sftm = SFTM(
        A=A,
        temperature=1.0,
        multiplicative_error=eps,
        failure_probability=delta,
        atom_importance_sampling=False,
        query_importance_sampling=False,
        randomized_hadamard_transform=False,
        exact_pull_best_arm=False,
        max_init_pull_budget=1,
        verbose=False,
        seed=42,
    )

    total_budget = 0
    for x in X:
        sftm.bandits.set_query(x)
        sftm.log_norm_estimation(eps, delta, first_pull_batched=True)  # should i be setting this to false?
        total_budget += sum(sftm.bandits.it)
    
    return total_budget


def run_norm_gains(n, d, curr_time=None):
    if curr_time is None:
        curr_time = time.strftime("%H:%M:%S", time.gmtime())

    ideal_gains = []
    mus = []
    gain_ratios = []
    for c in tqdm(np.linspace(1, 5, GAIN_POINTS)):
        save_dir = f"{NORM_GAINS_DIR}/{curr_time}"
        os.makedirs(save_dir, exist_ok=True)

        mu =  np.ones(n)
        mu[1] = mu[1] * c
        mu = mu / (c/ n * 2)    # for constant sigma (this is arbitrary choice)
        mus.append(mu)
        A, X = construct_random_example(n=n, d=d, q=NUM_TRIALS, mu=mu)
        
        theory = n * np.sum(np.exp(2 * mu)) / (np.sum(np.exp(mu))**2)
        ratio = estimate_uniform(A, X, 0.3, 0.01) / estimate_sftm(A, X, 0.3, 0.01)

        ideal_gains.append(theory)
        gain_ratios.append(ratio)

    print(gain_ratios)

    # plt.plot(mus, mus, "b--", label="theoretical_baseline")
    # plt.scatter(ideal_gains, gain_ratios, color="red", label="empirical_gain_lv3")
    # plt.xlabel("theoretical gain")
    # plt.ylabel("empirical_gain")
    # plt.legend()
    # plt.savefig(f"{save_dir}/n{n}_d{d}.png")

if __name__ == "__main__":
    run_norm_gains(n=100, d=10000)
