Source code for promptbind.utils.fabind_inference_dataset

from torch_geometric.data import Dataset
import pandas as pd
from tqdm import tqdm
import os
from .inference_pdb_utils import extract_protein_structure, extract_esm_feature
from .inference_mol_utils import read_smiles, extract_torchdrug_feature_from_mol, generate_conformation
from torch_geometric.data import HeteroData
import torch


[docs] class InferenceDataset(Dataset): def __init__(self, index_csv, pdb_file_dir, preprocess_dir): super().__init__(None, None, None, None) # extract pair index from csv file with open(index_csv, 'r') as f: content = f.readlines() info = [] for line in content[1:]: smiles, pdb = line.strip().split(',') info.append([smiles, pdb]) info = pd.DataFrame(info, columns=['smiles', 'pdb']) # read preprocessed data self.protein_feature, self.protein_structure = torch.load(os.path.join(preprocess_dir, 'processed_protein.pt')) self.data = [] for i in tqdm(range(len(info))): input_dict = {} # obtain compound try: smiles = info.iloc[i].smiles mol, molecule_info = torch.load(os.path.join(preprocess_dir, 'mol', f'mol_{i}.pt')) except: print('\nFailed to read molecule id ', i, ' We are skipping it.') continue # obtain proteins try: # obtain protein protein_structure = self.protein_structure[info.iloc[i].pdb] # obtain protein esm feature protein_esm_feature = self.protein_feature[info.iloc[i].pdb] except: print('\nFailed to read protein pdb ', info.iloc[i].pdb, ' We are skipping it.') continue # add to input dict input_dict['protein_esm_feature'] = protein_esm_feature input_dict['protein_structure'] = protein_structure input_dict['molecule'] = mol input_dict['molecule_smiles'] = smiles input_dict['molecule_info'] = molecule_info self.data.append(input_dict)
[docs] def len(self): return len(self.data)
[docs] def get(self, idx): input_dict = self.data[idx] protein_node_xyz = torch.tensor(input_dict['protein_structure']['coords'])[:, 1] protein_seq = input_dict['protein_structure']['seq'] protein_esm_feature = input_dict['protein_esm_feature'] smiles = input_dict['molecule_smiles'] rdkit_coords, compound_node_features, input_atom_edge_list, LAS_edge_index = input_dict['molecule_info'] n_protein_whole = protein_node_xyz.shape[0] n_compound = compound_node_features.shape[0] data = HeteroData() data.coord_offset = protein_node_xyz.mean(dim=0).unsqueeze(0) protein_node_xyz = protein_node_xyz - protein_node_xyz.mean(dim=0) coords_init = rdkit_coords - rdkit_coords.mean(axis=0) # compound graph data['compound'].node_feats = compound_node_features.float() data['compound', 'LAS', 'compound'].edge_index = LAS_edge_index data['compound'].node_coords = coords_init data['compound'].rdkit_coords = coords_init data['compound'].smiles = smiles data['compound_atom_edge_list'].x = (input_atom_edge_list[:,:2].long().contiguous() + 1).clone() data['LAS_edge_list'].x = (LAS_edge_index + 1).clone().t() data.node_xyz_whole = protein_node_xyz data.seq_whole = protein_seq data.idx = idx data.uid = input_dict['protein_structure']['name'] data.mol = input_dict['molecule'] # complex whole graph data['complex_whole_protein'].node_coords = torch.cat( # [glb_c || compound || glb_p || protein] ( torch.zeros(1, 3), coords_init - coords_init.mean(dim=0), # for pocket prediction module, the ligand is centered at the protein center/origin torch.zeros(1, 3), protein_node_xyz ), dim=0 ).float() data['complex_whole_protein'].node_coords_LAS = torch.cat( # [glb_c || compound || glb_p || protein] ( torch.zeros(1, 3), rdkit_coords, torch.zeros(1, 3), torch.zeros_like(protein_node_xyz) ), dim=0 ).float() segment = torch.zeros(n_protein_whole + n_compound + 2) segment[n_compound+1:] = 1 # compound: 0, protein: 1 data['complex_whole_protein'].segment = segment # protein or ligand mask = torch.zeros(n_protein_whole + n_compound + 2) mask[:n_compound+2] = 1 # glb_p can be updated data['complex_whole_protein'].mask = mask.bool() is_global = torch.zeros(n_protein_whole + n_compound + 2) is_global[0] = 1 is_global[n_compound+1] = 1 data['complex_whole_protein'].is_global = is_global.bool() data['complex_whole_protein', 'c2c', 'complex_whole_protein'].edge_index = input_atom_edge_list[:,:2].long().t().contiguous() + 1 data['complex_whole_protein', 'LAS', 'complex_whole_protein'].edge_index = LAS_edge_index + 1 data['protein_whole'].node_feats = protein_esm_feature return data