"""Implementation of metrics."""
import torch
from typing import Any, Callable, Optional
# from pytorch_lightning.metrics import Metric
# from pytorch_lightning.metrics.functional import confusion_matrix




def _input_format_classification(
        preds: torch.Tensor,
        target: torch.Tensor,
        threshold: float = 0.5
        ):
    """ Convert preds and target tensors into label tensors
    Args:
        preds: either tensor with labels, tensor with probabilities/logits or
            multilabel tensor
        target: tensor with ground true labels
        threshold: float used for thresholding multilabel input
    Returns:
        preds: tensor with labels
        target: tensor with labels
    """
    if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1):
        raise ValueError(
            "preds and target must have same number of dimensions, or one additional dimension for preds"
        )

    if len(preds.shape) == len(target.shape) + 1:
        # multi class probabilites
        preds = torch.argmax(preds, dim=1)

    if len(preds.shape) == len(target.shape) and preds.dtype == torch.float:
        # binary or multilabel probablities
        preds = (preds >= threshold).long()
    return preds, target
