Source code for diffalign.models.encoder.egnn

import torch
from torch import nn
import torch.nn.functional as F

[docs] def unsorted_segment_sum(data: torch.Tensor, segment_ids: torch.Tensor, num_segments: int) -> torch.Tensor: """TensorFlow-style unsorted_segment_sum for 2D tensors [E, C].""" result = data.new_zeros((num_segments, data.size(1))) scatter_index = segment_ids.unsqueeze(-1).expand(-1, data.size(1)) result.scatter_add_(0, scatter_index, data) return result
[docs] class E_GCL(nn.Module): def __init__( self, input_nf: int, output_nf: int, hidden_nf: int, edges_in_d: int = 0, nodes_att_dim: int = 0, act_fn = nn.SiLU(), attention: bool = False, norm_diff: bool = True, tanh: bool = False, coords_range: float = 1.0, norm_constant: float = 0.0, ): super().__init__() self.attention = attention self.norm_diff = norm_diff self.tanh = tanh self.norm_constant = norm_constant edge_coords_nf = 1 # ||x_i - x_j||^2 self.edge_mlp = nn.Sequential( nn.Linear(input_nf * 2 + edge_coords_nf + edges_in_d, hidden_nf), act_fn, nn.Linear(hidden_nf, hidden_nf), act_fn, ) self.node_mlp = nn.Sequential( nn.Linear(hidden_nf + input_nf + nodes_att_dim, hidden_nf), act_fn, nn.Linear(hidden_nf, output_nf), ) coord_head = nn.Linear(hidden_nf, 1, bias=False) nn.init.xavier_uniform_(coord_head.weight, gain=1e-3) coord_mlp = [ nn.Linear(hidden_nf, hidden_nf), act_fn, coord_head, ] if self.tanh: coord_mlp.append(nn.Tanh()) self.coords_range = coords_range self.coord_mlp = nn.Sequential(*coord_mlp) if self.attention: self.att_mlp = nn.Sequential(nn.Linear(hidden_nf, 1), nn.Sigmoid())
[docs] def coord2radial(self, edge_index: torch.Tensor, coord: torch.Tensor): row, col = edge_index coord_diff = coord[row] - coord[col] # [E,3] radial = torch.sum(coord_diff ** 2, dim=1, keepdim=True) # [E,1] if self.norm_diff: norm = torch.sqrt(radial + 1e-8) coord_diff = coord_diff / (norm + self.norm_constant) return radial, coord_diff
[docs] def edge_model(self, source, target, radial, edge_attr, edge_mask): if edge_attr is None: m_in = torch.cat([source, target, radial], dim=1) else: m_in = torch.cat([source, target, radial, edge_attr], dim=1) m = self.edge_mlp(m_in) if self.attention: m = m * self.att_mlp(m) if edge_mask is not None: m = m * edge_mask return m
[docs] def node_model(self, x, edge_index, edge_feat, node_attr): row, col = edge_index agg = unsorted_segment_sum(edge_feat, row, num_segments=x.size(0)) # [N, H] if node_attr is not None: agg = torch.cat([x, agg, node_attr], dim=1) else: agg = torch.cat([x, agg], dim=1) x_out = x + self.node_mlp(agg) return x_out
[docs] def coord_model(self, coord, edge_index, coord_diff, edge_feat, edge_mask, coord_mask): row, col = edge_index trans = coord_diff * self.coord_mlp(edge_feat) # [E,3] · [E,1] -> [E,3] if self.tanh: trans = trans * self.coords_range if edge_mask is not None: trans = trans * edge_mask dx = unsorted_segment_sum(trans, row, num_segments=coord.size(0)) # [N,3] if coord_mask is None: coord = coord + dx else: mask = coord_mask.unsqueeze(-1) if coord_mask.dim() == 1 else coord_mask coord = coord + dx * mask return coord
[docs] def forward( self, h: torch.Tensor, # [N, C_h] edge_index: torch.Tensor, # [2, E] coord: torch.Tensor, # [N, 3] edge_attr: torch.Tensor = None,# [E, C_e] node_attr: torch.Tensor = None,# [N, C_n] node_mask: torch.Tensor = None,# [N, 1] or [N] edge_mask: torch.Tensor = None,# [E, 1] or [E] coord_mask: torch.Tensor = None# [N, 1] or [N]; 1=update, 0=freeze ): radial, coord_diff = self.coord2radial(edge_index, coord) row, col = edge_index # Edge/message edge_feat = self.edge_model(h[row], h[col], radial, edge_attr, edge_mask) # Coordinate update (masked) coord = self.coord_model(coord, edge_index, coord_diff, edge_feat, edge_mask, coord_mask) # Node update h = self.node_model(h, edge_index, edge_feat, node_attr) if node_mask is not None: h = h * node_mask coord = coord * node_mask return h, coord, edge_feat # edge_feat is optional output
[docs] class EGNN(nn.Module): def __init__( self, in_node_nf: int, in_edge_nf: int, hidden_nf: int, device: str = 'cpu', act_fn = nn.SiLU(), n_layers: int = 4, attention: bool = False, norm_diff: bool = True, out_node_nf: int = None, tanh: bool = False, coords_range: float = 15.0, agg: str = 'sum', norm_constant: float = 0.0, ): super().__init__() if out_node_nf is None: out_node_nf = in_node_nf self.hidden_nf = hidden_nf self.device = device self.n_layers = n_layers self.coords_range_layer = float(coords_range) / max(1, n_layers) if agg == 'mean': self.coords_range_layer = self.coords_range_layer * 19 self.embedding = nn.Linear(in_node_nf, hidden_nf) self.embedding_out = nn.Linear(hidden_nf, out_node_nf) for i in range(n_layers): self.add_module( f"gcl_{i}", E_GCL( input_nf=hidden_nf, output_nf=hidden_nf, hidden_nf=hidden_nf, edges_in_d=in_edge_nf, act_fn=act_fn, attention=attention, norm_diff=norm_diff, tanh=tanh, coords_range=self.coords_range_layer, norm_constant=norm_constant, ), ) self.to(self.device)
[docs] def forward( self, h: torch.Tensor, # [N, C_in] x: torch.Tensor, # [N, 3] edges: torch.Tensor, # [2, E] edge_attr: torch.Tensor = None, # [E, C_e] node_mask: torch.Tensor = None, # [N, 1] or [N] edge_mask: torch.Tensor = None, # [E, 1] or [E] coord_mask: torch.Tensor = None, # [N, 1] or [N]; 1=update, 0=freeze node_attr: torch.Tensor = None, # [N, C_n] ): h = self.embedding(h) for i in range(self.n_layers): h, x, _ = self._modules[f"gcl_{i}"]( h, edges, x, edge_attr=edge_attr, node_attr=node_attr, node_mask=node_mask, edge_mask=edge_mask, coord_mask=coord_mask, ) h = self.embedding_out(h) if node_mask is not None: h = h * node_mask return h, x