import torch


class CoachSampler:
    def __init__(self, init_step, max_epoch):
        self.init_step = init_step
        self.max_epoch = max_epoch
        self.enabled = True

    def enable(self, e):
        self.enabled = e

    def get_coach_step(self, epoch, batchsize):
        tensor_step = torch.zeros(batchsize, 1).long()
        if not self.enabled:
            return {'max_coach_step': tensor_step}

        step = self.init_step / self.max_epoch * (self.max_epoch - epoch)
        tensor_step += step
        reply = {
            'max_coach_step': tensor_step
        }
        return reply
