Source code for bapred.data.data

import os
import torch, dgl
from collections import defaultdict
from rdkit import Chem, RDLogger
from rdkit.Chem import AllChem
RDLogger.DisableLog('rdApp.*')

from dgl.data import DGLDataset

from meeko import PDBQTMolecule, RDKitMolCreate


from bapred.data.atom_feature import *



def _process_dlg_pdbqt(file_path, is_dlg, only_cluster_leads=True):
    """Helper function to process .dlg and .pdbqt files."""
    name = os.path.basename(file_path).split('.')[0]
    pdbqt_mol = PDBQTMolecule.from_file(
        file_path, name=name, is_dlg=is_dlg, skip_typing=True
    )
    rdkit_mols = RDKitMolCreate.from_pdbqt_mol(
        pdbqt_mol, only_cluster_leads=only_cluster_leads, keep_flexres=False
    )
    sdf_string, _ = RDKitMolCreate.write_sd_string(pdbqt_mol, only_cluster_leads=only_cluster_leads)

    adg_score = []
    for line in sdf_string.split('\n'):
        if '{' in line:
            words = line.split(',')
            free_energy = words[1].split(':')[1].strip()
            adg_score.append(float(free_energy))

    mols, err_tags, names = [], [], []
    for i, conf in enumerate(rdkit_mols[0].GetConformers()):
        mol = Chem.Mol(rdkit_mols[0])
        if mol is None:
            mols.append(None)
            err_tags.append(1)
        else:
            mol.RemoveAllConformers()
            mol.AddConformer(conf, assignId=True)
            mol = Chem.RemoveHs(mol)
            mols.append(mol)
            err_tags.append(0)
        names.append(f"{name}_{i}")
    return mols, err_tags, names, adg_score


def _process_sdf(file_path):
    """Helper function to process .sdf files."""
    supplier = Chem.SDMolSupplier(file_path)
    return _process_supplier(supplier, file_path)

def _process_mol2(file_path):
    """Helper function to process .mol2 files"""
    with open(file_path, 'r') as f:
        mol2_data = f.read()
    mol2_blocks = mol2_data.split('@<TRIPOS>MOLECULE')
    supplier = (
        Chem.MolFromMol2Block('@<TRIPOS>MOLECULE' + block)
        for block in mol2_blocks[1:]
    )
    return _process_supplier(supplier, file_path)

def _process_supplier(supplier, file_path):
    """Common logic for processing SDF and Mol2 suppliers."""
    ligands, err_tag, ligand_names = [], [], []
    base_name = os.path.splitext(os.path.basename(file_path))[0]

    for idx, mol in enumerate(supplier):
        if mol is not None:
            mol = Chem.RemoveHs(mol)
            ligands.append(mol)
            err_tag.append(0)
            ligand_names.append(f"{base_name}_{idx}")
        else:
            ligands.append(None)
            err_tag.append(1)
            ligand_names.append(f"{base_name}_{idx}")

    return ligands, err_tag, ligand_names, [float('nan')] * len(ligands)


