import torch
from .base import _BaseAggregator


class Average(_BaseAggregator):
    def __call__(self, inputs):
        stacked = torch.stack(inputs, dim=0)
        return torch.mean(stacked, dim=0)

    def __str__(self):
        return "Baseline Averaging"
