import numpy as np
import random
import torch
from learner_diag import GP_NTK, Network
from misc import N_TOKENS


class EVO:

    def __init__(self, evo_opts):
        self.k = 0
        self.api = None
        self.max_iters = evo_opts['maxiter']
        self.N = evo_opts['N']
        self.algo = evo_opts['algo']

    def roulette_wheel_selection(self, population):
        """
        Selects two prompts from the population using the roulette wheel selection method.

        Args:
            population (list): A list of tuples where each tuple contains an instruction and its score.

        Returns:
            tuple: A tuple containing two selected parent prompts.
        """
        total_score = sum(score for _, score in population)
        if total_score == 0:
            raise ValueError(
                "Total score cannot be zero for roulette wheel selection.")

        selection_probabilities = [
            score / total_score for _, score in population]
        selected_indices = random.choices(
            range(len(population)), weights=selection_probabilities, k=2)
        selected_prompts = (
            population[selected_indices[0]][0], population[selected_indices[1]][0])

        return selected_prompts[0], selected_prompts[1]

    def random_selection(self, population):
        selected_indices = np.random.choice(len(population), 2, replace=False)
        return population[selected_indices[0]][0], population[selected_indices[1]][0]

    def best_prompt(self, population):
        return max(range(len(population)), key=lambda i: population[i][1])

    def stop(self):
        """whether the query budget is met"""
        return self.k >= self.max_iters
