Source code for promptbind.models.model

import torch
from torch_geometric.utils import to_dense_batch, to_dense_adj
from torch import nn
from torch.nn import Linear
import torch.nn as nn
from .att_model import EfficientMCAttModel
import torch.nn.functional as F
from promptbind.utils.utils import get_keepNode_tensor, gumbel_softmax_no_random
import random


[docs] class Transition_diff_out_dim(torch.nn.Module): # separate left/right edges (block1/block2). def __init__(self, embedding_channels=256, out_channels=256, n=4): super().__init__() self.layernorm = torch.nn.LayerNorm(embedding_channels) self.linear1 = Linear(embedding_channels, n*embedding_channels) self.linear2 = Linear(n*embedding_channels, out_channels) torch.nn.init.xavier_uniform_(self.linear1.weight, gain=0.001) torch.nn.init.xavier_uniform_(self.linear2.weight, gain=0.001)
[docs] def forward(self, z): # z of shape b, i, j, embedding_channels, where i is protein dim, j is compound dim. z = self.layernorm(z) z = self.linear2((self.linear1(z)).relu()) return z
[docs] class IaBNet_mean_and_pocket_prediction_cls_coords_dependent(torch.nn.Module): def __init__(self, args, embedding_channels=128, pocket_pred_embedding_channels=128): super().__init__() self.layernorm = torch.nn.LayerNorm(embedding_channels) self.args = args self.coordinate_scale = args.coordinate_scale self.normalize_coord = lambda x: x / self.coordinate_scale self.unnormalize_coord = lambda x: x * self.coordinate_scale self.stage_prob = args.stage_prob n_channel = 1 # ligand node has only one coordinate dimension. pocket_prompt_nf = args.pocket_prompt_nf complex_prompt_nf = args.complex_prompt_nf self.pocket_prompt_node_component = nn.Parameter(torch.zeros(1, pocket_prompt_nf, pocket_prompt_nf)) self.pocket_prompt_coord_component = nn.Parameter(torch.zeros(1, pocket_prompt_nf, pocket_prompt_nf)) self.complex_prompt_node_component = nn.Parameter(torch.zeros(1, complex_prompt_nf, complex_prompt_nf)) self.complex_prompt_coord_component = nn.Parameter(torch.zeros(1, complex_prompt_nf, complex_prompt_nf)) self.pocket_pred_model = EfficientMCAttModel( args, pocket_pred_embedding_channels, pocket_pred_embedding_channels, pocket_prompt_nf, n_channel, n_edge_feats=0, n_layers=args.pocket_pred_layers, n_iter=args.pocket_pred_n_iter, inter_cutoff=args.inter_cutoff, intra_cutoff=args.intra_cutoff, normalize_coord=self.normalize_coord, unnormalize_coord=self.unnormalize_coord, ) self.complex_model = EfficientMCAttModel( args, embedding_channels, embedding_channels, complex_prompt_nf, n_channel, n_edge_feats=0, n_layers=args.mean_layers, n_iter=args.n_iter, inter_cutoff=args.inter_cutoff, intra_cutoff=args.intra_cutoff, normalize_coord=self.normalize_coord, unnormalize_coord=self.unnormalize_coord, ) self.protein_to_pocket = Transition_diff_out_dim(embedding_channels=embedding_channels, n=4, out_channels=1) # global nodes for protein / compound self.glb_c = nn.Parameter(torch.ones(1, embedding_channels)) self.glb_p = nn.Parameter(torch.ones(1, embedding_channels)) if args.use_esm2_feat: protein_hidden = 1280 else: protein_hidden = 15 if args.esm2_concat_raw: protein_hidden = 1295 # self.protein_linear = nn.Linear(protein_hidden, embedding_channels) # hard-coded GVP features # self.compound_linear = nn.Linear(56, embedding_channels) #self.protein_bottleneck = nn.Linear(ankh_hidden, protein_hidden) self.protein_linear_whole_protein = nn.Linear(protein_hidden, embedding_channels) # hard-coded GVP features self.compound_linear_whole_protein = nn.Linear(56, embedding_channels) self.embedding_shrink = nn.Linear(embedding_channels, pocket_pred_embedding_channels) self.embedding_enlarge = nn.Linear(pocket_pred_embedding_channels, embedding_channels) self.distmap_mlp = nn.Sequential( nn.Linear(embedding_channels, embedding_channels), nn.ReLU(), nn.Linear(embedding_channels, 1)) # torch.nn.init.xavier_uniform_(self.protein_linear.weight, gain=0.001) # torch.nn.init.xavier_uniform_(self.compound_linear.weight, gain=0.001) #torch.nn.init.xavier_uniform_(self.protein_bottleneck.weight, gain=0.001) torch.nn.init.xavier_uniform_(self.protein_linear_whole_protein.weight, gain=0.001) torch.nn.init.xavier_uniform_(self.compound_linear_whole_protein.weight, gain=0.001) torch.nn.init.xavier_uniform_(self.embedding_shrink.weight, gain=0.001) torch.nn.init.xavier_uniform_(self.embedding_enlarge.weight, gain=0.001) torch.nn.init.xavier_uniform_(self.distmap_mlp[0].weight, gain=0.001) torch.nn.init.xavier_uniform_(self.distmap_mlp[2].weight, gain=0.001)
[docs] def forward(self, data, stage=1, train=False): keepNode_less_5 = 0 compound_batch = data['compound'].batch pocket_batch = data['pocket'].batch complex_batch = data['complex'].batch protein_batch_whole = data['protein_whole'].batch complex_batch_whole_protein = data['complex_whole_protein'].batch # Pocket Prediction # nodes_whole = (data['protein_whole']['node_s'], data['protein_whole']['node_v']) # edges_whole = (data[("protein_whole", "p2p", "protein_whole")]["edge_s"], data[("protein_whole", "p2p", "protein_whole")]["edge_v"]) # protein_out_whole = self.conv_protein(nodes_whole, data[("protein_whole", "p2p", "protein_whole")]["edge_index"], edges_whole, data.seq_whole) # protein_out_batched_whole, protein_out_mask_whole = to_dense_batch(protein_out_whole, protein_batch_whole) # pocket_cls_pred = self.protein_to_pocket(protein_out_batched_whole) # pocket_cls_pred = pocket_cls_pred.squeeze(-1) * protein_out_mask_whole # pocket_cls, _ = to_dense_batch(data.pocket_idx, protein_batch_whole) batched_complex_coord_whole_protein = self.normalize_coord(data['complex_whole_protein'].node_coords.unsqueeze(-2)) batched_complex_coord_LAS_whole_protein = self.normalize_coord(data['complex_whole_protein'].node_coords_LAS.unsqueeze(-2)) batched_compound_emb_whole_protein = self.compound_linear_whole_protein(data['compound'].node_feats) batched_protein_emb_whole_protein = self.protein_linear_whole_protein(data['protein_whole'].node_feats) # TODO self.glb_c and self.glb_p shared? for i in range(complex_batch_whole_protein.max()+1): if i == 0: new_samples_whole_protein = torch.cat(( self.glb_c, batched_compound_emb_whole_protein[compound_batch==i], self.glb_p, batched_protein_emb_whole_protein[protein_batch_whole==i] ), dim=0) else: new_sample_whole_protein = torch.cat(( self.glb_c, batched_compound_emb_whole_protein[compound_batch==i], self.glb_p, batched_protein_emb_whole_protein[protein_batch_whole==i] ), dim=0) new_samples_whole_protein = torch.cat((new_samples_whole_protein, new_sample_whole_protein), dim=0) new_samples_whole_protein = self.embedding_shrink(new_samples_whole_protein) complex_coords_whole_protein, complex_out_whole_protein, pocket_prompt_node_feat, pocket_prompt_coord_feat = self.pocket_pred_model( batched_complex_coord_whole_protein, new_samples_whole_protein, batch_id=complex_batch_whole_protein, segment_id=data['complex_whole_protein'].segment, mask=data['complex_whole_protein'].mask, is_global=data['complex_whole_protein'].is_global, compound_edge_index=data['complex_whole_protein', 'c2c', 'complex_whole_protein'].edge_index, LAS_edge_index=data['complex_whole_protein', 'LAS', 'complex_whole_protein'].edge_index, batched_complex_coord_LAS=batched_complex_coord_LAS_whole_protein, LAS_mask=None, prompt_node=self.pocket_prompt_node_component, prompt_coord=self.pocket_prompt_coord_component ) complex_out_whole_protein = self.embedding_enlarge(complex_out_whole_protein) compound_flag_whole_protein = torch.logical_and(data['complex_whole_protein'].segment == 0, ~data['complex_whole_protein'].is_global) compound_out_whole_protein = complex_out_whole_protein[compound_flag_whole_protein] protein_flag_whole_protein = torch.logical_and(data['complex_whole_protein'].segment == 1, ~data['complex_whole_protein'].is_global) protein_out_whole_protein = complex_out_whole_protein[protein_flag_whole_protein] protein_out_batched_whole, protein_out_mask_whole = to_dense_batch(protein_out_whole_protein, protein_batch_whole) pocket_cls_pred = self.protein_to_pocket(protein_out_batched_whole) pocket_cls_pred = pocket_cls_pred.squeeze(-1) * protein_out_mask_whole pocket_cls, _ = to_dense_batch(data.pocket_idx, protein_batch_whole) pocket_coords_batched, _ = to_dense_batch(self.normalize_coord(data.node_xyz), pocket_batch) protein_coords_batched_whole, protein_coords_mask_whole = to_dense_batch(data.node_xyz_whole, protein_batch_whole) pred_index_true = pocket_cls_pred.sigmoid().unsqueeze(-1) pred_index_false = 1. - pred_index_true pred_index_prob = torch.cat([pred_index_false, pred_index_true], dim=-1) # For training stability pred_index_prob = torch.clamp(pred_index_prob, min=1e-6, max=1-1e-6) pred_index_log_prob = torch.log(pred_index_prob) if self.pocket_pred_model.training: pred_index_one_hot = F.gumbel_softmax(pred_index_log_prob, tau=self.args.gs_tau, hard=self.args.gs_hard) else: pred_index_one_hot = gumbel_softmax_no_random(pred_index_log_prob, tau=self.args.gs_tau, hard=self.args.gs_hard) pred_index_one_hot_true = (pred_index_one_hot[:, :, 1] * protein_out_mask_whole).unsqueeze(-1) pred_pocket_center_gumbel = pred_index_one_hot_true * protein_coords_batched_whole pred_pocket_center = pred_pocket_center_gumbel.sum(dim=1) / pred_index_one_hot_true.sum(dim=1) center_dist_ligand_pocket_batch = torch.norm(data.coords_center - pred_pocket_center, p=2, dim=-1) center_dist_mean = center_dist_ligand_pocket_batch.mean(dim=-1) if self.pocket_pred_model.training and center_dist_mean < self.args.center_dist_threshold: if random.random() < self.stage_prob: final_stage = 2 else: final_stage = 1 elif self.pocket_pred_model.training and center_dist_mean >= self.args.center_dist_threshold: final_stage = 1 else: final_stage = stage if final_stage == 2: # Replace raw feature with pocket prediction output # batched_compound_emb = self.compound_linear(data['compound'].node_feats) batched_compound_emb = compound_out_whole_protein # keepNode_batch = torch.tensor([], device=compound_batch.device) data['complex'].node_coords = torch.tensor([], device=compound_batch.device) data['complex'].node_coords_LAS = torch.tensor([], device=compound_batch.device) data['complex'].segment = torch.tensor([], device=compound_batch.device) data['complex'].mask = torch.tensor([], device=compound_batch.device) data['complex'].is_global = torch.tensor([], device=compound_batch.device) complex_batch = torch.tensor([], device=compound_batch.device) pocket_batch = torch.tensor([], device=compound_batch.device) data['complex', 'c2c', 'complex'].edge_index = torch.tensor([], device=compound_batch.device) data['complex', 'LAS', 'complex'].edge_index = torch.tensor([], device=compound_batch.device) pocket_coords_concats = torch.tensor([], device=compound_batch.device) dis_map = torch.tensor([], device=compound_batch.device) if self.args.local_eval: pred_pocket_center += self.args.train_pred_pocket_noise * (2 * torch.rand_like(pred_pocket_center) - 1) if self.args.train_pred_pocket_noise and train: pred_pocket_center += self.args.train_pred_pocket_noise * (2 * torch.rand_like(pred_pocket_center) - 1) for i in range(pred_pocket_center.shape[0]): protein_i = data.node_xyz_whole[protein_batch_whole==i].detach() keepNode = get_keepNode_tensor(protein_i, self.args.pocket_radius, None, pred_pocket_center[i].detach()) # TODO Check the case if keepNode.sum() < 5: # if only include less than 5 residues, simply add first 100 residues. keepNode[:100] = True keepNode_less_5 += 1 pocket_emb = protein_out_batched_whole[i][protein_out_mask_whole[i]][keepNode] # node emb if i == 0: new_samples = torch.cat(( self.glb_c, batched_compound_emb[compound_batch==i], self.glb_p, pocket_emb ), dim=0) else: new_sample = torch.cat(( self.glb_c, batched_compound_emb[compound_batch==i], self.glb_p, pocket_emb ), dim=0) new_samples = torch.cat((new_samples, new_sample), dim=0) # Node coords. # Ligand coords are initialized at pocket center with rdkit random conformation. # Pocket coords are from origin protein coords. pocket_coords = protein_coords_batched_whole[i][protein_coords_mask_whole[i]][keepNode] pocket_coords_concats = torch.cat((pocket_coords_concats, pocket_coords), dim=0) data['complex'].node_coords = torch.cat( # [glb_c || compound || glb_p || protein] ( data['complex'].node_coords, torch.zeros((1, 3), device=compound_batch.device), data['compound'].node_coords[compound_batch==i] - data['compound'].node_coords[compound_batch==i].mean(dim=0).reshape(1, 3) + pocket_coords.mean(dim=0).reshape(1, 3), torch.zeros((1, 3), device=compound_batch.device), pocket_coords, ), dim=0 ).float() if self.args.compound_coords_init_mode == 'redocking' or self.args.compound_coords_init_mode == 'redocking_no_rotate': data['complex'].node_coords_LAS = torch.cat( # [glb_c || compound || glb_p || protein] ( data['complex'].node_coords_LAS, torch.zeros((1, 3), device=compound_batch.device), torch.tensor(data['compound'].node_coords[compound_batch==i]), torch.zeros((1, 3), device=compound_batch.device), torch.zeros_like(pocket_coords) ), dim=0 ).float() else: data['complex'].node_coords_LAS = torch.cat( # [glb_c || compound || glb_p || protein] ( data['complex'].node_coords_LAS, torch.zeros((1, 3), device=compound_batch.device), data['compound'].rdkit_coords[compound_batch==i], torch.zeros((1, 3), device=compound_batch.device), torch.zeros_like(pocket_coords) ), dim=0 ).float() # masks n_protein = pocket_emb.shape[0] n_compound = batched_compound_emb[compound_batch==i].shape[0] segment = torch.zeros((n_protein + n_compound + 2), device=complex_batch.device) segment[n_compound+1:] = 1 # compound: 0, protein: 1 data['complex'].segment = torch.cat((data['complex'].segment, segment), dim=0) # protein or ligand mask = torch.zeros((n_protein + n_compound + 2), device=complex_batch.device) mask[:n_compound+2] = 1 # glb_p can be updated data['complex'].mask = torch.cat((data['complex'].mask, mask.bool()), dim=0) is_global = torch.zeros((n_protein + n_compound + 2), device=complex_batch.device) is_global[0] = 1 is_global[n_compound+1] = 1 data['complex'].is_global = torch.cat((data['complex'].is_global, is_global.bool()), dim=0) # edge_index data['complex', 'c2c', 'complex'].edge_index = torch.cat( ( data['complex', 'c2c', 'complex'].edge_index, data['compound_atom_edge_list'].x[data['compound_atom_edge_list'].batch==i].t() + complex_batch.shape[0] ), dim=1) data['complex', 'LAS', 'complex'].edge_index = torch.cat( ( data['complex', 'LAS', 'complex'].edge_index, data['LAS_edge_list'].x[data['LAS_edge_list'].batch==i].t() + complex_batch.shape[0] ), dim=1) # batch_id complex_batch = torch.cat((complex_batch, torch.ones((n_compound + n_protein + 2), device=compound_batch.device)*i), dim=0) pocket_batch = torch.cat((pocket_batch, torch.ones((n_protein), device=compound_batch.device)*i), dim=0) # distance map dis_map_i = torch.cdist(pocket_coords, data['compound'].node_coords[compound_batch==i].to(torch.float32)).flatten() dis_map_i[dis_map_i>10] = 10 dis_map = torch.cat((dis_map, dis_map_i), dim=0) # construct inputs batched_complex_coord = self.normalize_coord(data['complex'].node_coords.unsqueeze(-2)) batched_complex_coord_LAS = self.normalize_coord(data['complex'].node_coords_LAS.unsqueeze(-2)) complex_batch = complex_batch.to(torch.int64) pocket_batch = pocket_batch.to(torch.int64) pocket_coords_batched, _ = to_dense_batch(self.normalize_coord(pocket_coords_concats), pocket_batch) data['complex', 'c2c', 'complex'].edge_index = data['complex', 'c2c', 'complex'].edge_index.to(torch.int64) data['complex', 'LAS', 'complex'].edge_index = data['complex', 'LAS', 'complex'].edge_index.to(torch.int64) data['complex'].segment = data['complex'].segment.to(torch.bool) data['complex'].mask = data['complex'].mask.to(torch.bool) data['complex'].is_global = data['complex'].is_global.to(torch.bool) elif final_stage == 1: batched_compound_emb = compound_out_whole_protein batched_pocket_emb = protein_out_whole_protein[data['pocket'].keepNode] batched_complex_coord = self.normalize_coord(data['complex'].node_coords.unsqueeze(-2)) batched_complex_coord_LAS = self.normalize_coord(data['complex'].node_coords_LAS.unsqueeze(-2)) for i in range(complex_batch.max()+1): if i == 0: new_samples = torch.cat(( self.glb_c, batched_compound_emb[compound_batch==i], self.glb_p, batched_pocket_emb[pocket_batch==i] ), dim=0) else: new_sample = torch.cat(( self.glb_c, batched_compound_emb[compound_batch==i], self.glb_p, batched_pocket_emb[pocket_batch==i] ), dim=0) new_samples = torch.cat((new_samples, new_sample), dim=0) dis_map = data.dis_map complex_coords, complex_out, complex_prompt_node_feat, complex_prompt_coord_feat = self.complex_model( batched_complex_coord, new_samples, batch_id=complex_batch, segment_id=data['complex'].segment, mask=data['complex'].mask, is_global=data['complex'].is_global, compound_edge_index=data['complex', 'c2c', 'complex'].edge_index, LAS_edge_index=data['complex', 'LAS', 'complex'].edge_index, batched_complex_coord_LAS=batched_complex_coord_LAS, LAS_mask=None, prompt_node=self.complex_prompt_node_component, prompt_coord=self.complex_prompt_coord_component ) compound_flag = torch.logical_and(data['complex'].segment == 0, ~data['complex'].is_global) protein_flag = torch.logical_and(data['complex'].segment == 1, ~data['complex'].is_global) pocket_out = complex_out[protein_flag] compound_out = complex_out[compound_flag] compound_coords_out = complex_coords[compound_flag].squeeze(-2) # pocket_batch version could further process b matrix. better than for loop. # pocket_out_batched of shape (b, n, c). to_dense_batch is torch geometric function. pocket_out_batched, pocket_out_mask = to_dense_batch(pocket_out, pocket_batch) compound_out_batched, compound_out_mask = to_dense_batch(compound_out, compound_batch) compound_coords_out_batched, compound_coords_out_mask = to_dense_batch(compound_coords_out, compound_batch) # get the pair distance of protein and compound pocket_com_dis_map = torch.cdist(pocket_coords_batched, compound_coords_out_batched) # Assume self.args.distmap_pred == 'mlp': pocket_out_batched = self.layernorm(pocket_out_batched) compound_out_batched = self.layernorm(compound_out_batched) # z of shape, b, protein_length, compound_length, channels. z = torch.einsum("bik,bjk->bijk", pocket_out_batched, compound_out_batched) z_mask = torch.einsum("bi,bj->bij", pocket_out_mask, compound_out_mask) b = self.distmap_mlp(z).squeeze(-1) y_pred = b[z_mask] y_pred = y_pred.sigmoid() * 10 # normalize to 0 to 10. y_pred_by_coords = pocket_com_dis_map[z_mask] y_pred_by_coords = self.unnormalize_coord(y_pred_by_coords) y_pred_by_coords = torch.clamp(y_pred_by_coords, 0, 10) compound_coords_out = self.unnormalize_coord(compound_coords_out) return compound_coords_out, compound_batch, y_pred, y_pred_by_coords, pocket_cls_pred, pocket_cls, protein_out_mask_whole, protein_coords_batched_whole, pred_pocket_center, dis_map, keepNode_less_5, (self.pocket_prompt_node_component, self.pocket_prompt_coord_component), (self.complex_prompt_node_component, self.complex_prompt_coord_component), (pocket_prompt_node_feat, pocket_prompt_coord_feat), (complex_prompt_node_feat, complex_prompt_coord_feat)
[docs] def inference(self, data): compound_batch = data['compound'].batch protein_batch_whole = data['protein_whole'].batch complex_batch_whole_protein = data['complex_whole_protein'].batch # Pocket Prediction batched_complex_coord_whole_protein = self.normalize_coord(data['complex_whole_protein'].node_coords.unsqueeze(-2)) batched_complex_coord_LAS_whole_protein = self.normalize_coord(data['complex_whole_protein'].node_coords_LAS.unsqueeze(-2)) batched_compound_emb_whole_protein = self.compound_linear_whole_protein(data['compound'].node_feats) batched_protein_emb_whole_protein = self.protein_linear_whole_protein((data['protein_whole'].node_feats)) # TODO self.glb_c and self.glb_p shared? for i in range(complex_batch_whole_protein.max()+1): if i == 0: new_samples_whole_protein = torch.cat(( self.glb_c, batched_compound_emb_whole_protein[compound_batch==i], self.glb_p, batched_protein_emb_whole_protein[protein_batch_whole==i] ), dim=0) else: new_sample_whole_protein = torch.cat(( self.glb_c, batched_compound_emb_whole_protein[compound_batch==i], self.glb_p, batched_protein_emb_whole_protein[protein_batch_whole==i] ), dim=0) new_samples_whole_protein = torch.cat((new_samples_whole_protein, new_sample_whole_protein), dim=0) new_samples_whole_protein = self.embedding_shrink(new_samples_whole_protein) complex_coords_whole_protein, complex_out_whole_protein, pocket_prompt_node_feat, pocket_prompt_coord_feat = self.pocket_pred_model( batched_complex_coord_whole_protein, new_samples_whole_protein, batch_id=complex_batch_whole_protein, segment_id=data['complex_whole_protein'].segment, mask=data['complex_whole_protein'].mask, is_global=data['complex_whole_protein'].is_global, compound_edge_index=data['complex_whole_protein', 'c2c', 'complex_whole_protein'].edge_index, LAS_edge_index=data['complex_whole_protein', 'LAS', 'complex_whole_protein'].edge_index, batched_complex_coord_LAS=batched_complex_coord_LAS_whole_protein, LAS_mask=None, prompt_node=self.pocket_prompt_node_component, prompt_coord=self.pocket_prompt_coord_component ) complex_out_whole_protein = self.embedding_enlarge(complex_out_whole_protein) compound_flag_whole_protein = torch.logical_and(data['complex_whole_protein'].segment == 0, ~data['complex_whole_protein'].is_global) compound_out_whole_protein = complex_out_whole_protein[compound_flag_whole_protein] protein_flag_whole_protein = torch.logical_and(data['complex_whole_protein'].segment == 1, ~data['complex_whole_protein'].is_global) protein_out_whole_protein = complex_out_whole_protein[protein_flag_whole_protein] protein_out_batched_whole, protein_out_mask_whole = to_dense_batch(protein_out_whole_protein, protein_batch_whole) pocket_cls_pred = self.protein_to_pocket(protein_out_batched_whole) pocket_cls_pred = pocket_cls_pred.squeeze(-1) * protein_out_mask_whole protein_coords_batched_whole, protein_coords_mask_whole = to_dense_batch(data.node_xyz_whole, protein_batch_whole) pred_pocket_center = torch.zeros((pocket_cls_pred.shape[0], 3)).to(pocket_cls_pred.device) batch_len = protein_out_mask_whole.sum(dim=1).detach() for i, j in enumerate(batch_len): pred_index_bool = (pocket_cls_pred.detach()[i][:j].sigmoid().round().int() == 1) if pred_index_bool.sum() != 0: pred_pocket_center[i] = protein_coords_batched_whole.detach()[i][:j][pred_index_bool].mean(dim=0) else: # all the prediction is False, use gumbel soft pred_index_true = pocket_cls_pred[i][:j].sigmoid().unsqueeze(-1) pred_index_false = 1. - pred_index_true pred_index_prob = torch.cat([pred_index_false, pred_index_true], dim=-1) pred_index_log_prob = torch.log(pred_index_prob) pred_index_one_hot = gumbel_softmax_no_random(pred_index_log_prob, tau=self.args.gs_tau, hard=self.args.gs_hard) pred_index_one_hot_true = pred_index_one_hot[:, 1].unsqueeze(-1) pred_pocket_center_gumbel = pred_index_one_hot_true * protein_coords_batched_whole[i][:j] pred_pocket_center[i] = pred_pocket_center_gumbel.sum(dim=0) / pred_index_one_hot_true.sum(dim=0) # Replace raw feature with pocket prediction output # batched_compound_emb = self.compound_linear(data['compound'].node_feats) batched_compound_emb = compound_out_whole_protein # keepNode_batch = torch.tensor([], device=compound_batch.device) data['complex'].node_coords = torch.tensor([], device=compound_batch.device) data['complex'].node_coords_LAS = torch.tensor([], device=compound_batch.device) data['complex'].segment = torch.tensor([], device=compound_batch.device) data['complex'].mask = torch.tensor([], device=compound_batch.device) data['complex'].is_global = torch.tensor([], device=compound_batch.device) complex_batch = torch.tensor([], device=compound_batch.device) pocket_batch = torch.tensor([], device=compound_batch.device) data['complex', 'c2c', 'complex'].edge_index = torch.tensor([], device=compound_batch.device) data['complex', 'LAS', 'complex'].edge_index = torch.tensor([], device=compound_batch.device) pocket_coords_concats = torch.tensor([], device=compound_batch.device) dis_map = torch.tensor([], device=compound_batch.device) for i in range(pred_pocket_center.shape[0]): protein_i = data.node_xyz_whole[protein_batch_whole==i].detach() keepNode = get_keepNode_tensor(protein_i, self.args.pocket_radius, None, pred_pocket_center[i].detach()) # TODO Check the case if keepNode.sum() < 5: # if only include less than 5 residues, simply add first 100 residues. keepNode[:100] = True pocket_emb = protein_out_batched_whole[i][protein_out_mask_whole[i]][keepNode] # node emb if i == 0: new_samples = torch.cat(( self.glb_c, batched_compound_emb[compound_batch==i], self.glb_p, pocket_emb ), dim=0) else: new_sample = torch.cat(( self.glb_c, batched_compound_emb[compound_batch==i], self.glb_p, pocket_emb ), dim=0) new_samples = torch.cat((new_samples, new_sample), dim=0) # Node coords. # Ligand coords are initialized at pocket center with rdkit random conformation. # Pocket coords are from origin protein coords. pocket_coords = protein_coords_batched_whole[i][protein_coords_mask_whole[i]][keepNode] pocket_coords_concats = torch.cat((pocket_coords_concats, pocket_coords), dim=0) data['complex'].node_coords = torch.cat( # [glb_c || compound || glb_p || protein] ( data['complex'].node_coords, torch.zeros((1, 3), device=compound_batch.device), data['compound'].node_coords[compound_batch==i] - data['compound'].node_coords[compound_batch==i].mean(dim=0).reshape(1, 3) + pocket_coords.mean(dim=0).reshape(1, 3), torch.zeros((1, 3), device=compound_batch.device), pocket_coords, ), dim=0 ).float() if self.args.compound_coords_init_mode == 'redocking' or self.args.compound_coords_init_mode == 'redocking_no_rotate': data['complex'].node_coords_LAS = torch.cat( # [glb_c || compound || glb_p || protein] ( data['complex'].node_coords_LAS, torch.zeros((1, 3), device=compound_batch.device), data['compound'].node_coords[compound_batch==i], torch.zeros((1, 3), device=compound_batch.device), torch.zeros_like(pocket_coords) ), dim=0 ).float() else: data['complex'].node_coords_LAS = torch.cat( # [glb_c || compound || glb_p || protein] ( data['complex'].node_coords_LAS, torch.zeros((1, 3), device=compound_batch.device), data['compound'].rdkit_coords[compound_batch==i], torch.zeros((1, 3), device=compound_batch.device), torch.zeros_like(pocket_coords) ), dim=0 ).float() # masks n_protein = pocket_emb.shape[0] n_compound = batched_compound_emb[compound_batch==i].shape[0] segment = torch.zeros((n_protein + n_compound + 2), device=complex_batch.device) segment[n_compound+1:] = 1 # compound: 0, protein: 1 data['complex'].segment = torch.cat((data['complex'].segment, segment), dim=0) # protein or ligand mask = torch.zeros((n_protein + n_compound + 2), device=complex_batch.device) mask[:n_compound+2] = 1 # glb_p can be updated data['complex'].mask = torch.cat((data['complex'].mask, mask.bool()), dim=0) is_global = torch.zeros((n_protein + n_compound + 2), device=complex_batch.device) is_global[0] = 1 is_global[n_compound+1] = 1 data['complex'].is_global = torch.cat((data['complex'].is_global, is_global.bool()), dim=0) # edge_index data['complex', 'c2c', 'complex'].edge_index = torch.cat( ( data['complex', 'c2c', 'complex'].edge_index, data['compound_atom_edge_list'].x[data['compound_atom_edge_list'].batch==i].t() + complex_batch.shape[0] ), dim=1) data['complex', 'LAS', 'complex'].edge_index = torch.cat( ( data['complex', 'LAS', 'complex'].edge_index, data['LAS_edge_list'].x[data['LAS_edge_list'].batch==i].t() + complex_batch.shape[0] ), dim=1) # batch_id complex_batch = torch.cat((complex_batch, torch.ones((n_compound + n_protein + 2), device=compound_batch.device)*i), dim=0) pocket_batch = torch.cat((pocket_batch, torch.ones((n_protein), device=compound_batch.device)*i), dim=0) # distance map dis_map_i = torch.cdist(pocket_coords, data['compound'].node_coords[compound_batch==i].to(torch.float32)).flatten() dis_map_i[dis_map_i>10] = 10 dis_map = torch.cat((dis_map, dis_map_i), dim=0) # construct inputs batched_complex_coord = self.normalize_coord(data['complex'].node_coords.unsqueeze(-2)) batched_complex_coord_LAS = self.normalize_coord(data['complex'].node_coords_LAS.unsqueeze(-2)) complex_batch = complex_batch.to(torch.int64) pocket_batch = pocket_batch.to(torch.int64) pocket_coords_batched, _ = to_dense_batch(self.normalize_coord(pocket_coords_concats), pocket_batch) data['complex', 'c2c', 'complex'].edge_index = data['complex', 'c2c', 'complex'].edge_index.to(torch.int64) data['complex', 'LAS', 'complex'].edge_index = data['complex', 'LAS', 'complex'].edge_index.to(torch.int64) data['complex'].segment = data['complex'].segment.to(torch.bool) data['complex'].mask = data['complex'].mask.to(torch.bool) data['complex'].is_global = data['complex'].is_global.to(torch.bool) data['complex'].batch = complex_batch complex_coords, complex_out, complex_prompt_node_feat, complex_prompt_coord_feat = self.complex_model( batched_complex_coord, new_samples, batch_id=complex_batch, segment_id=data['complex'].segment, mask=data['complex'].mask, is_global=data['complex'].is_global, compound_edge_index=data['complex', 'c2c', 'complex'].edge_index, LAS_edge_index=data['complex', 'LAS', 'complex'].edge_index, batched_complex_coord_LAS=batched_complex_coord_LAS, LAS_mask=None, prompt_node=self.complex_prompt_node_component, prompt_coord=self.complex_prompt_coord_component ) compound_flag = torch.logical_and(data['complex'].segment == 0, ~data['complex'].is_global) compound_coords_out = complex_coords[compound_flag].squeeze(-2) compound_coords_out = self.unnormalize_coord(compound_coords_out) return compound_coords_out, compound_batch, (pocket_prompt_node_feat, pocket_prompt_coord_feat), (complex_prompt_node_feat, complex_prompt_coord_feat)
[docs] def get_model(args, logger, device): if args.mode == 5: logger.log_message("PromptBind") model = IaBNet_mean_and_pocket_prediction_cls_coords_dependent(args, args.hidden_size, args.pocket_pred_hidden_size) return model