Source code for promptbind.data.data

import pandas as pd
import numpy as np
import torch
from torch_geometric.data import Dataset
from promptbind.utils.utils import construct_data_from_graph_gvp_mean
import lmdb
import pickle

[docs] class FABindDataSet(Dataset): def __init__(self, root, data=None, protein_dict=None, compound_dict=None, proteinMode=0, compoundMode=1, add_noise_to_com=None, pocket_radius=20, contactCutoff=8.0, predDis=True, args=None, use_whole_protein=False, compound_coords_init_mode=None, seed=42, pre=None, transform=None, pre_transform=None, pre_filter=None, noise_for_predicted_pocket=5.0, test_random_rotation=False, pocket_idx_no_noise=True, use_esm2_feat=False): self.data = data self.protein_dict = protein_dict self.compound_dict = compound_dict # this will call the process function to save the data, protein_dict and compound_dict super().__init__(root, transform, pre_transform, pre_filter) print(self.processed_paths) self.data = torch.load(self.processed_paths[0]) self.compound_rdkit_coords = torch.load(self.processed_paths[3]) self.protein_dict = lmdb.open(self.processed_paths[1], readonly=True, max_readers=1, lock=False, readahead=False, meminit=False) self.compound_dict = lmdb.open(self.processed_paths[2], readonly=True, max_readers=1, lock=False, readahead=False, meminit=False) if use_esm2_feat: self.protein_esm2_feat = lmdb.open(self.processed_paths[4], readonly=True, max_readers=1, lock=False, readahead=False, meminit=False) self.compound_coords_init_mode = compound_coords_init_mode self.add_noise_to_com = add_noise_to_com self.noise_for_predicted_pocket = noise_for_predicted_pocket self.proteinMode = proteinMode self.compoundMode = compoundMode self.pocket_radius = pocket_radius self.contactCutoff = contactCutoff self.predDis = predDis self.use_whole_protein = use_whole_protein self.test_random_rotation = test_random_rotation self.pocket_idx_no_noise = pocket_idx_no_noise self.use_esm2_feat = use_esm2_feat self.seed = seed self.args = args self.pre = pre @property def processed_file_names(self): return ['data.pt', 'protein_1d_3d.lmdb', 'compound_LAS_edge_index.lmdb', 'compound_rdkit_coords.pt', 'esm2_t33_650M_UR50D.lmdb']
[docs] def len(self): return len(self.data)
[docs] def get(self, idx): line = self.data.iloc[idx] pocket_com = line['pocket_com'] use_compound_com = line['use_compound_com'] use_whole_protein = line['use_whole_protein'] if "use_whole_protein" in line.index else self.use_whole_protein group = line['group'] if "group" in line.index else 'train' if group == 'train' and use_compound_com: add_noise_to_com = self.add_noise_to_com elif group == 'train' and not use_compound_com: add_noise_to_com = self.noise_for_predicted_pocket else: add_noise_to_com = None if group == 'train': random_rotation = True elif group == 'test' and self.test_random_rotation: random_rotation = True else: random_rotation = False protein_name = line['protein_name'] # pdb id if self.proteinMode == 0: with self.protein_dict.begin() as txn: protein_node_xyz, protein_seq= pickle.loads(txn.get(protein_name.encode())) if self.use_esm2_feat: with self.protein_esm2_feat.begin() as txn: protein_esm2_feat = pickle.loads(txn.get(protein_name.encode())) else: protein_esm2_feat = None name = line['compound_name'] rdkit_coords = self.compound_rdkit_coords[name] # compound embedding from torchdrug with self.compound_dict.begin() as txn: coords, compound_node_features, input_atom_edge_list, input_atom_edge_attr_list, pair_dis_distribution, LAS_edge_index = pickle.loads(txn.get(name.encode())) if self.proteinMode == 0: data, input_node_list, keepNode = construct_data_from_graph_gvp_mean(self.args, protein_node_xyz, protein_seq, coords, compound_node_features, input_atom_edge_list, input_atom_edge_attr_list, LAS_edge_index, rdkit_coords, compound_coords_init_mode=self.compound_coords_init_mode, contactCutoff=self.contactCutoff, includeDisMap=self.predDis, pocket_radius=self.pocket_radius, add_noise_to_com=add_noise_to_com, use_whole_protein=use_whole_protein, pdb_id=name, group=group, seed=self.seed, data_path=self.pre, use_compound_com_as_pocket=use_compound_com, chosen_pocket_com=pocket_com, compoundMode=self.compoundMode, random_rotation=random_rotation, pocket_idx_no_noise=self.pocket_idx_no_noise, protein_esm2_feat=protein_esm2_feat) data.pdb = line['pdb'] if "pdb" in line.index else f'smiles_{idx}' data.group = group return data
[docs] def get_data(args, logger, addNoise=None, use_whole_protein=False, compound_coords_init_mode='pocket_center_rdkit', pre="/PDBbind_data/pdbbind2020"): if args.data == "0": logger.log_message(f"Loading dataset") logger.log_message(f"compound feature based on torchdrug") logger.log_message(f"protein feature based on esm2") add_noise_to_com = float(addNoise) if addNoise else None new_dataset = FABindDataSet(f"{pre}/dataset", add_noise_to_com=add_noise_to_com, use_whole_protein=use_whole_protein, compound_coords_init_mode=compound_coords_init_mode, pocket_radius=args.pocket_radius, noise_for_predicted_pocket=args.noise_for_predicted_pocket, test_random_rotation=args.test_random_rotation, pocket_idx_no_noise=args.pocket_idx_no_noise, use_esm2_feat=args.use_esm2_feat, seed=args.seed, pre=pre, args=args) # load compound features extracted using torchdrug. # c_length: number of atoms in the compound # This filter may cause some samples to be filtered out. So the actual number of samples is less than that in the original papers. train_tmp = new_dataset.data.query("c_length < 100 and native_num_contact > 5 and group =='train' and use_compound_com").reset_index(drop=True) valid_test_tmp = new_dataset.data.query("(group == 'valid' or group == 'test') and use_compound_com").reset_index(drop=True) new_dataset.data = pd.concat([train_tmp, valid_test_tmp], axis=0).reset_index(drop=True) d = new_dataset.data only_native_train_index = d.query("group =='train'").index.values train = new_dataset[only_native_train_index] valid_index = d.query("group =='valid'").index.values valid = new_dataset[valid_index] test_index = d.query("group =='test'").index.values test = new_dataset[test_index] return train, valid, test