Source code for diffalign.models.encoder.edge

from torch.nn import Module, Embedding
from ..common import MultiLayerPerceptron

[docs] class MLPEdgeEncoder(Module): def __init__(self, hidden_dim=100, activation='relu'): super().__init__() self.hidden_dim = hidden_dim self.bond_emb = Embedding(100, embedding_dim=self.hidden_dim) self.mlp = MultiLayerPerceptron(1, [self.hidden_dim, self.hidden_dim], activation=activation) @property def out_channels(self): return self.hidden_dim
[docs] def forward(self, edge_length, edge_type): """ Input: edge_length: The length of edges, shape=(E, 1). edge_type: The type pf edges, shape=(E,) Returns: edge_attr: The representation of edges. (E, self.hidden_dim) """ d_emb = self.mlp(edge_length) # (num_edge, hidden_dim) edge_attr = self.bond_emb(edge_type) # (num_edge, hidden_dim) return d_emb * edge_attr # (num_edge, hidden)