Source code for bapred.data.data

import os
import torch, dgl
from collections import defaultdict
from rdkit import Chem, RDLogger
from rdkit.Chem import AllChem
RDLogger.DisableLog('rdApp.*')

from dgl.data import DGLDataset

from bapred.data.atom_feature import *

[docs] def process_ligand_file(file_path): extension = os.path.splitext(file_path)[-1].lower() if extension == '.sdf': supplier = enumerate(Chem.SDMolSupplier(file_path)) elif extension == '.mol2': with open(file_path, 'r') as f: mol2_data = f.read() mol2_blocks = mol2_data.split('@<TRIPOS>MOLECULE') supplier = enumerate(Chem.MolFromMol2Block('@<TRIPOS>MOLECULE' + block) for block in mol2_blocks[1:]) else: raise ValueError(f"Unsupported file type: {extension}") ligands = [] err_tag = [] ligand_names = [] base_name = os.path.splitext(os.path.basename(file_path))[0] for idx, mol in supplier: if mol is not None: ligands.append(mol) err_tag.append(0) ligand_name = mol.GetProp('_Name') if ligand_name == '': ligand_name = f"{base_name}_{idx}" ligand_names.append(ligand_name) else: ligands.append(None) err_tag.append(1) ligand_names.append(f"{base_name}_{idx}") return ligands, err_tag, ligand_names
[docs] def load_ligands(file_path): lig_mols = [] err_tags = [] lig_names = [] def process_single_file(line): assert os.path.isfile(line), f"File not found: {line}" return process_ligand_file(line) file_extension = os.path.splitext(file_path)[-1].lower() if file_extension == '.txt': with open(file_path, 'r') as f: lines = [line.strip() for line in f if line.strip()] for line in lines: file_ligands, file_err_tag, file_ligand_names = process_single_file(line) lig_mols.extend(file_ligands) err_tags.extend(file_err_tag) lig_names.extend(file_ligand_names) elif file_extension in ['.sdf', '.mol2']: lig_mols, err_tags, lig_names = process_single_file(file_path) else: raise ValueError("Unsupported file type. Use '.txt', '.sdf', or '.mol2'.") return lig_mols, err_tags, lig_names
[docs] class BAPredDataset(DGLDataset): def __init__(self, protein_pdb, ligand_file, train=True): super(BAPredDataset, self).__init__(name='Protein Ligand Binding Affinity prediction') self.lig_mols, self.err_tags, self.lig_names = load_ligands(ligand_file) self.prot_atom_line, self.prot_atom_coord = self.get_protein_info( protein_pdb ) def __getitem__(self, idx): name = self.lig_names[idx] if self.err_tags[idx] == 0: lmol = self.lig_mols[idx] pmol = self.get_pocket_with_ligand_in_protein( self.prot_atom_line, self.prot_atom_coord, lmol ) gl = self.mol_to_graph( lmol ) gp = self.mol_to_graph( pmol ) gc = self.complex_to_graph( pmol, lmol ) error = 0 else: gp = self.prot_dummy_graph( num_nodes=1000) gl = self.lig_dummy_graph( num_nodes=2 ) gc = self.comp_dummy_graph( num_nodes=1002 ) error = 1 return gp, gl, gc, error, idx, name def __len__(self): return len(self.lig_mols)
[docs] def lig_dummy_graph(self, num_nodes): src = torch.randint(0, num_nodes, (10,)) dst = torch.randint(0, num_nodes, (10,)) gl = dgl.graph( (src, dst), num_nodes=num_nodes) gl.ndata['feats'] = torch.zeros((num_nodes, 57)).float() gl.ndata['pos_enc'] = torch.zeros((num_nodes, 20)).float() gl.ndata['coord'] = torch.randn((num_nodes, 3)).float() gl.edata['feats'] = torch.zeros((10, 13)).float() return gl
[docs] def prot_dummy_graph(self, num_nodes): src = torch.randint(0, num_nodes, (10,)) dst = torch.randint(0, num_nodes, (10,)) gp = dgl.graph( (src, dst), num_nodes=num_nodes) gp.ndata['feats'] = torch.zeros((num_nodes, 57)).float() gp.ndata['pos_enc'] = torch.zeros((num_nodes, 20)).float() gp.ndata['coord'] = torch.randint(0, 100, (num_nodes, 3)).float() gp.edata['feats'] = torch.zeros((10, 13)).float() return gp
[docs] def comp_dummy_graph( self, num_nodes): src = torch.randint(0, num_nodes, (10,)) dst = torch.randint(0, num_nodes, (10,)) gc = dgl.graph( (src, dst), num_nodes=num_nodes) gc.ndata['coord'] = torch.randint(0, 100, (num_nodes, 3)).float() gc.edata['feats'] = torch.zeros((10, 25)).float() gc.edata['distance'] = torch.zeros((10, 1)).float() return gc
[docs] def get_protein_info( self, prot_pdb ): prot_atom_line = [] prot_atom_coord = [] for line in open(prot_pdb).readlines(): if line[0:4] in ['ATOM', 'HETA'] and 'H' not in line[12:14] and 'HOH' not in line[17:20]: prot_atom_line.append( line ) prot_atom_coord.append( [ float(line[30:38]), float(line[38:46]), float(line[46:54]) ]) return prot_atom_line, prot_atom_coord
[docs] def get_pocket_with_ligand_in_protein(self, prot_atom_line, prot_atom_coord, lig_mol ): lig_atom_coord = torch.tensor( lig_mol.GetConformers()[0].GetPositions() ).float() prot_atom_coord = torch.tensor( prot_atom_coord ).float() pl_distance = torch.cdist( prot_atom_coord, lig_atom_coord ) select_index = torch.where( pl_distance < 8 )[0] select_atom = [ line for idx, line in enumerate( prot_atom_line ) if idx in select_index ] select_residue = defaultdict(set) for idx, line in enumerate(prot_atom_line): if idx in select_index: select_residue[line[21]].add( int(line[22:26]) ) total_lines = """""" for idx, line in enumerate(prot_atom_line): if int( line[22:26] ) in select_residue[ line[21] ]: total_lines += line mol = Chem.MolFromPDBBlock( total_lines, sanitize=False ) #Chem.AssignAtomChiralTagsFromStructure(mol) return mol
[docs] def mol_to_graph( self, mol ): n = mol.GetNumAtoms() coord = get_mol_coordinate(mol) h = get_atom_feature(mol) adj = get_bond_feature(mol).to_sparse(sparse_dim=2) u = adj.indices()[0] v = adj.indices()[1] e = adj.values() g = dgl.DGLGraph() g.add_nodes(n) g.add_edges(u, v) g.ndata['feats'] = h g.ndata['coord'] = coord g.edata['feats'] = e g.ndata['pos_enc'] = dgl.random_walk_pe(g, 20) return g
[docs] def complex_to_graph( self, pmol, lmol): pcoord = get_mol_coordinate(pmol) lcoord = get_mol_coordinate(lmol) ccoord = torch.cat( [pcoord, lcoord] ) npa = pmol.GetNumAtoms() nla = lmol.GetNumAtoms() distance = torch.cdist(pcoord, lcoord) u, v = torch.where( distance < 5 ) ### u - src protein node, v - dst ligand node distance = distance[ u, v ].unsqueeze(-1) interact_feature = get_interact_feature( pmol, lmol, u, v ) distance_feature = get_distance_feature(distance).squeeze(-1) e = torch.cat( [interact_feature, distance_feature], dim=1) e = torch.cat( [e, e] ) distance = torch.cat( [ distance, distance] ) u, v = torch.cat( [u, v+npa] ), torch.cat( [v+npa, u] ) g = dgl.DGLGraph() g.add_nodes( npa + nla ) g.add_edges( u, v ) g.ndata['coord'] = ccoord g.edata['feats'] = e g.edata['distance'] = distance return g