"""Riemannian gradient descent optimizer."""

from torch.optim.optimizer import Optimizer, required

from utils.poincare import expmap, egrad2rgrad


class RSGD(Optimizer):

    def __init__(self, params, lr=required, param_names=[]):
        defaults = dict(lr=lr)
        super(RSGD, self).__init__(params, defaults)
        self.param_names = param_names

    def step(self, lr=None):
        loss = None
        for group in self.param_groups:
            for i, p in enumerate(group["params"]):
                if p.grad is None:
                    continue
                d_p = p.grad.data
                if lr is None:
                    lr = group["lr"]
                d_pr = egrad2rgrad(p, d_p)
                v = - lr * d_pr
                p.data = expmap(v, p.data)
        return loss
