#!/usr/bin/python
# -*- coding:utf-8 -*-
'''
Most codes are copied from https://github.com/vgsatorras/egnn, which is the official implementation of the paper:
E(n) Equivariant Graph Neural Networks
Victor Garcia Satorras, Emiel Hogeboom, Max Welling
'''
import math
import torch
from torch import nn
import torch.nn.functional as F
from torch_scatter import scatter_softmax, scatter_add, scatter_sum
from torch_geometric.utils import to_dense_batch
from .cross_att import CrossAttentionModule
from .model_utils import InteractionModule
[docs]
class MC_E_GCL(nn.Module):
"""
Multi-Channel E(n) Equivariant Convolutional Layer
"""
def __init__(self, args, input_nf, output_nf, hidden_nf, n_channel, edges_in_d=0, act_fn=nn.SiLU(),
residual=True, attention=False, normalize=False, coords_agg='mean', tanh=False,
dropout=0.1, coord_change_maximum=10):
super(MC_E_GCL, self).__init__()
input_edge = input_nf * 2
self.args = args
self.residual = residual
self.attention = attention
self.normalize = normalize
self.coords_agg = coords_agg
self.tanh = tanh
self.epsilon = 1e-8
self.dropout = nn.Dropout(dropout)
self.edge_mlp = nn.Sequential(
nn.Linear(input_edge + n_channel**2 + edges_in_d, hidden_nf),
act_fn,
nn.Linear(hidden_nf, hidden_nf),
act_fn)
self.node_mlp = nn.Sequential(
nn.Linear(hidden_nf + input_nf, hidden_nf),
act_fn,
nn.Linear(hidden_nf, output_nf))
layer = nn.Linear(hidden_nf, n_channel, bias=False)
torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)
coord_mlp = []
coord_mlp.append(nn.Linear(hidden_nf, hidden_nf))
coord_mlp.append(act_fn)
coord_mlp.append(layer)
if self.tanh:
coord_mlp.append(nn.Tanh())
self.coord_mlp = nn.Sequential(*coord_mlp)
if self.attention:
self.att_mlp = nn.Sequential(
nn.Linear(hidden_nf, 1),
nn.Sigmoid())
self.coord_change_maximum = coord_change_maximum
[docs]
def edge_model(self, source, target, radial, edge_attr):
'''
:param source: [n_edge, input_size]
:param target: [n_edge, input_size]
:param radial: [n_edge, n_channel, n_channel]
:param edge_attr: [n_edge, edge_dim]
'''
radial = radial.reshape(radial.shape[0], -1) # [n_edge, n_channel ^ 2]
if edge_attr is None: # Unused.
out = torch.cat([source, target, radial], dim=1)
else:
out = torch.cat([source, target, radial, edge_attr], dim=1)
out = self.edge_mlp(out)
out = self.dropout(out)
if self.attention:
att_val = self.att_mlp(out)
out = out * att_val
return out
[docs]
def node_model(self, x, edge_index, edge_attr, node_attr):
'''
:param x: [bs * n_node, input_size]
:param edge_index: list of [n_edge], [n_edge]
:param edge_attr: [n_edge, hidden_size], refers to message from i to j
:param node_attr: [bs * n_node, node_dim]
'''
row, col = edge_index
agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0)) # [bs * n_node, hidden_size]
# print_log(f'agg1, {torch.isnan(agg).sum()}', level='DEBUG')
if node_attr is not None:
agg = torch.cat([x, agg, node_attr], dim=1)
else:
agg = torch.cat([x, agg], dim=1) # [bs * n_node, input_size + hidden_size]
# print_log(f'agg, {torch.isnan(agg).sum()}', level='DEBUG')
out = self.node_mlp(agg) # [bs * n_node, output_size]
# print_log(f'out, {torch.isnan(out).sum()}', level='DEBUG')
out = self.dropout(out)
if self.residual:
out = x + out
return out, agg
[docs]
def coord_model(self, coord, edge_index, coord_diff, edge_feat):
'''
coord: [bs * n_node, n_channel, d]
edge_index: list of [n_edge], [n_edge]
coord_diff: [n_edge, n_channel, d]
edge_feat: [n_edge, hidden_size]
'''
row, col = edge_index
trans = coord_diff * self.coord_mlp(edge_feat).unsqueeze(-1) # [n_edge, n_channel, d]
if self.coords_agg == 'sum':
agg = unsorted_segment_sum(trans, row, num_segments=coord.size(0))
elif self.coords_agg == 'mean':
agg = unsorted_segment_mean(trans, row, num_segments=coord.size(0)) # [bs * n_node, n_channel, d]
else:
raise Exception('Wrong coords_agg parameter' % self.coords_agg)
coord = coord + agg.clamp(-self.coord_change_maximum, self.coord_change_maximum)
return coord
[docs]
def forward(self, h, edge_index, coord, edge_attr=None, node_attr=None, batch_id=None):
'''
h: [bs * n_node, hidden_size]
edge_index: list of [n_row] and [n_col] where n_row == n_col (with no cutoff, n_row == bs * n_node * (n_node - 1))
coord: [bs * n_node, n_channel, d]
'''
row, col = edge_index
radial, coord_diff = coord2radial(edge_index, coord, self.args.rm_F_norm, batch_id=batch_id, norm_type=self.args.norm_type)
edge_feat = self.edge_model(h[row], h[col], radial, edge_attr) # [n_edge, hidden_size]
# print_log(f'edge_feat, {torch.isnan(edge_feat).sum()}', level='DEBUG')
coord = self.coord_model(coord, edge_index, coord_diff, edge_feat) # [bs * n_node, n_channel, d]
h, agg = self.node_model(h, edge_index, edge_feat, node_attr)
return h, coord
[docs]
class MC_E_GCL_Prompt(nn.Module):
"""
Multi-Channel E(n) Equivariant Convolutional Layer
"""
def __init__(self, args, input_nf, output_nf, hidden_nf, prompt_nf, n_channel, edges_in_d=0, act_fn=nn.SiLU(),
residual=True, attention=False, normalize=False, coords_agg='mean', tanh=False,
dropout=0.1, coord_change_maximum=10):
super(MC_E_GCL_Prompt, self).__init__()
input_edge = input_nf * 2
self.args = args
self.residual = residual
self.attention = attention
self.normalize = normalize
self.coords_agg = coords_agg
self.tanh = tanh
self.epsilon = 1e-8
self.prompt_nf = prompt_nf
self.attn_nf = 4
#######################################################
self.prompt_node_mlp_0 = nn.Sequential(
nn.Linear(hidden_nf, prompt_nf),
act_fn,
nn.Linear(prompt_nf, prompt_nf**2),
act_fn)
self.prompt_node_mlp_1 = nn.Sequential(
nn.Linear(prompt_nf**2, prompt_nf),
act_fn,
nn.Linear(prompt_nf, hidden_nf),
act_fn)
self.prompt_coord_mlp_0 = nn.Sequential(
nn.Linear(hidden_nf, prompt_nf),
act_fn,
nn.Linear(prompt_nf, prompt_nf**2),
act_fn)
self.prompt_coord_mlp_1 = nn.Sequential(
nn.Linear(prompt_nf**2, prompt_nf),
act_fn,
nn.Linear(prompt_nf, hidden_nf),
act_fn)
#######################################################
self.prompt_node_mlp_agg_0 = nn.Linear(hidden_nf*2, hidden_nf)
self.prompt_node_mlp_agg_1 = nn.Linear(self.attn_nf**2, hidden_nf)
self.node_q, self.node_k, self.node_v = nn.Linear(hidden_nf, self.attn_nf**2), nn.Linear(hidden_nf, self.attn_nf**2), nn.Linear(hidden_nf, self.attn_nf**2)
self.prompt_coord_mlp_agg_0 = nn.Linear(hidden_nf*2, hidden_nf)
self.prompt_coord_mlp_agg_1 = nn.Linear(self.attn_nf**2, hidden_nf)
self.coord_q, self.coord_k, self.coord_v = nn.Linear(hidden_nf, self.attn_nf**2), nn.Linear(hidden_nf, self.attn_nf**2), nn.Linear(hidden_nf, self.attn_nf**2)
#######################################################
self.dropout = nn.Dropout(dropout)
self.edge_mlp = nn.Sequential(
nn.Linear(input_edge + n_channel**2 + edges_in_d, hidden_nf),
act_fn,
nn.Linear(hidden_nf, hidden_nf),
act_fn)
self.node_mlp = nn.Sequential(
nn.Linear(hidden_nf + input_nf, hidden_nf),
act_fn,
nn.Linear(hidden_nf, output_nf))
layer = nn.Linear(hidden_nf, n_channel, bias=False)
torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)
coord_mlp = []
coord_mlp.append(nn.Linear(hidden_nf, hidden_nf))
coord_mlp.append(act_fn)
coord_mlp.append(layer)
if self.tanh:
coord_mlp.append(nn.Tanh())
self.coord_mlp = nn.Sequential(*coord_mlp)
if self.attention:
self.att_mlp = nn.Sequential(
nn.Linear(hidden_nf, 1),
nn.Sigmoid())
self.coord_change_maximum = coord_change_maximum
[docs]
def edge_model(self, source, target, radial, edge_attr):
'''
:param source: [n_edge, input_size]
:param target: [n_edge, input_size]
:param radial: [n_edge, n_channel, n_channel]
:param edge_attr: [n_edge, edge_dim]
'''
radial = radial.reshape(radial.shape[0], -1) # [n_edge, n_channel ^ 2]
if edge_attr is None: # Unused.
out = torch.cat([source, target, radial], dim=1)
else:
out = torch.cat([source, target, radial, edge_attr], dim=1)
out = self.edge_mlp(out)
out = self.dropout(out)
if self.attention:
att_val = self.att_mlp(out)
out = out * att_val
return out
[docs]
def node_model(self, x, edge_index, edge_attr_agg, node_attr):
'''
:param x: [bs * n_node, input_size]
:param edge_index: list of [n_edge], [n_edge]
:param edge_attr: [n_edge, hidden_size], refers to message from i to j
:param node_attr: [bs * n_node, node_dim]
'''
row, col = edge_index
#agg = unsorted_segment_sum(torch.cat([edge_attr, self.prompt_node_mlp(prompt)], dim=-1), row, num_segments=x.size(0)) # [bs * n_node, hidden_size]
agg = unsorted_segment_sum(edge_attr_agg, row, num_segments=x.size(0))
# print_log(f'agg1, {torch.isnan(agg).sum()}', level='DEBUG')
if node_attr is not None:
agg = torch.cat([x, agg, node_attr], dim=1)
else:
agg = torch.cat([x, agg], dim=1) # [bs * n_node, input_size + hidden_size]
# print_log(f'agg, {torch.isnan(agg).sum()}', level='DEBUG')
out = self.node_mlp(agg) # [bs * n_node, output_size]
# print_log(f'out, {torch.isnan(out).sum()}', level='DEBUG')
out = self.dropout(out)
if self.residual:
out = x + out
return out, agg
[docs]
def coord_model(self, coord, edge_index, coord_diff, edge_feat_agg):
'''
coord: [bs * n_node, n_channel, d]
edge_index: list of [n_edge], [n_edge]
coord_diff: [n_edge, n_channel, d]
edge_feat: [n_edge, hidden_size]
'''
row, col = edge_index
n_edge, n_channel, d = coord_diff.size()
#trans = coord_diff * self.coord_mlp(torch.cat([edge_feat, self.prompt_coord_mlp(prompt)], dim=1)).unsqueeze(-1) # [n_edge, n_channel, d]
trans = coord_diff * self.coord_mlp(edge_feat_agg).unsqueeze(-1)
if self.coords_agg == 'sum':
agg = unsorted_segment_sum(trans, row, num_segments=coord.size(0))
elif self.coords_agg == 'mean':
agg = unsorted_segment_mean(trans, row, num_segments=coord.size(0)) # [bs * n_node, n_channel, d]
else:
raise Exception('Wrong coords_agg parameter' % self.coords_agg)
coord = coord + agg.clamp(-self.coord_change_maximum, self.coord_change_maximum)
return coord
[docs]
def prompt_generation_module(self, edge_feat, prompt_node, prompt_coord) :
weight_feat_node = self.prompt_node_mlp_0(edge_feat)
weight_feat_node = F.softmax(weight_feat_node.view(-1, self.prompt_nf, self.prompt_nf), dim=-1)
prompt_node_feat = torch.flatten(weight_feat_node*prompt_node, start_dim=1)
prompt_node_feat = self.prompt_node_mlp_1(prompt_node_feat)
weight_feat_coord = self.prompt_coord_mlp_0(edge_feat)
weight_feat_coord = F.softmax(weight_feat_coord.view(-1, self.prompt_nf, self.prompt_nf), dim=-1)
prompt_coord_feat = torch.flatten(weight_feat_coord*prompt_coord, start_dim=1)
prompt_coord_feat = self.prompt_coord_mlp_1(prompt_coord_feat)
return prompt_node_feat, prompt_coord_feat # [n_edge, n_prompt]
[docs]
def prompt_interaction_module(self, edge_feat, prompt_node_feat, prompt_coord_feat) :
prompt_node_feat_agg = self.prompt_node_mlp_agg_0(torch.cat([edge_feat, prompt_node_feat], dim=1))
node_q, node_k, node_v = self.node_q(prompt_node_feat_agg), self.node_k(prompt_node_feat_agg), self.node_v(prompt_node_feat_agg)
node_q, node_k, node_v = node_q.view(-1, self.attn_nf, self.attn_nf), node_k.view(-1, self.attn_nf, self.attn_nf), node_v.view(-1, self.attn_nf, self.attn_nf)
node_weight = F.softmax(torch.bmm(node_q, node_k.permute(0, 2, 1)), dim=-1)
node_residual = self.prompt_node_mlp_agg_1(torch.bmm(node_weight, node_v).view(-1, (self.attn_nf)**2))
prompt_node_final = edge_feat + node_residual
###############################################################################################################################################
prompt_coord_feat_agg = self.prompt_coord_mlp_agg_0(torch.cat([edge_feat, prompt_coord_feat], dim=1))
coord_q, coord_k, coord_v = self.coord_q(prompt_coord_feat_agg), self.coord_k(prompt_coord_feat_agg), self.coord_v(prompt_coord_feat_agg)
coord_q, coord_k, coord_v = coord_q.view(-1, self.attn_nf, self.attn_nf), coord_k.view(-1, self.attn_nf, self.attn_nf), coord_v.view(-1, self.attn_nf, self.attn_nf)
coord_weight = F.softmax(torch.bmm(coord_q, coord_k.permute(0, 2, 1)), dim=-1)
coord_residual = self.prompt_coord_mlp_agg_1(torch.bmm(coord_weight, coord_v).view(-1, (self.attn_nf)**2))
prompt_coord_final = edge_feat + coord_residual
return prompt_node_final, prompt_coord_final
[docs]
def forward(self, h, edge_index, coord, edge_attr=None, node_attr=None, batch_id=None, prompt_node=None, prompt_coord=None):
'''
h: [bs * n_node, hidden_size]
edge_index: list of [n_row] and [n_col] where n_row == n_col (with no cutoff, n_row == bs * n_node * (n_node - 1))
coord: [bs * n_node, n_channel, d]
'''
row, col = edge_index
radial, coord_diff = coord2radial(edge_index, coord, self.args.rm_F_norm, batch_id=batch_id, norm_type=self.args.norm_type)
edge_feat = self.edge_model(h[row], h[col], radial, edge_attr) # [n_edge, hidden_size]
prompt_node_feat, prompt_coord_feat = self.prompt_generation_module(edge_feat, prompt_node, prompt_coord) # Newly Added
edge_feat_node_agg, edge_feat_coord_agg = self.prompt_interaction_module(edge_feat, prompt_node_feat, prompt_coord_feat) # Newly Added
# print_log(f'edge_feat, {torch.isnan(edge_feat).sum()}', level='DEBUG')
coord = self.coord_model(coord, edge_index, coord_diff, edge_feat_coord_agg) # [bs * n_node, n_channel, d]
h, agg = self.node_model(h, edge_index, edge_feat_node_agg, node_attr)
return h, coord, edge_feat_node_agg, edge_feat_coord_agg
[docs]
class MC_Att_L(nn.Module):
"""
Multi-Channel Attention Layer
"""
def __init__(self, args, input_nf, output_nf, hidden_nf, n_channel, edges_in_d=0,
act_fn=nn.SiLU(), dropout=0.1, coord_change_maximum=10, opm=False, normalize_coord=None):
super().__init__()
self.args = args
self.hidden_nf = hidden_nf
self.dropout = nn.Dropout(dropout)
self.linear_q = nn.Linear(input_nf, hidden_nf)
self.linear_kv = nn.Linear(input_nf + n_channel ** 2 + edges_in_d, hidden_nf * 2) # parallel calculate kv
layer = nn.Linear(hidden_nf, n_channel, bias=False)
torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)
coord_mlp = []
coord_mlp.append(nn.Linear(hidden_nf, hidden_nf))
coord_mlp.append(act_fn)
coord_mlp.append(layer)
self.coord_mlp = nn.Sequential(*coord_mlp)
self.coord_change_maximum = coord_change_maximum
if args.add_cross_attn_layer and args.explicit_pair_embed:
self.cross_attn_module = CrossAttentionModule(node_hidden_dim=input_nf, pair_hidden_dim=input_nf, rm_layernorm=args.rm_layernorm, keep_trig_attn=args.keep_trig_attn, dist_hidden_dim=input_nf, normalize_coord=normalize_coord)
elif args.add_cross_attn_layer and not args.explicit_pair_embed:
raise AssertionError
elif args.add_cross_attn_layer and not args.add_attn_pair_bias:
raise AssertionError
if args.add_attn_pair_bias:
self.inter_layer = InteractionModule(input_nf, output_nf, hidden_nf, opm=opm, rm_layernorm=args.rm_layernorm)
self.attn_bias_proj = nn.Linear(hidden_nf, 1)
[docs]
def att_model(self, h, edge_index, radial, edge_attr, pair_embed=None):
'''
:param h: [bs * n_node, input_size]
:param edge_index: list of [n_edge], [n_edge]
:param radial: [n_edge, n_channel, n_channel]
:param edge_attr: [n_edge, edge_dim]
'''
row, col = edge_index
source, target = h[row], h[col] # [n_edge, input_size]
# qkv
q = self.linear_q(source) # [n_edge, hidden_size]
n_channel = radial.shape[1]
radial = radial.reshape(radial.shape[0], n_channel * n_channel) # [n_edge, n_channel ^ 2]
if edge_attr is not None:
target_feat = torch.cat([radial, target, edge_attr], dim=1)
else:
target_feat = torch.cat([radial, target], dim=1)
kv = self.linear_kv(target_feat) # [n_edge, hidden_size * 2]
k, v = kv[..., 0::2], kv[..., 1::2] # [n_edge, hidden_size]
if self.args.add_attn_pair_bias:
attn_bias = self.attn_bias_proj(pair_embed).squeeze(-1) # [n_edge]
# attention weight
alpha = torch.sum(q * k, dim=1) + attn_bias # [n_edge]
else:
# attention weight
alpha = torch.sum(q * k, dim=1) # [n_edge]
# print_log(f'alpha1, {torch.isnan(alpha).sum()}', level='DEBUG')
# alpha = scatter_softmax(alpha, row, h.shape[0]) # [n_edge]
alpha = scatter_softmax(alpha, row) # [n_edge]
# print_log(f'alpha2, {torch.isnan(alpha).sum()}', level='DEBUG')
return alpha, v
[docs]
def node_model(self, h, edge_index, att_weight, v):
'''
:param h: [bs * n_node, input_size]
:param edge_index: list of [n_edge], [n_edge]
:param att_weight: [n_edge, 1], unsqueezed before passed in
:param v: [n_edge, hidden_size]
'''
row, _ = edge_index
agg = unsorted_segment_sum(att_weight * v, row, h.shape[0]) # [bs * n_node, hidden_size]
agg = self.dropout(agg)
return h + agg # residual
[docs]
def coord_model(self, coord, edge_index, coord_diff, att_weight, v):
'''
:param coord: [bs * n_node, n_channel, d]
:param edge_index: list of [n_edge], [n_edge]
:param coord_diff: [n_edge, n_channel, d]
:param att_weight: [n_edge, 1], unsqueezed before passed in
:param v: [n_edge, hidden_size]
'''
row, _ = edge_index
coord_v = att_weight * self.coord_mlp(v) # [n_edge, n_channel]
trans = coord_diff * coord_v.unsqueeze(-1)
agg = unsorted_segment_sum(trans, row, coord.size(0))
coord = coord + agg.clamp(-self.coord_change_maximum, self.coord_change_maximum)
return coord
[docs]
def trio_encoder(self, h, edge_index, coord, pair_embed_batched=None, pair_mask=None, batch_id=None, segment_id=None, reduced_tuple=None, LAS_mask=None, p_p_dist_embed=None, c_c_dist_embed=None):
row, col = edge_index
# pair wise feature
c_batch = batch_id[segment_id == 0]
p_batch = batch_id[segment_id == 1]
c_embed = h[segment_id == 0]
p_embed = h[segment_id == 1]
p_embed_batched, p_mask = to_dense_batch(p_embed, p_batch)
c_embed_batched, c_mask = to_dense_batch(c_embed, c_batch)
if self.args.add_cross_attn_layer:
p_embed_batched, c_embed_batched, pair_embed_batched = self.cross_attn_module(
p_embed_batched, p_mask,
c_embed_batched, c_mask,
pair_embed_batched, pair_mask,
p_p_dist_embed=p_p_dist_embed, c_c_dist_embed=c_c_dist_embed
)
for i in range(batch_id.max()+1):
if i == 0:
new_h = torch.cat((c_embed_batched[i][c_mask[i]], p_embed_batched[i][p_mask[i]]), dim=0)
else:
new_sample = torch.cat((c_embed_batched[i][c_mask[i]], p_embed_batched[i][p_mask[i]]), dim=0)
new_h = torch.cat((new_h, new_sample), dim=0)
else:
new_h = h
if self.args.explicit_pair_embed:
pair_embed_batched = pair_embed_batched + self.inter_layer(p_embed_batched, c_embed_batched, p_mask, c_mask)[0]
else:
pair_embed_batched = self.inter_layer(p_embed_batched, c_embed_batched, p_mask, c_mask)[0]
# pair offset embeddings for attention bias
compound_offset_in_batch = c_mask.sum(1)
reduced_inter_edges_batchid, reduced_inter_edge_offsets = reduced_tuple
reduced_row = row[row<col] - reduced_inter_edge_offsets
reduced_col = col[row<col] - reduced_inter_edge_offsets - compound_offset_in_batch[reduced_inter_edges_batchid]
first_part = pair_embed_batched[reduced_inter_edges_batchid, reduced_col, reduced_row] # col: protein, row: ligand
reduced_row = row[row>col] - reduced_inter_edge_offsets - compound_offset_in_batch[reduced_inter_edges_batchid]
reduced_col = col[row>col] - reduced_inter_edge_offsets
second_part = pair_embed_batched[reduced_inter_edges_batchid, reduced_row, reduced_col] # row: protein, col: ligand
for i in range(reduced_inter_edges_batchid.max()+1):
if i == 0:
pair_offset = torch.cat((
first_part[reduced_inter_edges_batchid==i], second_part[reduced_inter_edges_batchid==i]
), dim=0)
else:
new_sample = torch.cat((
first_part[reduced_inter_edges_batchid==i], second_part[reduced_inter_edges_batchid==i]
), dim=0)
pair_offset = torch.cat((pair_offset, new_sample), dim=0)
return new_h, pair_embed_batched, pair_offset
[docs]
def forward(self, h, edge_index, coord, edge_attr=None, segment_id=None, batch_id=None,
reduced_tuple=None, pair_embed_batched=None, pair_mask=None, LAS_mask=None,
p_p_dist_embed=None, c_c_dist_embed=None):
# Cross-attention
if self.args.add_attn_pair_bias:
h, pair_embed_batched, pair_offset_embed = self.trio_encoder(
h, edge_index, coord, pair_embed_batched=pair_embed_batched, pair_mask=pair_mask,
batch_id=batch_id, segment_id=segment_id, reduced_tuple=reduced_tuple, LAS_mask=LAS_mask,
p_p_dist_embed=p_p_dist_embed, c_c_dist_embed=c_c_dist_embed
)
else:
pair_offset_embed = None
# Interfacial
radial, coord_diff = coord2radial(edge_index, coord, self.args.rm_F_norm, batch_id=batch_id, norm_type=self.args.norm_type)
att_weight, v = self.att_model(h, edge_index, radial, edge_attr, pair_embed=pair_offset_embed)
# print_log(f'att_weight, {torch.isnan(att_weight).sum()}', level='DEBUG')
# print_log(f'v, {torch.isnan(v).sum()}', level='DEBUG')
flat_att_weight = att_weight
att_weight = att_weight.unsqueeze(-1) # [n_edge, 1]
h = self.node_model(h, edge_index, att_weight, v)
coord = self.coord_model(coord, edge_index, coord_diff, att_weight, v)
return h, coord, flat_att_weight
[docs]
class MCAttEGNN(nn.Module):
def __init__(self, args, in_node_nf, hidden_nf, out_node_nf, prompt_nf, n_channel, in_edge_nf=0,
act_fn=nn.SiLU(), n_layers=4, residual=True, dropout=0.1, dense=False,
normalize_coord=None, unnormalize_coord=None, geometry_reg_step_size=0.001):
super().__init__()
'''
:param in_node_nf: Number of features for 'h' at the input
:param hidden_nf: Number of hidden features
:param out_node_nf: Number of features for 'h' at the output
:param n_channel: Number of channels of coordinates
:param in_edge_nf: Number of features for the edge features
:param act_fn: Non-linearity
:param n_layers: Number of layer for the EGNN
:param residual: Use residual connections, we recommend not changing this one
:param dropout: probability of dropout
:param dense: if dense, then context states will be concatenated for all layers,
coordination will be averaged
'''
self.args = args
self.geometry_reg_step_size = geometry_reg_step_size
self.geom_reg_steps = 1
self.hidden_nf = hidden_nf
self.n_layers = n_layers
self.dropout = nn.Dropout(dropout)
self.linear_in = nn.Linear(in_node_nf, self.hidden_nf)
self.dense = dense
self.normalize_coord = normalize_coord
self.unnormalize_coord = unnormalize_coord
if dense:
self.linear_out = nn.Linear(self.hidden_nf * (n_layers + 1), out_node_nf)
else:
self.linear_out = nn.Linear(self.hidden_nf, out_node_nf)
for i in range(0, n_layers):
self.add_module(f'gcl_{i}', MC_E_GCL_Prompt(
args, self.hidden_nf, self.hidden_nf, self.hidden_nf, prompt_nf, n_channel,
edges_in_d=in_edge_nf, act_fn=act_fn, residual=residual, dropout=dropout,
coord_change_maximum=self.normalize_coord(10)
))
# TODO: add parameter for passing edge type to interaction layer
self.add_module(f'att_{i}', MC_Att_L(
args, self.hidden_nf, self.hidden_nf, self.hidden_nf,
n_channel, edges_in_d=0, act_fn=act_fn, dropout=dropout,
coord_change_maximum=self.normalize_coord(10), opm=args.opm, normalize_coord=normalize_coord
))
#self.out_layer = MC_E_GCL(
#args, self.hidden_nf, self.hidden_nf, self.hidden_nf, n_channel,
#edges_in_d=in_edge_nf, act_fn=act_fn, residual=residual,
#coord_change_maximum=self.normalize_coord(10)
#)
self.out_layer = MC_E_GCL_Prompt(
args, self.hidden_nf, self.hidden_nf, self.hidden_nf, prompt_nf, n_channel,
edges_in_d=in_edge_nf, act_fn=act_fn, residual=residual,
coord_change_maximum=self.normalize_coord(10)
)
[docs]
def forward(self, h, x, ctx_edges, att_edges, LAS_edge_list, batched_complex_coord_LAS,
segment_id=None, batch_id=None, reduced_tuple=None,
pair_embed_batched=None, pair_mask=None, LAS_mask=None,
p_p_dist_embed=None, c_c_dist_embed=None, mask=None,
ctx_edge_attr=None, att_edge_attr=None, return_attention=False, prompt_node=None, prompt_coord=None):
h = self.linear_in(h)
h = self.dropout(h)
x = x.clone()
ctx_states, ctx_coords, atts = [], [], []
for i in range(0, self.n_layers):
h, coord, prompt_node_feat, prompt_coord_feat = self._modules[f'gcl_{i}'](h, ctx_edges, x, batch_id=batch_id, edge_attr=ctx_edge_attr, prompt_node=prompt_node, prompt_coord=prompt_coord)
if self.args.fix_pocket:
x[mask] = coord[mask]
else:
x = coord
ctx_states.append(h)
ctx_coords.append(x)
# attention bias
if self.args.add_attn_pair_bias:
if self.args.explicit_pair_embed:
h, coord, att = self._modules[f'att_{i}'](
h, att_edges, x, edge_attr=att_edge_attr,
segment_id=segment_id, batch_id=batch_id, reduced_tuple=reduced_tuple,
pair_embed_batched=pair_embed_batched, pair_mask=pair_mask, LAS_mask=LAS_mask,
p_p_dist_embed=p_p_dist_embed, c_c_dist_embed=c_c_dist_embed)
else:
h, coord, att = self._modules[f'att_{i}'](
h, att_edges, x, edge_attr=att_edge_attr,
segment_id=segment_id, batch_id=batch_id, reduced_tuple=reduced_tuple)
else:
h, coord, att = self._modules[f'att_{i}'](h, att_edges, x, batch_id=batch_id, edge_attr=att_edge_attr)
if self.args.fix_pocket:
x[mask] = coord[mask]
else:
x = coord
atts.append(att)
if not self.args.rm_LAS_constrained_optim:
x.squeeze_(1)
batched_complex_coord_LAS.squeeze_(1)
for step in range(self.geom_reg_steps):
LAS_cur_squared = torch.sum(
(x[LAS_edge_list[0]] - x[LAS_edge_list[1]]) ** 2, dim=1)
LAS_true_squared = torch.sum(
(batched_complex_coord_LAS[LAS_edge_list[0]] - batched_complex_coord_LAS[LAS_edge_list[1]]) ** 2,
dim=1)
grad_squared = 2 * (x[LAS_edge_list[0]] - x[LAS_edge_list[1]])
LAS_force = 2 * (LAS_cur_squared - LAS_true_squared)[:, None] * grad_squared
LAS_delta_coord = scatter_add(src=LAS_force, index=LAS_edge_list[1], dim=0,
dim_size=x.shape[0])
x = x + (LAS_delta_coord * self.geometry_reg_step_size) \
.clamp(min=self.normalize_coord(-15), max=self.normalize_coord(15))
x.unsqueeze_(1)
h, coord, _, _ = self.out_layer(h, ctx_edges, x, batch_id=batch_id, edge_attr=ctx_edge_attr, prompt_node=prompt_node, prompt_coord=prompt_coord)
#h, coord = self.out_layer(h, ctx_edges, x, batch_id=batch_id, edge_attr=ctx_edge_attr)
if self.args.fix_pocket:
x[mask] = coord[mask]
else:
x = coord
ctx_states.append(h)
ctx_coords.append(x)
if self.dense:
h = torch.cat(ctx_states, dim=-1)
x = torch.mean(torch.stack(ctx_coords), dim=0)
h = self.dropout(h)
h = self.linear_out(h)
if return_attention:
return h, x, atts
else:
return h, x, prompt_node_feat, prompt_coord_feat
[docs]
def coord2radial(edge_index, coord, rm_F_norm, batch_id=None, norm_type=None):
row, col = edge_index
coord_diff = coord[row] - coord[col] # [n_edge, n_channel, d]
radial = torch.bmm(coord_diff, coord_diff.transpose(-1, -2)) # [n_edge, n_channel, n_channel]
# normalize radial
if not rm_F_norm:
if norm_type == 'all_sample':
radial = F.normalize(radial, dim=0) # [n_edge, n_channel, n_channel]
elif norm_type == 'per_sample':
edge_batch_id = batch_id[row]
norm_for_each_sample = scatter_sum(src=(radial**2), index=edge_batch_id, dim=0).sqrt()
norm_for_each_edge = norm_for_each_sample[edge_batch_id]
radial = radial / norm_for_each_edge
elif norm_type == '4_sample':
shrink_batch_id = batch_id // 4
edge_batch_id = shrink_batch_id[row]
norm_for_each_sample = scatter_sum(src=(radial**2), index=edge_batch_id, dim=0).sqrt()
norm_for_each_edge = norm_for_each_sample[edge_batch_id]
radial = radial / norm_for_each_edge
return radial, coord_diff
[docs]
def unsorted_segment_sum(data, segment_ids, num_segments):
'''
:param data: [n_edge, *dimensions]
:param segment_ids: [n_edge]
:param num_segments: [bs * n_node]
'''
expand_dims = tuple(data.shape[1:])
result_shape = (num_segments, ) + expand_dims
for _ in expand_dims:
segment_ids = segment_ids.unsqueeze(-1)
segment_ids = segment_ids.expand(-1, *expand_dims)
result = data.new_full(result_shape, 0) # Init empty result tensor.
result.scatter_add_(0, segment_ids, data)
return result
[docs]
def unsorted_segment_mean(data, segment_ids, num_segments):
'''
:param data: [n_edge, *dimensions]
:param segment_ids: [n_edge]
:param num_segments: [bs * n_node]
'''
expand_dims = tuple(data.shape[1:])
result_shape = (num_segments, ) + expand_dims
for _ in expand_dims:
segment_ids = segment_ids.unsqueeze(-1)
segment_ids = segment_ids.expand(-1, *expand_dims)
result = data.new_full(result_shape, 0) # Init empty result tensor.
count = data.new_full(result_shape, 0)
result.scatter_add_(0, segment_ids, data)
count.scatter_add_(0, segment_ids, torch.ones_like(data))
return result / count.clamp(min=1)
[docs]
def get_edges(n_nodes):
rows, cols = [], []
for i in range(n_nodes):
for j in range(n_nodes):
if i != j:
rows.append(i)
cols.append(j)
edges = [rows, cols]
return edges
[docs]
def get_edges_batch(n_nodes, batch_size):
edges = get_edges(n_nodes)
edge_attr = torch.ones(len(edges[0]) * batch_size, 1)
edges = [torch.LongTensor(edges[0]), torch.LongTensor(edges[1])]
if batch_size == 1:
return edges, edge_attr
elif batch_size > 1:
rows, cols = [], []
for i in range(batch_size):
rows.append(edges[0] + n_nodes * i)
cols.append(edges[1] + n_nodes * i)
edges = [torch.cat(rows), torch.cat(cols)]
return edges, edge_attr
if __name__ == "__main__":
# Dummy parameters
batch_size = 8
n_nodes = 4
n_feat = 10
x_dim = 3
n_channel = 5
# Dummy variables h, x and fully connected edges
h = torch.randn(batch_size * n_nodes, n_feat)
x = torch.randn(batch_size * n_nodes, n_channel, x_dim)
edges, edge_attr = get_edges_batch(n_nodes, batch_size)
ctx_edges, att_edges = edges, edges
# Initialize EGNN
gnn = MCAttEGNN(in_node_nf=n_feat, hidden_nf=32, out_node_nf=21, n_channel=n_channel)
# Run EGNN
h, x = gnn(h, x, ctx_edges, att_edges)
print(h)
print(x)