import torch
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.nn import APPNP, GATConv, Linear, GCNConv



class GCN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, activation, num_layers):
        super(GCN, self).__init__()
        self.activation = activation()
        self.layers = torch.nn.ModuleList()
        self.layers.append(GCNConv(input_dim, hidden_dim))
        for _ in range(num_layers-1):
            self.layers.append(GCNConv(hidden_dim, hidden_dim))

    def forward(self, x, edge_index, enhance=None, rw_embeddings=None, edge_weight=None):
        if enhance != None:
            z, _ = enhance.augment(x, rw_embeddings)
            
        else:
            z = x
        for i, conv in enumerate(self.layers):
            z = conv(z, edge_index, edge_weight)
            z = self.activation(z)
        return z

class Encoder(torch.nn.Module):
    def __init__(self, encoder, input_dim, hidden_dim, enhance):
        super(Encoder, self).__init__()
        self.encoder = encoder
        self.enhance1, self.enhance2 = enhance
        self.fc1 = Linear(input_dim, hidden_dim)
        self.fc2 = Linear(hidden_dim, hidden_dim)


    def forward(self, x, edge_index,rw_embeddings=None,edge_weight=None):
        z1  = self.encoder(x, edge_index, self.enhance1, rw_embeddings, edge_weight)
                
        z2 = self.encoder(x, edge_index,self.enhance2, rw_embeddings,edge_weight)

        z = torch.cat([z1, z2], dim=-1)
        return z, z1, z2

    def project(self, z: torch.Tensor) -> torch.Tensor:
        z = F.elu(self.fc1(z))
        return self.fc2(z)