Source code for diffalign.models.encoder.egnn

import torch
from torch import Tensor
from torch import nn
import torch.nn.functional as F


[docs] class E_GCL(nn.Module): """Graph Neural Net with global state and fixed number of nodes per graph. Args: hidden_dim: Number of hidden units. num_nodes: Maximum number of nodes (for self-attentive pooling). global_agg: Global aggregation function ('attn' or 'sum'). temp: Softmax temperature. """ def __init__(self, input_nf, output_nf, hidden_nf, edges_in_d=0, nodes_att_dim=0, act_fn=nn.SiLU(), attention=False, norm_diff=True, tanh=False, coords_range=1, norm_constant=0): super(E_GCL, self).__init__() input_edge = input_nf * 2 self.attention = attention self.norm_diff = norm_diff self.tanh = tanh self.norm_constant = norm_constant edge_coords_nf = 1 self.edge_mlp = nn.Sequential( nn.Linear(input_edge + edge_coords_nf + edges_in_d, hidden_nf), act_fn, nn.Linear(hidden_nf, hidden_nf), act_fn) self.node_mlp = nn.Sequential( nn.Linear(hidden_nf + input_nf + nodes_att_dim, hidden_nf), act_fn, nn.Linear(hidden_nf, output_nf)) layer = nn.Linear(hidden_nf, 1, bias=False) torch.nn.init.xavier_uniform_(layer.weight, gain=0.001) coord_mlp = [] coord_mlp.append(nn.Linear(hidden_nf, hidden_nf)) coord_mlp.append(act_fn) coord_mlp.append(layer) if self.tanh: coord_mlp.append(nn.Tanh()) self.coords_range = coords_range self.coord_mlp = nn.Sequential(*coord_mlp) if self.attention: self.att_mlp = nn.Sequential( nn.Linear(hidden_nf, 1), nn.Sigmoid())
[docs] def edge_model(self, source, target, radial, edge_attr, edge_mask): if edge_attr is None: # Unused. out = torch.cat([source, target, radial], dim=1) else: out = torch.cat([source, target, radial, edge_attr], dim=1) out = self.edge_mlp(out) if self.attention: att_val = self.att_mlp(out) out = out * att_val if edge_mask is not None: out = out * edge_mask return out
[docs] def node_model(self, x, edge_index, edge_attr, node_attr): row, col = edge_index agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0)) if node_attr is not None: agg = torch.cat([x, agg, node_attr], dim=1) else: agg = torch.cat([x, agg], dim=1) out = x + self.node_mlp(agg) return out, agg
[docs] def coord_model(self, coord, edge_index, coord_diff, radial, edge_feat, node_mask, edge_mask, coord_mask): row, col = edge_index if self.tanh: trans = coord_diff * self.coord_mlp(edge_feat) * self.coords_range else: trans = coord_diff * self.coord_mlp(edge_feat) if edge_mask is not None: trans = trans * edge_mask agg = unsorted_segment_sum(trans, row, num_segments=coord.size(0)) if coord_mask == None: coord = coord + agg else: coord = coord + agg * coord_mask.view(-1,1) return coord
[docs] def forward(self, h, edge_index, coord, edge_attr=None, node_attr=None, node_mask=None, edge_mask=None, coord_mask=None): row, col = edge_index radial, coord_diff = self.coord2radial(edge_index, coord) edge_feat = self.edge_model(h[row], h[col], radial, edge_attr, edge_mask) # Make message coord = self.coord_model(coord, edge_index, coord_diff, radial, edge_feat, node_mask, edge_mask, coord_mask) # Update coordinates h, agg = self.node_model(h, edge_index, edge_feat, node_attr) # Update nodes if node_mask is not None: h = h * node_mask coord = coord * node_mask return h, coord, edge_attr
[docs] def coord2radial(self, edge_index, coord): row, col = edge_index coord_diff = coord[row] - coord[col] radial = torch.sum((coord_diff)**2, 1).unsqueeze(1) norm = torch.sqrt(radial + 1e-8) coord_diff = coord_diff/(norm + self.norm_constant) return radial, coord_diff
[docs] class EGNN(nn.Module): def __init__(self, in_node_nf, in_edge_nf, hidden_nf, device='cpu', act_fn=nn.SiLU(), n_layers=4, recurrent=True, attention=False, norm_diff=True, out_node_nf=None, tanh=False, coords_range=15, agg='sum', norm_constant=0, inv_sublayers=1, sin_embedding=False): super(EGNN, self).__init__() if out_node_nf is None: out_node_nf = in_node_nf self.hidden_nf = hidden_nf self.device = device self.n_layers = n_layers self.coords_range_layer = float(coords_range)/self.n_layers if agg == 'mean': self.coords_range_layer = self.coords_range_layer * 19 #self.reg = reg ### Encoder #self.add_module("gcl_0", E_GCL(in_node_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf, act_fn=act_fn, recurrent=False, coords_weight=coords_weight)) self.embedding = nn.Linear(in_node_nf, self.hidden_nf) self.embedding_out = nn.Linear(self.hidden_nf, out_node_nf) for i in range(0, n_layers): self.add_module("gcl_%d" % i, E_GCL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf, act_fn=act_fn, attention=attention, norm_diff=norm_diff, tanh=tanh, coords_range=self.coords_range_layer, norm_constant=norm_constant)) self.to(self.device)
[docs] def forward(self, h, x, edges, edge_attr=None, node_mask=None, edge_mask=None, coord_mask=None): # Edit Emiel: Remove velocity as input # edge_attr = torch.sum((x[edges[0]] - x[edges[1]]) ** 2, dim=1, keepdim=True) h = self.embedding(h) for i in range(0, self.n_layers): h, x, _ = self._modules["gcl_%d" % i](h, edges, x, edge_attr=edge_attr, node_mask=node_mask, edge_mask=edge_mask, coord_mask=coord_mask) h = self.embedding_out(h) # Important, the bias of the last linear might be non-zero if node_mask is not None: h = h * node_mask return h, x
[docs] def unsorted_segment_sum(data, segment_ids, num_segments): """Custom PyTorch op to replicate TensorFlow's `unsorted_segment_sum`.""" result_shape = (num_segments, data.size(1)) result = data.new_full(result_shape, 0) # Init empty result tensor. segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1)) result.scatter_add_(0, segment_ids, data) return result
[docs] class EGNN_old(nn.Module): def __init__(self, in_node_nf, in_edge_nf, hidden_nf, device='cpu', act_fn=nn.SiLU(), n_layers=4, recurrent=True, attention=False, norm_diff=True, out_node_nf=None, tanh=False, coords_range=15, agg='sum'): super(EGNN_old, self).__init__() if out_node_nf is None: out_node_nf = in_node_nf self.hidden_nf = hidden_nf self.device = device self.n_layers = n_layers self.coords_range_layer = float(coords_range)/self.n_layers if agg == 'mean': self.coords_range_layer = self.coords_range_layer * 19 #self.reg = reg ### Encoder #self.add_module("gcl_0", E_GCL(in_node_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf, act_fn=act_fn, recurrent=False, coords_weight=coords_weight)) self.embedding = nn.Linear(in_node_nf, self.hidden_nf) self.embedding_out = nn.Linear(self.hidden_nf, out_node_nf) for i in range(0, n_layers): self.add_module("gcl_%d" % i, E_GCL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf, act_fn=act_fn, attention=attention, norm_diff=norm_diff, tanh=tanh, coords_range=self.coords_range_layer)) self.to(self.device)
[docs] def forward(self, h, x, edges, edge_attr=None, node_mask=None, edge_mask=None): # Edit Emiel: Remove velocity as input edge_attr = torch.sum((x[edges[0]] - x[edges[1]]) ** 2, dim=1, keepdim=True) h = self.embedding(h) for i in range(0, self.n_layers): h, x, _ = self._modules["gcl_%d" % i](h, edges, x, edge_attr=edge_attr, node_mask=node_mask, edge_mask=edge_mask) h = self.embedding_out(h) # Important, the bias of the last linear might be non-zero if node_mask is not None: h = h * node_mask return h, x
[docs] class GNN(nn.Module): def __init__(self, in_node_nf, in_edge_nf, hidden_nf, device='cpu', act_fn=nn.SiLU(), n_layers=4, attention=False, out_node_nf=None): super(GNN, self).__init__() if out_node_nf is None: out_node_nf = in_node_nf self.hidden_nf = hidden_nf self.device = device self.n_layers = n_layers ### Encoder self.embedding = nn.Linear(in_node_nf, self.hidden_nf) self.embedding_out = nn.Linear(self.hidden_nf, out_node_nf) for i in range(0, n_layers): self.add_module("gcl_%d" % i, GCL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf, act_fn=act_fn, attention=attention)) self.to(self.device)
[docs] def forward(self, h, edges, edge_attr=None, node_mask=None, edge_mask=None): # Edit Emiel: Remove velocity as input h = self.embedding(h) for i in range(0, self.n_layers): h, _ = self._modules["gcl_%d" % i](h, edges, edge_attr=edge_attr, node_mask=node_mask, edge_mask=edge_mask) h = self.embedding_out(h) # Important, the bias of the last linear might be non-zero if node_mask is not None: h = h * node_mask return h
[docs] class TransformerNN(nn.Module): def __init__(self, in_node_nf, in_edge_nf, hidden_nf, device='cpu', act_fn=nn.SiLU(), n_layers=4, recurrent=True, attention=False, norm_diff=True, out_node_nf=None, tanh=False, coords_range=15, agg='sum', norm_constant=0): super(EGNN, self).__init__() hidden_initial = 128 initial_mlp_layers = 1 hidden_final = 128 final_mlp_layers = 2 n_heads = 8 dim_feedforward = 512 if out_node_nf is None: out_node_nf = in_node_nf self.hidden_nf = hidden_nf self.device = device self.n_layers = n_layers self.initial_mlp = MLP(in_node_nf, hidden_nf, hidden_initial, initial_mlp_layers, skip=1, bias=True) self.decoder_layers = nn.ModuleList() if self.use_bn: self.bn_layers = nn.ModuleList() for _ in n_layers: self.decoder_layers.append(nn.TransformerEncoderLayer(hidden_nf, n_heads, dim_feedforward, dropout=0)) self.final_mlp = MLP(hidden_nf, out_node_nf, hidden_final, final_mlp_layers, skip=1, bias=True) self.to(self.device)
[docs] def forward(self, h, x, edges, edge_attr=None, node_mask=None, edge_mask=None): """ x: batch_size, n, channels latent: batch_size, channels2. """ x = F.relu(self.initial_mlp(x)) for i in range(len(self.decoder_layers)): out = F.relu(self.decoder_layers[i](x)) # [bs, n, d] if self.use_bn and type(x).__name__ != 'TransformerEncoderLayer': out = self.bn_layers[i](out.transpose(1, 2)).transpose(1, 2) # bs, n, hidden x = out + x if self.res and type(x).__name__ != 'TransformerEncoderLayer' else out return x # Edit Emiel: Remove velocity as input edge_attr = torch.sum((x[edges[0]] - x[edges[1]]) ** 2, dim=1, keepdim=True) h = self.embedding(h) for i in range(0, self.n_layers): h, x, _ = self._modules["gcl_%d" % i](h, edges, x, edge_attr=edge_attr, node_mask=node_mask, edge_mask=edge_mask) h = self.embedding_out(h) # Important, the bias of the last linear might be non-zero if node_mask is not None: h = h * node_mask return h, x
[docs] class SetDecoder(nn.Module): def __init__(self, cfg): super().__init__() hidden, hidden_final = cfg.hidden_decoder, cfg.hidden_last_decoder self.use_bn = cfg.use_batch_norm self.res = cfg.use_residual self.cosine_channels = cfg.cosine_channels self.initial_mlp = MLP(cfg.set_channels, cfg.hidden_decoder, cfg.hidden_initial_decoder, cfg.initial_mlp_layers_decoder, skip=1, bias=True) self.decoder_layers = nn.ModuleList() if self.use_bn: self.bn_layers = nn.ModuleList() for layer in cfg.decoder_layers: self.decoder_layers.append(create_layer(layer, hidden, hidden, cfg)) if self.use_bn: self.bn_layers.append(nn.BatchNorm1d(hidden))
[docs] def forward(self, x, latent): """ x: batch_size, n, channels latent: batch_size, channels2. """ x = F.relu(self.initial_mlp(x, latent[:, self.cosine_channels:].unsqueeze(1))) for i in range(len(self.decoder_layers)): out = F.relu(self.decoder_layers[i](x)) # [bs, n, d] if self.use_bn and type(x).__name__ != 'TransformerEncoderLayer': out = self.bn_layers[i](out.transpose(1, 2)).transpose(1, 2) # bs, n, hidden x = out + x if self.res and type(x).__name__ != 'TransformerEncoderLayer' else out return x
[docs] class MLP(nn.Module): def __init__(self, dim_in: int, dim_out: int, width: int, nb_layers: int, skip=1, bias=True): """ Args: dim_in: input dimension dim_out: output dimension width: hidden width nb_layers: number of layers skip: jump from residual connections bias: indicates presence of bias """ super(MLP, self).__init__() self.dim_in = dim_in self.dim_out = dim_out self.width = width self.nb_layers = nb_layers self.hidden = nn.ModuleList() self.lin1 = nn.Linear(self.dim_in, width, bias) self.skip = skip self.residual_start = dim_in == width self.residual_end = dim_out == width for i in range(nb_layers-2): self.hidden.append(nn.Linear(width, width, bias)) self.lin_final = nn.Linear(width, dim_out, bias)
[docs] def forward(self, x: Tensor): out = self.lin1(x) out = F.relu(out) + (x if self.residual_start else 0) for layer in self.hidden: out = out + layer(F.relu(out)) out = self.lin_final(F.relu(out)) + (out if self.residual_end else 0) return out