import numpy as np
import scipy.stats as stats


class Environment:
    def __init__(self, item_utilities, mean_delay, delay_threshold, upper_bound_delay=None,
                 skew="left"):
        self.item_utilities = item_utilities
        self.current_time = 0
        self.mean_delay = mean_delay
        self.delay_threshold = delay_threshold
        self.upper_bound_delay = upper_bound_delay
        self.skew = skew
        self.feedback_release_times = {i: [] for i in range(len(item_utilities))}

    def simulate_customer_choice(self, assortment):
        assortment_utilities = [self.item_utilities[i] for i in assortment]

        probabilities = assortment_utilities / np.sum(assortment_utilities)
        chosen_item_index = np.random.choice(assortment, p=probabilities)
        return chosen_item_index

    def calculate_delay(self, ):
        delay = np.random.exponential(scale=self.mean_delay)

        if self.skew == "uniform":
            delay = stats.uniform.rvs(loc=0, scale=2*self.mean_delay)

        return min(delay, self.upper_bound_delay) if self.upper_bound_delay else delay

    def get_decision(self, assortment):
        chosen_item = self.simulate_customer_choice(assortment=assortment)
        release_time = self.current_time + self.calculate_delay(chosen_item)
        self.feedback_release_times[chosen_item] = release_time  # Store release time for feedback
        return chosen_item

    def step(self, assortment):
        self.current_time += 1
        immediate_decision = None
        delayed_decisions = []

        chosen_item = self.simulate_customer_choice(assortment=assortment)
        if chosen_item == 0:
            immediate_decision = True
        else:
            delay = self.calculate_delay()
            if delay <= self.delay_threshold:
                release_time = self.current_time + self.calculate_delay()
                self.feedback_release_times[chosen_item].append(release_time)

        for item, release_times in list(self.feedback_release_times.items()):
            for release_time in release_times:
                if release_time <= self.current_time:
                    delayed_decisions.append(item)
                    self.feedback_release_times[item].remove(release_time)

        return immediate_decision, delayed_decisions

    def reset(self):
        self.current_time = 0
        self.feedback_release_times = {}
