from typing import List

import numpy as np
import torch
import random

from gflownet.config import Config


class ReplayBuffer(object):
    def __init__(self, cfg: Config, rng: np.random.Generator = None):
        self.capacity = cfg.replay.capacity
        self.warmup = cfg.replay.warmup
        assert self.warmup <= self.capacity, "ReplayBuffer warmup must be smaller than capacity"
        self.method = cfg.replay.method
        self.priorities = []

        self.buffer: List[tuple] = []
        self.position = 0
        self.rng = rng

    def push(self, *args):
        if len(self.buffer) == 0:
            self._input_size = len(args)
        else:
            assert self._input_size == len(args), "ReplayBuffer input size must be constant"
        if self.method == 'Random':
            traj, reward = args 
            if len(self.buffer) < self.capacity:
                self.buffer.append(None)
            self.buffer[self.position] = (traj, reward)
            self.position = (self.position + 1) % self.capacity
        elif self.method == 'Prioritized':
            traj, reward = args 
            if len(self.buffer) < self.capacity or args[1] > self.buffer[0][1]: #Checking if the current traj reward is greater than the least traj reward in the buffer 
                self.buffer.append((traj, reward))
                self.buffer = sorted(self.buffer, key = lambda rew: rew[1])[-self.capacity:] #Adding the traj to the buffer and sorting the buffer based on the traj reward
        elif self.method == 'PER':
            traj, reward, priority = args 
            if len(self.buffer) < self.capacity:
                self.buffer.append((traj, reward))
                self.priorities.append(priority)
            else:
                # Replace the experience with the lowest priority
                min_priority_index = min(range(len(self.priorities)), key=lambda idx: self.priorities[idx])
                self.buffer[min_priority_index] = (traj, reward)
                self.priorities[min_priority_index] = priority


    def sample(self, batch_size):
        if self.method == 'Random':
            idxs = self.rng.choice(len(self.buffer), batch_size)
            out = list(zip(*[self.buffer[idx] for idx in idxs]))
            # print("Sample shape: ",np.array(out).shape)
        elif self.method == 'Prioritized':
            idxs = self.rng.choice(len(self.buffer), batch_size)
            out = list(zip(*[self.buffer[idx] for idx in idxs]))
        elif self.method == 'PER':
            # Sampling based on priorities
            total_priority = sum(self.priorities)
            probs = [p / total_priority for p in self.priorities]
            indices = random.choices(range(len(self.buffer)), weights=probs, k=batch_size)
            out = list(zip(*[self.buffer[idx] for idx in indices]))
            
        for i in range(len(out)):
            # stack if all elements are numpy arrays or torch tensors
            # (this is much more efficient to send arrays through multiprocessing queues)
            if all([isinstance(x, np.ndarray) for x in out[i]]):
                out[i] = np.stack(out[i], axis=0)
            elif all([isinstance(x, torch.Tensor) for x in out[i]]):
                out[i] = torch.stack(out[i], dim=0)
        return tuple(out)

    def __len__(self):
        return len(self.buffer)