[docs] def process_ligand_file(file_path, only_cluster_leads=True): """Processes a single ligand file (.dlg, .pdbqt, .sdf, .mol2).""" extension = os.path.splitext(file_path)[-1].lower() if extension == '.dlg': return _process_dlg_pdbqt(file_path, is_dlg=True, only_cluster_leads=only_cluster_leads) elif extension == '.pdbqt': return _process_dlg_pdbqt(file_path, is_dlg=False, only_cluster_leads=only_cluster_leads) elif extension == '.sdf': return _process_sdf(file_path) elif extension == '.mol2': return _process_mol2(file_path) else: raise ValueError(f"Unsupported file type: {extension}")
[docs] def load_ligands(file_path, only_cluster_leads=True): """Loads ligands from a file or a list of files.""" file_extension = os.path.splitext(file_path)[-1].lower() if file_extension == '.txt': with open(file_path, 'r') as f: lines = [line.strip() for line in f if line.strip()] lig_mols, err_tags, lig_names = [], [], [] for line in lines: assert os.path.isfile(line), f"File not found: {line}" file_ligands, file_err_tag, file_ligand_names, _ = process_ligand_file(line, only_cluster_leads=only_cluster_leads) lig_mols.extend(file_ligands) err_tags.extend(file_err_tag) lig_names.extend(file_ligand_names) return lig_mols, err_tags, lig_names elif file_extension in ['.sdf', '.mol2', '.dlg', '.pdbqt']: return process_ligand_file(file_path, only_cluster_leads=only_cluster_leads) else: raise ValueError("Unsupported file type. Use '.txt', '.sdf', '.mol2', '.dlg', or '.pdbqt'.")
# def process_ligand_file(file_path): # extension = os.path.splitext(file_path)[-1].lower() # if extension == '.sdf': # supplier = enumerate(Chem.SDMolSupplier(file_path)) # elif extension == '.mol2': # with open(file_path, 'r') as f: # mol2_data = f.read() # mol2_blocks = mol2_data.split('@<TRIPOS>MOLECULE') # supplier = enumerate(Chem.MolFromMol2Block('@<TRIPOS>MOLECULE' + block) for block in mol2_blocks[1:]) # else: # raise ValueError(f"Unsupported file type: {extension}") # ligands = [] # err_tag = [] # ligand_names = [] # base_name = os.path.splitext(os.path.basename(file_path))[0] # for idx, mol in supplier: # if mol is not None: # ligands.append(mol) # err_tag.append(0) # ligand_name = mol.GetProp('_Name') # if ligand_name == '': # ligand_name = f"{base_name}_{idx}" # ligand_names.append(ligand_name) # else: # ligands.append(None) # err_tag.append(1) # ligand_names.append(f"{base_name}_{idx}") # return ligands, err_tag, ligand_names # def load_ligands(file_path): # lig_mols = [] # err_tags = [] # lig_names = [] # def process_single_file(line): # assert os.path.isfile(line), f"File not found: {line}" # return process_ligand_file(line) # file_extension = os.path.splitext(file_path)[-1].lower() # if file_extension == '.txt': # with open(file_path, 'r') as f: # lines = [line.strip() for line in f if line.strip()] # for line in lines: # file_ligands, file_err_tag, file_ligand_names = process_single_file(line) # lig_mols.extend(file_ligands) # err_tags.extend(file_err_tag) # lig_names.extend(file_ligand_names) # elif file_extension in ['.sdf', '.mol2']: # lig_mols, err_tags, lig_names = process_single_file(file_path) # else: # raise ValueError("Unsupported file type. Use '.txt', '.sdf', or '.mol2'.") # return lig_mols, err_tags, lig_names
[docs] class BAPredDataset(DGLDataset): def __init__(self, protein_pdb, ligand_file, train=True, only_cluster_leads=True): super(BAPredDataset, self).__init__(name='Protein Ligand Binding Affinity prediction') self.lig_mols, self.err_tags, self.lig_names, _ = load_ligands(ligand_file, only_cluster_leads=only_cluster_leads) self.prot_atom_line, self.prot_atom_coord = self.get_protein_info( protein_pdb ) def __getitem__(self, idx): name = self.lig_names[idx] if self.err_tags[idx] == 0: lmol = self.lig_mols[idx] pmol = self.get_pocket_with_ligand_in_protein( self.prot_atom_line, self.prot_atom_coord, lmol ) gl = self.mol_to_graph( lmol ) gp = self.mol_to_graph( pmol ) gc = self.complex_to_graph( pmol, lmol ) error = 0 else: gp = self.prot_dummy_graph( num_nodes=1000) gl = self.lig_dummy_graph( num_nodes=2 ) gc = self.comp_dummy_graph( num_nodes=1002 ) error = 1 return gp, gl, gc, error, idx, name def __len__(self): return len(self.lig_mols)
[docs] def lig_dummy_graph(self, num_nodes): src = torch.randint(0, num_nodes, (10,)) dst = torch.randint(0, num_nodes, (10,)) gl = dgl.graph( (src, dst), num_nodes=num_nodes) gl.ndata['feats'] = torch.zeros((num_nodes, 57)).float() gl.ndata['pos_enc'] = torch.zeros((num_nodes, 20)).float() gl.ndata['coord'] = torch.randn((num_nodes, 3)).float() gl.edata['feats'] = torch.zeros((10, 13)).float() return gl
[docs] def prot_dummy_graph(self, num_nodes): src = torch.randint(0, num_nodes, (10,)) dst = torch.randint(0, num_nodes, (10,)) gp = dgl.graph( (src, dst), num_nodes=num_nodes) gp.ndata['feats'] = torch.zeros((num_nodes, 57)).float() gp.ndata['pos_enc'] = torch.zeros((num_nodes, 20)).float() gp.ndata['coord'] = torch.randint(0, 100, (num_nodes, 3)).float() gp.edata['feats'] = torch.zeros((10, 13)).float() return gp
[docs] def comp_dummy_graph( self, num_nodes): src = torch.randint(0, num_nodes, (10,)) dst = torch.randint(0, num_nodes, (10,)) gc = dgl.graph( (src, dst), num_nodes=num_nodes) gc.ndata['coord'] = torch.randint(0, 100, (num_nodes, 3)).float() gc.edata['feats'] = torch.zeros((10, 25)).float() gc.edata['distance'] = torch.zeros((10, 1)).float() return gc
[docs] def get_protein_info( self, prot_pdb ): prot_atom_line = [] prot_atom_coord = [] for line in open(prot_pdb).readlines(): if line[0:4] in ['ATOM', 'HETA'] and 'H' not in line[12:14] and 'HOH' not in line[17:20]: prot_atom_line.append( line ) prot_atom_coord.append( [ float(line[30:38]), float(line[38:46]), float(line[46:54]) ]) return prot_atom_line, prot_atom_coord
[docs] def get_pocket_with_ligand_in_protein(self, prot_atom_line, prot_atom_coord, lig_mol ): lig_atom_coord = torch.tensor( lig_mol.GetConformers()[0].GetPositions() ).float() prot_atom_coord = torch.tensor( prot_atom_coord ).float() pl_distance = torch.cdist( prot_atom_coord, lig_atom_coord ) select_index = torch.where( pl_distance < 8 )[0] select_atom = [ line for idx, line in enumerate( prot_atom_line ) if idx in select_index ] select_residue = defaultdict(set) for idx, line in enumerate(prot_atom_line): if idx in select_index: select_residue[line[21]].add( int(line[22:26]) ) total_lines = """""" for idx, line in enumerate(prot_atom_line): if int( line[22:26] ) in select_residue[ line[21] ]: total_lines += line mol = Chem.MolFromPDBBlock( total_lines ) return mol
[docs] def mol_to_graph( self, mol ): n = mol.GetNumAtoms() coord = get_mol_coordinate(mol) h = get_atom_feature(mol) adj = get_bond_feature(mol).to_sparse(sparse_dim=2) u = adj.indices()[0] v = adj.indices()[1] e = adj.values() g = dgl.DGLGraph() g.add_nodes(n) g.add_edges(u, v) g.ndata['feats'] = h g.ndata['coord'] = coord g.edata['feats'] = e g.ndata['pos_enc'] = dgl.random_walk_pe(g, 20) return g
[docs] def complex_to_graph( self, pmol, lmol): pcoord = get_mol_coordinate(pmol) lcoord = get_mol_coordinate(lmol) ccoord = torch.cat( [pcoord, lcoord] ) npa = pmol.GetNumAtoms() nla = lmol.GetNumAtoms() distance = torch.cdist(pcoord, lcoord) u, v = torch.where( distance < 5 ) ### u - src protein node, v - dst ligand node distance = distance[ u, v ].unsqueeze(-1) interact_feature = get_interact_feature( pmol, lmol, u, v ) distance_feature = get_distance_feature(distance).squeeze(-1) e = torch.cat( [interact_feature, distance_feature], dim=1) e = torch.cat( [e, e] ) distance = torch.cat( [ distance, distance] ) u, v = torch.cat( [u, v+npa] ), torch.cat( [v+npa, u] ) g = dgl.DGLGraph() g.add_nodes( npa + nla ) g.add_edges( u, v ) g.ndata['coord'] = ccoord g.edata['feats'] = e g.edata['distance'] = distance return g