import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms


def err(logits, labels):
    return torch.mean((torch.argmax(logits, dim=1) != labels).to(torch.float))


def trans_err(logits_s, logits_t):
    return torch.mean(
        (torch.argmax(logits_s, dim=1) != torch.argmax(logits_t, dim=1)).to(torch.float)
    )


loss = nn.CrossEntropyLoss()


def soft_loss(logits_s, logits_t):
    return -torch.mean(torch.mul(F.softmax(logits_t, dim=1), F.log_softmax(logits_s, dim=1)))


train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])


def parameters_vec(net):
    mod = nn.ModuleList()
    for m in net.modules():
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
            mod.append(m)

    return nn.utils.parameters_to_vector(mod.parameters())


def reverse_sigmoid(x):
    return torch.min(
        torch.log(x) - torch.log(1 - x), torch.tensor(16.0, dtype=x.dtype, device=x.device)
    )
