"""
Implementation of Focal Loss.
Reference:
[1]  T.-Y. Lin, P. Goyal, R. Girshick, K. He, and P. Dollar, Focal loss for dense object detection.
     arXiv preprint arXiv:1708.02002, 2017.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


# from: https://github.com/torrvision/focal_calibration/blob/main/Losses.py
class FocalLoss(nn.Module):
    def __init__(self, adaptive=False, gamma=0, size_average=False, device=None):
        super(FocalLoss, self).__init__()
        self.adaptive = adaptive
        self.gamma = gamma
        self.size_average = size_average
        self.device = device
        self.gamma_dic = {
            0.2: 5.0,
            0.5: 3.0
        }

    def get_gamma_list(self, pt):
        gamma_list = []
        batch_size = pt.shape[0]
        for i in range(batch_size):
            pt_sample = pt[i].item()
            if pt_sample >= 0.5:
                gamma_list.append(self.gamma)
                continue
            # Choosing the gamma for the sample
            for key in sorted(self.gamma_dic.keys()):
                if pt_sample < key:
                    gamma_list.append(self.gamma_dic[key])
                    break

        return torch.tensor(gamma_list).to(self.device)

    def forward(self, input, target):
        if input.dim() > 2:
            input = input.view(input.size(0), input.size(1), -1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1, 2)    # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1, input.size(2))   # N,H*W,C => N*H*W,C
        target = target.view(-1, 1)

        logpt = F.log_softmax(input, dim=1)
        logpt = logpt.gather(1, target)
        logpt = logpt.view(-1).to(self.device)
        pt = logpt.exp().to(self.device)

        gamma = self.get_gamma_list(pt) if self.adaptive else self.gamma
        loss = -1 * (1-pt)**gamma * logpt

        return loss.mean() if self.size_average else loss.sum()
