Source code for bapred.model.GraphGPS

import dgl, torch
import torch.nn as nn
import torch.nn.functional as F

from .GatedGCNLSPE import GatedGCNLSPELayer
from .MHA import MultiHeadAttention

[docs] class GraphGPS(nn.Module): def __init__(self, emb_dim, num_heads): super(GraphGPS, self).__init__() self.mpnn_layer = GatedGCNLSPELayer( emb_dim, emb_dim, dropout=0.1, batch_norm=True, use_lapeig_loss=False, residual=True ) self.mha_layer = MultiHeadAttention( emb_dim, num_heads ) self.MLP = nn.Sequential( nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.GELU(), nn.Dropout(p=0.1), nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.GELU(), nn.Dropout(p=0.1), ) self.mpnn_bn = nn.BatchNorm1d(emb_dim) self.mha_bn = nn.BatchNorm1d(emb_dim) self.mlp_bn = nn.BatchNorm1d(emb_dim) self.mpnn_weight = nn.Parameter( torch.tensor( [1.0] ) ) self.mha_weight = nn.Parameter( torch.tensor( [1.0] ) )
[docs] def forward(self, g, h, p, e): h_i = h e_i = e p_i = p h_mpnn, p, e = self.mpnn_layer(g, h, p, e) h_mha, h_wight = self.mha_layer(h) h_mpnn += h_i h_mha += h_i h_mpnn = self.mpnn_bn(h_mpnn) h_mha = self.mha_bn(h_mha) h_j = h_mpnn * self.mpnn_weight + h_mha * self.mha_weight h = self.MLP(h_j) h += h_j h = self.mlp_bn(h) return h, p, e