import torch
import torch.nn.functional as F


CONFUSING_NUMBERS = {
    1: [7, 2],
    2: [1, 7],
    3: [5, 8],
    4: [9, 7],
    5: [3, 6, 8],
    6: [0, 5, 8],
    7: [1, 2],
    8: [3, 5, 6, 9],
    9: [0, 4, 7, 8],
    0: [6, 9]
}


def pred_label(y, classifier):
    pad_y = F.pad(y, (2, 2, 2, 2), value=0)
    logits = classifier(pad_y)
    preds = torch.argmax(logits, dim=1)
    return preds


def pred_prob(y, classifier):
    pad_y = F.pad(y, (2, 2, 2, 2), value=0)
    logits = classifier(pad_y)
    preds = torch.argmax(logits, dim=1)
    probs = torch.nn.functional.softmax(logits, dim=1)
    probs = probs[torch.arange(len(probs)), preds]
    return probs