import random
import gzip
import pickle
import os

class UpdateBuffer:
    def __init__(self, returns_dict: dict, age: int, max_start_states: int):
        self.age = age
        self.max_start_states = max_start_states
        self.returns_dict = returns_dict

    def __call__(self, reward: int, idx):
        # Sample an index
        if len(self.returns_dict) < self.max_start_states:
            self.add(reward)
            if idx > len(self.returns_dict):
                self.returns_dict[idx].append(reward)
        else:
            self.add(reward)
            sampled_idx = random.randint(0, self.age)
            if sampled_idx < self.max_start_states:
                self.returns_dict[sampled_idx] = [reward]
        self.update()

    def update(self):
        self.age += 1

    def add(self, rew: int):
        for key in self.returns_dict.keys():
            self.returns_dict[key].append(rew)

    def save(self, directory, logs, name, index=None):
        if self.returns_dict is not None:
            if index is not None:
                file_name = name + '_{}'.format(index)
            else:
                file_name = name

            with gzip.open(os.path.join(directory, file_name + '.pkl.gz'), "wb") as f_dict:
                pickle.dump(self.returns_dict, f_dict, protocol=pickle.HIGHEST_PROTOCOL)
        with open(os.path.join(directory, name + ".txt"), "a") as f_logs:
            f_logs.write(logs)







