import numpy as np
import torch


class Buffer:
    def __init__(self, n, shape, **kwargs):
        self.shape = shape
        self.n = n
        self.buf = torch.zeros([n, *shape], **kwargs)
        self.index = 0

    @torch.no_grad()
    def add(self, elem):
        self.buf[self.index % self.n] = torch.as_tensor(elem)
        self.index += 1

    @torch.no_grad()
    def add_many(self, elements):
        m = min(elements.shape[0], self.n)
        if m == 0:
            return
        n = self.n
        pad = (n - self.index) % n
        if m <= pad:
            self.buf[(self.index % n):(self.index % n + m)] = elements
        else:
            if pad != 0:
                self.buf[n - pad:] = elements[0:pad]
            self.buf[:m - pad] = elements[pad:m]
        self.index += m

    @torch.no_grad()
    def sample(self, n=0, index=None):
        if index is None:
            index = np.random.choice(min(self.index, self.n), n)
        return self.buf[index]

    @torch.no_grad()
    def all(self):
        return self.buf[:min(self.index, self.n)]
