Source code for promptbind.models.att_model

#!/usr/bin/python
# -*- coding:utf-8 -*-
import torch
from torch import masked_fill, nn
import torch.nn.functional as F
import random
from torch_scatter import scatter_sum

from torch_geometric.utils import to_dense_batch
from .egnn import MCAttEGNN
from .model_utils import InteractionModule, RBFDistanceModule


[docs] def sequential_and(*tensors): res = tensors[0] for mat in tensors[1:]: res = torch.logical_and(res, mat) return res
[docs] def sequential_or(*tensors): res = tensors[0] for mat in tensors[1:]: res = torch.logical_or(res, mat) return res
[docs] class ComplexGraph(nn.Module): def __init__(self, args, inter_cutoff=10, intra_cutoff=8, normalize_coord=None, unnormalize_coord=None): super().__init__() self.args = args self.inter_cutoff = normalize_coord(inter_cutoff) self.intra_cutoff = normalize_coord(intra_cutoff)
[docs] @torch.no_grad() def construct_edges(self, X, batch_id, segment_ids, is_global): ''' Memory efficient with complexity of O(Nn) where n is the largest number of nodes in the batch ''' # construct tensors to map between global / local node index lengths = scatter_sum(torch.ones_like(batch_id), batch_id) # [bs] N, max_n = batch_id.shape[0], torch.max(lengths) offsets = F.pad(torch.cumsum(lengths, dim=0)[:-1], pad=(1, 0), value=0) # [bs] # global node index to local index. lni2gni can be implemented as lni + offsets[batch_id] gni = torch.arange(N, device=batch_id.device) gni2lni = gni - offsets[batch_id] # [N] ctx_edges, inter_edges = [], [] # all possible edges (within the same graph) # same bid (get rid of self-loop and none edges) same_bid = torch.zeros(N, max_n, device=batch_id.device) same_bid[(gni, lengths[batch_id] - 1)] = 1 same_bid = 1 - torch.cumsum(same_bid, dim=-1) # shift right and pad 1 to the left same_bid = F.pad(same_bid[:, :-1], pad=(1, 0), value=1) same_bid[(gni, gni2lni)] = 0 # delete self loop row, col = torch.nonzero(same_bid).T # [2, n_edge_all] col = col + offsets[batch_id[row]] # mapping from local to global node index # not global edges # is_global = sequential_or(S == self.boa_idx, S == self.boh_idx, S == self.bol_idx) # [N] row_global, col_global = is_global[row], is_global[col] not_global_edges = torch.logical_not(torch.logical_or(row_global, col_global)) # all possible ctx edges: seg==protein, not global # segment for compound is 0, for protein is 1 row_seg, col_seg = segment_ids[row], segment_ids[col] select_edges = sequential_and( row_seg == col_seg, row_seg == 1, not_global_edges ) ctx_all_row, ctx_all_col = row[select_edges], col[select_edges] # ctx edges ctx_edges = _radial_edges(X, torch.stack([ctx_all_row, ctx_all_col]).T, cutoff=self.intra_cutoff) # all possible inter edges: not same seg, not global select_edges = torch.logical_and(row_seg != col_seg, not_global_edges) inter_all_row, inter_all_col = row[select_edges], col[select_edges] inter_edges = _radial_edges(X, torch.stack([inter_all_row, inter_all_col]).T, cutoff=self.inter_cutoff) if inter_edges.shape[1] == 0: inter_edges = torch.tensor([[inter_all_row[0], inter_all_col[0]], [inter_all_col[0], inter_all_row[0]]], device=inter_all_row.device) reduced_inter_edge_batchid = batch_id[inter_edges[0][inter_edges[0] < inter_edges[1]]] # # make sure row belongs to compound and col belongs to protein # inter_edge_lengths = scatter_sum(torch.ones_like(inter_edges_batchid), inter_edges_batchid) reduced_inter_edge_offsets = offsets.gather(-1, reduced_inter_edge_batchid) # edges between global and normal nodes select_edges = torch.logical_and(row_seg == col_seg, torch.logical_not(not_global_edges)) global_normal = torch.stack([row[select_edges], col[select_edges]]) # [2, nE] # edges between global and global nodes select_edges = torch.logical_and(row_global, col_global) # self-loop has been deleted global_global = torch.stack([row[select_edges], col[select_edges]]) # [2, nE] # add additional edge to neighbors in 1D sequence (except epitope) # select_edges = sequential_and( # torch.logical_or((row - col) == 1, (row - col) == -1), # adjacent in the graph # not_global_edges, # not global edges (also ensure the edges are in the same segment) # row_seg != self.ag_seg_id # not epitope # ) # seq_adj = torch.stack([row[select_edges], col[select_edges]]) # [2, nE] # finally construct context edges space_edge_num = ctx_edges.shape[1] + global_normal.shape[1] + global_global.shape[1] ctx_edges = torch.cat([ctx_edges, global_normal, global_global], dim=1) # [2, E] # ctx_edge_feats = torch.cat( # [torch.zeros(space_edge_num, dtype=torch.float, device=X.device), # torch.ones(seq_adj.shape[1], dtype=torch.float, device=X.device)], dim=0).unsqueeze(-1) if self.args.add_attn_pair_bias: return ctx_edges, inter_edges, (reduced_inter_edge_batchid, reduced_inter_edge_offsets) else: return ctx_edges, inter_edges, None
[docs] def forward(self, X, batch_id, segment_id, is_global): return self.construct_edges(X, batch_id, segment_id, is_global)
def _radial_edges(X, src_dst, cutoff): dist = X[:, 0][src_dst] # [Ef, 2, 3], CA position dist = torch.norm(dist[:, 0] - dist[:, 1], dim=-1) # [Ef] src_dst = src_dst[dist <= cutoff] src_dst = src_dst.transpose(0, 1) # [2, Ef] return src_dst
[docs] class EfficientMCAttModel(nn.Module): def __init__(self, args, embed_size, hidden_size, prompt_nf, n_channel, n_edge_feats=0, n_layers=5, dropout=0.1, n_iter=5, dense=False, inter_cutoff=10, intra_cutoff=8, normalize_coord=None, unnormalize_coord=None): super().__init__() self.n_iter = n_iter self.args = args self.random_n_iter = args.random_n_iter self.gnn = MCAttEGNN(args, embed_size, hidden_size, hidden_size, prompt_nf, n_channel, n_edge_feats, n_layers=n_layers, residual=True, dropout=dropout, dense=dense, normalize_coord=normalize_coord, unnormalize_coord=unnormalize_coord, geometry_reg_step_size=args.geometry_reg_step_size) # complex graph features self.extract_edges = ComplexGraph(args, inter_cutoff=inter_cutoff, intra_cutoff=intra_cutoff, normalize_coord=normalize_coord, unnormalize_coord=unnormalize_coord) # construct pair embed if args.explicit_pair_embed: self.inter_layer = InteractionModule(hidden_size, hidden_size, hidden_size, rm_layernorm=args.rm_layernorm) if args.keep_trig_attn: f = normalize_coord self.p_p_dist_layer = RBFDistanceModule(rbf_stop=f(32), distance_hidden_dim=hidden_size, num_gaussian=32) self.c_c_dist_layer = RBFDistanceModule(rbf_stop=f(16), distance_hidden_dim=hidden_size, num_gaussian=32)
[docs] def forward(self, X, H, batch_id, segment_id, mask, is_global, compound_edge_index, LAS_edge_index, batched_complex_coord_LAS, LAS_mask=None, prompt_node=None, prompt_coord=None): ''' :param X: [n_all_node, n_channel, 3] :param S: [n_all_node] :param batch: [n_all_node] ''' if self.args.keep_trig_attn: s_coord = X.squeeze(1).clone().detach() c_batch = batch_id[segment_id == 0] p_batch = batch_id[segment_id == 1] c_coord = s_coord[segment_id == 0] p_coord = s_coord[segment_id == 1] p_coord_batched, p_coord_mask = to_dense_batch(p_coord, p_batch) # (B, Np_max, 3) c_coord_batched, c_coord_mask = to_dense_batch(c_coord, c_batch) # (B, Np_max, 3) p_p_dist = torch.cdist(p_coord_batched, p_coord_batched, compute_mode='donot_use_mm_for_euclid_dist') c_c_dist = torch.cdist(c_coord_batched, c_coord_batched, compute_mode='donot_use_mm_for_euclid_dist') p_p_dist_mask = torch.einsum("...i, ...j->...ij", p_coord_mask, p_coord_mask) # c_c_dist_mask = torch.einsum("...i, ...j->...ij", c_coord_mask, c_coord_mask) c_c_diag_mask = torch.diag_embed(c_coord_mask) # (B, Nc, Nc) c_c_dist_mask = torch.logical_or(LAS_mask, c_c_diag_mask) p_p_dist[~p_p_dist_mask] = 1e6 c_c_dist[~c_c_dist_mask] = 1e6 p_p_dist_embed = self.p_p_dist_layer(p_p_dist) c_c_dist_embed = self.c_c_dist_layer(c_c_dist) else: p_p_dist_embed=None c_c_dist_embed=None if self.args.explicit_pair_embed: c_batch = batch_id[segment_id == 0] p_batch = batch_id[segment_id == 1] c_embed = H[segment_id == 0] p_embed = H[segment_id == 1] p_embed_batched, p_mask = to_dense_batch(p_embed, p_batch) c_embed_batched, c_mask = to_dense_batch(c_embed, c_batch) pair_embed_batched, pair_mask = self.inter_layer(p_embed_batched, c_embed_batched, p_mask, c_mask) pair_embed_batched = pair_embed_batched * pair_mask.to(torch.float).unsqueeze(-1) else: pair_embed_batched, pair_mask = None, None if self.training and self.random_n_iter: iter_i = random.randint(1, self.n_iter) else: iter_i = self.n_iter for r in range(iter_i): # refine if self.args.refine == 'stack': with torch.no_grad(): ctx_edges, inter_edges, reduced_tuple = self.extract_edges(X, batch_id, segment_id, is_global) ctx_edges = torch.cat((compound_edge_index, ctx_edges), dim=1) H, Z, prompt_node_feat, prompt_coord_feat = self.gnn(H, X, ctx_edges, inter_edges, LAS_edge_index, batched_complex_coord_LAS, segment_id=segment_id, batch_id=batch_id, reduced_tuple=reduced_tuple, pair_embed_batched=pair_embed_batched, pair_mask=pair_mask, LAS_mask=LAS_mask, p_p_dist_embed=p_p_dist_embed, c_c_dist_embed=c_c_dist_embed, mask=mask, prompt_node=prompt_node, prompt_coord=prompt_coord) X[mask] = Z[mask] elif self.args.refine == 'refine_coord': if r < iter_i - 1: with torch.no_grad(): ctx_edges, inter_edges, reduced_tuple = self.extract_edges(X, batch_id, segment_id, is_global) ctx_edges = torch.cat((compound_edge_index, ctx_edges), dim=1) _, Z, prompt_node_feat, prompt_coord_feat = self.gnn(H, X, ctx_edges, inter_edges, LAS_edge_index, batched_complex_coord_LAS, segment_id=segment_id, batch_id=batch_id, reduced_tuple=reduced_tuple, pair_embed_batched=pair_embed_batched, pair_mask=pair_mask, LAS_mask=LAS_mask, p_p_dist_embed=p_p_dist_embed, c_c_dist_embed=c_c_dist_embed, mask=mask, prompt_node=prompt_node, prompt_coord=prompt_coord) X[mask] = Z[mask] else: with torch.no_grad(): ctx_edges, inter_edges, reduced_tuple = self.extract_edges(X, batch_id, segment_id, is_global) ctx_edges = torch.cat((compound_edge_index, ctx_edges), dim=1) H, Z, prompt_node_feat, prompt_coord_feat = self.gnn(H, X, ctx_edges, inter_edges, LAS_edge_index, batched_complex_coord_LAS, segment_id=segment_id, batch_id=batch_id, reduced_tuple=reduced_tuple, pair_embed_batched=pair_embed_batched, pair_mask=pair_mask, LAS_mask=LAS_mask, p_p_dist_embed=p_p_dist_embed, c_c_dist_embed=c_c_dist_embed, mask=mask, prompt_node=prompt_node, prompt_coord=prompt_coord) X[mask] = Z[mask] return X, H, prompt_node_feat, prompt_coord_feat