import torch


class DictDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        assert "index" in data, "Data must have an 'index' key"
        self.data = data
        self.index = data.pop("index").detach().cpu().numpy()

        self.data_dict = {}
        for j, idx in enumerate(self.index):
            self.data_dict[idx] = {k: v[j] for k, v in data.items()}

    def __getitem__(self, idx):
        return self.data_dict[idx]

    def __len__(self):
        return len(self.index)


class DictDatasetWriter:
    def __init__(self, path):
        self.path = path
        self.dataset = {"index": [], "ref_score": []}

    def write(self, idx: torch.Tensor, data: torch.Tensor):
        self.dataset["index"].append(idx)
        self.dataset["ref_score"].append(data)

    def close(self):
        dataset = {k: torch.stack(v) for k, v in self.dataset.items()}
        torch.save(dataset, self.path)
