import torch
from torch import nn


# define metrics
def accuracy(y_hat, y):
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis = 1)
        y = y.argmax(axis = 1)
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum()/y.shape[0])

def reconstruction_loss(X_hat,X):
    n,d = X.shape
    criterion = nn.MSELoss(reduction = 'none')
    loss = criterion(X_hat, X)
    loss = loss.sum(axis = 1)
    return loss

def LS_loss_new(X_new,S_ori):
    n,k = X_new.shape
    S = (S_ori+S_ori.T)/2
    D = torch.diag_embed(S.sum(dim = -1))
    L = D-S
    res = torch.trace(torch.matmul(X_new.T,torch.matmul(L,X_new)))
    return res