import torch
from torch import Tensor
from torch import nn
import torch.nn.functional as F
[docs]
class E_GCL(nn.Module):
"""Graph Neural Net with global state and fixed number of nodes per graph.
Args:
hidden_dim: Number of hidden units.
num_nodes: Maximum number of nodes (for self-attentive pooling).
global_agg: Global aggregation function ('attn' or 'sum').
temp: Softmax temperature.
"""
def __init__(self, input_nf, output_nf, hidden_nf, edges_in_d=0, nodes_att_dim=0, act_fn=nn.SiLU(), attention=False, norm_diff=True, tanh=False, coords_range=1, norm_constant=0):
super(E_GCL, self).__init__()
input_edge = input_nf * 2
self.attention = attention
self.norm_diff = norm_diff
self.tanh = tanh
self.norm_constant = norm_constant
edge_coords_nf = 1
self.edge_mlp = nn.Sequential(
nn.Linear(input_edge + edge_coords_nf + edges_in_d, hidden_nf),
act_fn,
nn.Linear(hidden_nf, hidden_nf),
act_fn)
self.node_mlp = nn.Sequential(
nn.Linear(hidden_nf + input_nf + nodes_att_dim, hidden_nf),
act_fn,
nn.Linear(hidden_nf, output_nf))
layer = nn.Linear(hidden_nf, 1, bias=False)
torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)
coord_mlp = []
coord_mlp.append(nn.Linear(hidden_nf, hidden_nf))
coord_mlp.append(act_fn)
coord_mlp.append(layer)
if self.tanh:
coord_mlp.append(nn.Tanh())
self.coords_range = coords_range
self.coord_mlp = nn.Sequential(*coord_mlp)
if self.attention:
self.att_mlp = nn.Sequential(
nn.Linear(hidden_nf, 1),
nn.Sigmoid())
[docs]
def edge_model(self, source, target, radial, edge_attr, edge_mask):
if edge_attr is None: # Unused.
out = torch.cat([source, target, radial], dim=1)
else:
out = torch.cat([source, target, radial, edge_attr], dim=1)
out = self.edge_mlp(out)
if self.attention:
att_val = self.att_mlp(out)
out = out * att_val
if edge_mask is not None:
out = out * edge_mask
return out
[docs]
def node_model(self, x, edge_index, edge_attr, node_attr):
row, col = edge_index
agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0))
if node_attr is not None:
agg = torch.cat([x, agg, node_attr], dim=1)
else:
agg = torch.cat([x, agg], dim=1)
out = x + self.node_mlp(agg)
return out, agg
[docs]
def coord_model(self, coord, edge_index, coord_diff, radial, edge_feat, node_mask, edge_mask, coord_mask):
row, col = edge_index
if self.tanh:
trans = coord_diff * self.coord_mlp(edge_feat) * self.coords_range
else:
trans = coord_diff * self.coord_mlp(edge_feat)
if edge_mask is not None:
trans = trans * edge_mask
agg = unsorted_segment_sum(trans, row, num_segments=coord.size(0))
if coord_mask == None:
coord = coord + agg
else:
coord = coord + agg * coord_mask.view(-1,1)
return coord
[docs]
def forward(self, h, edge_index, coord, edge_attr=None, node_attr=None, node_mask=None, edge_mask=None, coord_mask=None):
row, col = edge_index
radial, coord_diff = self.coord2radial(edge_index, coord)
edge_feat = self.edge_model(h[row], h[col], radial, edge_attr, edge_mask) # Make message
coord = self.coord_model(coord, edge_index, coord_diff, radial, edge_feat, node_mask, edge_mask, coord_mask) # Update coordinates
h, agg = self.node_model(h, edge_index, edge_feat, node_attr) # Update nodes
if node_mask is not None:
h = h * node_mask
coord = coord * node_mask
return h, coord, edge_attr
[docs]
def coord2radial(self, edge_index, coord):
row, col = edge_index
coord_diff = coord[row] - coord[col]
radial = torch.sum((coord_diff)**2, 1).unsqueeze(1)
norm = torch.sqrt(radial + 1e-8)
coord_diff = coord_diff/(norm + self.norm_constant)
return radial, coord_diff
[docs]
class EGNN(nn.Module):
def __init__(self, in_node_nf, in_edge_nf, hidden_nf, device='cpu', act_fn=nn.SiLU(), n_layers=4, recurrent=True, attention=False, norm_diff=True, out_node_nf=None, tanh=False, coords_range=15, agg='sum', norm_constant=0, inv_sublayers=1, sin_embedding=False):
super(EGNN, self).__init__()
if out_node_nf is None:
out_node_nf = in_node_nf
self.hidden_nf = hidden_nf
self.device = device
self.n_layers = n_layers
self.coords_range_layer = float(coords_range)/self.n_layers
if agg == 'mean':
self.coords_range_layer = self.coords_range_layer * 19
#self.reg = reg
### Encoder
#self.add_module("gcl_0", E_GCL(in_node_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf, act_fn=act_fn, recurrent=False, coords_weight=coords_weight))
self.embedding = nn.Linear(in_node_nf, self.hidden_nf)
self.embedding_out = nn.Linear(self.hidden_nf, out_node_nf)
for i in range(0, n_layers):
self.add_module("gcl_%d" % i, E_GCL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf, act_fn=act_fn, attention=attention, norm_diff=norm_diff, tanh=tanh, coords_range=self.coords_range_layer, norm_constant=norm_constant))
self.to(self.device)
[docs]
def forward(self, h, x, edges, edge_attr=None, node_mask=None, edge_mask=None, coord_mask=None):
# Edit Emiel: Remove velocity as input
# edge_attr = torch.sum((x[edges[0]] - x[edges[1]]) ** 2, dim=1, keepdim=True)
h = self.embedding(h)
for i in range(0, self.n_layers):
h, x, _ = self._modules["gcl_%d" % i](h, edges, x, edge_attr=edge_attr, node_mask=node_mask, edge_mask=edge_mask, coord_mask=coord_mask)
h = self.embedding_out(h)
# Important, the bias of the last linear might be non-zero
if node_mask is not None:
h = h * node_mask
return h, x
[docs]
def unsorted_segment_sum(data, segment_ids, num_segments):
"""Custom PyTorch op to replicate TensorFlow's `unsorted_segment_sum`."""
result_shape = (num_segments, data.size(1))
result = data.new_full(result_shape, 0) # Init empty result tensor.
segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1))
result.scatter_add_(0, segment_ids, data)
return result
[docs]
class EGNN_old(nn.Module):
def __init__(self, in_node_nf, in_edge_nf, hidden_nf, device='cpu', act_fn=nn.SiLU(), n_layers=4, recurrent=True, attention=False, norm_diff=True, out_node_nf=None, tanh=False, coords_range=15, agg='sum'):
super(EGNN_old, self).__init__()
if out_node_nf is None:
out_node_nf = in_node_nf
self.hidden_nf = hidden_nf
self.device = device
self.n_layers = n_layers
self.coords_range_layer = float(coords_range)/self.n_layers
if agg == 'mean':
self.coords_range_layer = self.coords_range_layer * 19
#self.reg = reg
### Encoder
#self.add_module("gcl_0", E_GCL(in_node_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf, act_fn=act_fn, recurrent=False, coords_weight=coords_weight))
self.embedding = nn.Linear(in_node_nf, self.hidden_nf)
self.embedding_out = nn.Linear(self.hidden_nf, out_node_nf)
for i in range(0, n_layers):
self.add_module("gcl_%d" % i, E_GCL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf, act_fn=act_fn, attention=attention, norm_diff=norm_diff, tanh=tanh, coords_range=self.coords_range_layer))
self.to(self.device)
[docs]
def forward(self, h, x, edges, edge_attr=None, node_mask=None, edge_mask=None):
# Edit Emiel: Remove velocity as input
edge_attr = torch.sum((x[edges[0]] - x[edges[1]]) ** 2, dim=1, keepdim=True)
h = self.embedding(h)
for i in range(0, self.n_layers):
h, x, _ = self._modules["gcl_%d" % i](h, edges, x, edge_attr=edge_attr, node_mask=node_mask, edge_mask=edge_mask)
h = self.embedding_out(h)
# Important, the bias of the last linear might be non-zero
if node_mask is not None:
h = h * node_mask
return h, x
[docs]
class GNN(nn.Module):
def __init__(self, in_node_nf, in_edge_nf, hidden_nf, device='cpu', act_fn=nn.SiLU(), n_layers=4,
attention=False, out_node_nf=None):
super(GNN, self).__init__()
if out_node_nf is None:
out_node_nf = in_node_nf
self.hidden_nf = hidden_nf
self.device = device
self.n_layers = n_layers
### Encoder
self.embedding = nn.Linear(in_node_nf, self.hidden_nf)
self.embedding_out = nn.Linear(self.hidden_nf, out_node_nf)
for i in range(0, n_layers):
self.add_module("gcl_%d" % i, GCL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf,
act_fn=act_fn, attention=attention))
self.to(self.device)
[docs]
def forward(self, h, edges, edge_attr=None, node_mask=None, edge_mask=None):
# Edit Emiel: Remove velocity as input
h = self.embedding(h)
for i in range(0, self.n_layers):
h, _ = self._modules["gcl_%d" % i](h, edges, edge_attr=edge_attr, node_mask=node_mask,
edge_mask=edge_mask)
h = self.embedding_out(h)
# Important, the bias of the last linear might be non-zero
if node_mask is not None:
h = h * node_mask
return h
[docs]
class MLP(nn.Module):
def __init__(self, dim_in: int, dim_out: int, width: int, nb_layers: int, skip=1, bias=True):
"""
Args:
dim_in: input dimension
dim_out: output dimension
width: hidden width
nb_layers: number of layers
skip: jump from residual connections
bias: indicates presence of bias
"""
super(MLP, self).__init__()
self.dim_in = dim_in
self.dim_out = dim_out
self.width = width
self.nb_layers = nb_layers
self.hidden = nn.ModuleList()
self.lin1 = nn.Linear(self.dim_in, width, bias)
self.skip = skip
self.residual_start = dim_in == width
self.residual_end = dim_out == width
for i in range(nb_layers-2):
self.hidden.append(nn.Linear(width, width, bias))
self.lin_final = nn.Linear(width, dim_out, bias)
[docs]
def forward(self, x: Tensor):
out = self.lin1(x)
out = F.relu(out) + (x if self.residual_start else 0)
for layer in self.hidden:
out = out + layer(F.relu(out))
out = self.lin_final(F.relu(out)) + (out if self.residual_end else 0)
return out