Source code for diffalign.models.common

# coding=utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import radius_graph, radius
from torch_scatter import scatter_mean, scatter_add, scatter_max
from torch_sparse import coalesce
from torch_geometric.utils import to_dense_adj, dense_to_sparse

from diffalign.utils.chem import BOND_TYPES


[docs] class MultiLayerPerceptron(nn.Module): """ Multi-layer Perceptron. Note there is no activation or dropout in the last layer. Parameters: input_dim (int): input dimension hidden_dim (list of int): hidden dimensions activation (str or function, optional): activation function dropout (float, optional): dropout rate """ def __init__(self, input_dim, hidden_dims, activation="relu", dropout=0): super(MultiLayerPerceptron, self).__init__() self.dims = [input_dim] + hidden_dims if isinstance(activation, str): self.activation = getattr(F, activation) else: self.activation = None if dropout: self.dropout = nn.Dropout(dropout) else: self.dropout = None self.layers = nn.ModuleList() for i in range(len(self.dims) - 1): self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1]))
[docs] def forward(self, input): """""" x = input for i, layer in enumerate(self.layers): x = layer(x) if i < len(self.layers) - 1: if self.activation: x = self.activation(x) if self.dropout: x = self.dropout(x) return x
def _extend_graph_order(num_nodes, edge_index, edge_type, order=3): """ Args: num_nodes: Number of atoms. edge_index: Bond indices of the original graph. edge_type: Bond types of the original graph. order: Extension order. Returns: new_edge_index: Extended edge indices. new_edge_type: Extended edge types. """ def binarize(x): return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x)) def get_higher_order_adj_matrix(adj, order): """ Args: adj: (N, N) type_mat: (N, N) Returns: Following attributes will be updated: - edge_index - edge_type Following attributes will be added to the data object: - bond_edge_index: Original edge_index. """ adj_mats = [torch.eye(adj.size(0), dtype=torch.long, device=adj.device), \ binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device))] for i in range(2, order+1): adj_mats.append(binarize(adj_mats[i-1] @ adj_mats[1])) order_mat = torch.zeros_like(adj) for i in range(1, order+1): order_mat += (adj_mats[i] - adj_mats[i-1]) * i return order_mat num_types = len(BOND_TYPES) N = num_nodes adj = to_dense_adj(edge_index).squeeze(0) adj_order = get_higher_order_adj_matrix(adj, order) # (N, N) type_mat = to_dense_adj(edge_index, edge_attr=edge_type).squeeze(0) # (N, N) type_highorder = torch.where(adj_order > 1, num_types + adj_order - 1, torch.zeros_like(adj_order)) assert (type_mat * type_highorder == 0).all() type_new = type_mat + type_highorder new_edge_index, new_edge_type = dense_to_sparse(type_new) new_edge_index, new_edge_type = coalesce(new_edge_index, new_edge_type.long(), N, N) # modify data # [Note] This is not necessary # data.is_bond = (data.edge_type < num_types) # [Note] In earlier versions, `edge_order` attribute will be added. # However, it doesn't seem to be necessary anymore so I removed it. # edge_index_1, data.edge_order = coalesce(new_edge_index, edge_order.long(), N, N) # modify data # assert (data.edge_index == edge_index_1).all() return new_edge_index, new_edge_type def _extend_to_radius_graph(pos, edge_index, edge_type, cutoff, batch, unspecified_type_number=0): assert edge_type.dim() == 1 N = pos.size(0) bgraph_adj = torch.sparse.LongTensor( edge_index, edge_type, torch.Size([N, N]) ) rgraph_edge_index = radius_graph(pos, r=cutoff, batch=batch) # (2, E_r) rgraph_adj = torch.sparse.LongTensor( rgraph_edge_index, torch.ones(rgraph_edge_index.size(1)).long().to(pos.device) * unspecified_type_number, torch.Size([N, N]) ) composed_adj = (bgraph_adj + rgraph_adj).coalesce() # Sparse (N, N, T) new_edge_index = composed_adj.indices() new_edge_type = composed_adj.values().long() new_edge_index = new_edge_index[:,batch[new_edge_index[0]]==batch[new_edge_index[1]]] new_edge_type = new_edge_type[batch[new_edge_index[0]]==batch[new_edge_index[1]]] return new_edge_index, new_edge_type
[docs] def extend_to_cross_attention(pos, cutoff, batch, graph_idx): rgraph_edge_index = radius_graph(pos, r=cutoff, batch=batch) # (2, E_r) rgraph_edge_index = rgraph_edge_index[:,graph_idx[rgraph_edge_index[0]]!=graph_idx[rgraph_edge_index[1]]] return rgraph_edge_index
[docs] def extend_graph_order_radius(num_nodes, pos, edge_index, edge_type, batch, order=3, cutoff=10.0, extend_order=True, extend_radius=True): if extend_order: edge_index, edge_type = _extend_graph_order( num_nodes=num_nodes, edge_index=edge_index, edge_type=edge_type, order=order ) if extend_radius: edge_index, edge_type = _extend_to_radius_graph( pos=pos, edge_index=edge_index, edge_type=edge_type, cutoff=cutoff, batch=batch, ) return edge_index, edge_type