import copy
import torch
from torch_geometric.data import Data
from torch_geometric.transforms import Compose
from torch_geometric.utils import to_dense_adj, dense_to_sparse
from torch_sparse import coalesce
from .chem import BOND_TYPES, BOND_NAMES, get_atom_symbol
[docs]
class AddHigherOrderEdges(object):
def __init__(self, order, num_types=len(BOND_TYPES)):
super().__init__()
self.order = order
self.num_types = num_types
[docs]
def binarize(self, x):
return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))
[docs]
def get_higher_order_adj_matrix(self, adj, order):
"""
Args:
adj: (N, N)
type_mat: (N, N)
"""
adj_mats = [torch.eye(adj.size(0), dtype=torch.long, device=adj.device), \
self.binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device))]
for i in range(2, order+1):
adj_mats.append(self.binarize(adj_mats[i-1] @ adj_mats[1]))
order_mat = torch.zeros_like(adj)
for i in range(1, order+1):
order_mat += (adj_mats[i] - adj_mats[i-1]) * i
return order_mat
def __call__(self, data: Data):
N = data.num_nodes
adj = to_dense_adj(data.edge_index).squeeze(0)
adj_order = self.get_higher_order_adj_matrix(adj, self.order) # (N, N)
type_mat = to_dense_adj(data.edge_index, edge_attr=data.edge_type).squeeze(0) # (N, N)
type_highorder = torch.where(adj_order > 1, self.num_types + adj_order - 1, torch.zeros_like(adj_order))
assert (type_mat * type_highorder == 0).all()
type_new = type_mat + type_highorder
new_edge_index, new_edge_type = dense_to_sparse(type_new)
_, edge_order = dense_to_sparse(adj_order)
data.bond_edge_index = data.edge_index # Save original edges
data.edge_index, data.edge_type = coalesce(new_edge_index, new_edge_type.long(), N, N) # modify data
edge_index_1, data.edge_order = coalesce(new_edge_index, edge_order.long(), N, N) # modify data
data.is_bond = (data.edge_type < self.num_types)
assert (data.edge_index == edge_index_1).all()
return data
[docs]
class AddEdgeLength(object):
def __call__(self, data: Data):
pos = data.pos
row, col = data.edge_index
d = (pos[row] - pos[col]).norm(dim=-1).unsqueeze(-1) # (num_edge, 1)
data.edge_length = d
return data
# Add attribute placeholder for data object, so that we can use batch.to_data_list
[docs]
class AddPlaceHolder(object):
def __call__(self, data: Data):
data.pos_gen = -1. * torch.ones_like(data.pos)
data.d_gen = -1. * torch.ones_like(data.edge_length)
data.d_recover = -1. * torch.ones_like(data.edge_length)
return data
[docs]
class AddEdgeName(object):
def __init__(self, asymmetric=True):
super().__init__()
self.bonds = copy.deepcopy(BOND_NAMES)
self.bonds[len(BOND_NAMES) + 1] = 'Angle'
self.bonds[len(BOND_NAMES) + 2] = 'Dihedral'
self.asymmetric = asymmetric
def __call__(self, data:Data):
data.edge_name = []
for i in range(data.edge_index.size(1)):
tail = data.edge_index[0, i]
head = data.edge_index[1, i]
if self.asymmetric and tail >= head:
data.edge_name.append('')
continue
tail_name = get_atom_symbol(data.atom_type[tail].item())
head_name = get_atom_symbol(data.atom_type[head].item())
name = '%s_%s_%s_%d_%d' % (
self.bonds[data.edge_type[i].item()] if data.edge_type[i].item() in self.bonds else 'E'+str(data.edge_type[i].item()),
tail_name,
head_name,
tail,
head,
)
if hasattr(data, 'edge_length'):
name += '_%.3f' % (data.edge_length[i].item())
data.edge_name.append(name)
return data
[docs]
class AddAngleDihedral(object):
def __init__(self):
super().__init__()
[docs]
@staticmethod
def iter_angle_triplet(bond_mat):
n_atoms = bond_mat.size(0)
for j in range(n_atoms):
for k in range(n_atoms):
for l in range(n_atoms):
if bond_mat[j, k].item() == 0 or bond_mat[k, l].item() == 0: continue
if (j == k) or (k == l) or (j >= l): continue
yield(j, k, l)
[docs]
@staticmethod
def iter_dihedral_quartet(bond_mat):
n_atoms = bond_mat.size(0)
for i in range(n_atoms):
for j in range(n_atoms):
if i >= j: continue
if bond_mat[i,j].item() == 0:continue
for k in range(n_atoms):
for l in range(n_atoms):
if (k in (i,j)) or (l in (i,j)): continue
if bond_mat[k,i].item() == 0 or bond_mat[l,j].item() == 0: continue
yield(k, i, j, l)
def __call__(self, data:Data):
N = data.num_nodes
if 'is_bond' in data:
bond_mat = to_dense_adj(data.edge_index, edge_attr=data.is_bond).long().squeeze(0) > 0
else:
bond_mat = to_dense_adj(data.edge_index, edge_attr=data.edge_type).long().squeeze(0) > 0
# Note: if the name of attribute contains `index`, it will automatically
# increases during batching.
data.angle_index = torch.LongTensor(list(self.iter_angle_triplet(bond_mat))).t()
data.dihedral_index = torch.LongTensor(list(self.iter_dihedral_quartet(bond_mat))).t()
return data
[docs]
class CountNodesPerGraph(object):
def __init__(self) -> None:
super().__init__()
def __call__(self, data):
data.num_nodes_per_graph = torch.LongTensor([data.num_nodes])
return data