import torch
import torch.nn as nn

class Sigma(nn.Module):
    def __init__(self, alpha, p):
        super(Sigma, self).__init__()
        self.alpha = alpha
        self.p = p

    def forward(self, H):
        return (H ** self.p).sum(-1) + (self.alpha * H).sum(-1)

class Attention(nn.Module):
    def __init__(self, Q, v):
        super(Attention, self).__init__()
        self.Q = Q
        self.v = v
        self.sm = nn.Softmax(dim = -1)


    def forward(self, X):

        Q = self.Q
        v = self.v
        attn = self.sm(Q)
        v_X = X @ v
        return v_X.mm(attn.T)

class Transformer(nn.Module):
    def __init__(self, Q, v, sigma):
        super(Transformer, self).__init__()
        self.Q = Q
        self.v = v
        self.attention = Attention(Q, v)
        self.sigma = sigma

    def forward(self, X):
        H = self.attention(X)
        return self.sigma(H)

