Source code for bindingrmsd.data.data_bck

import os, torch, dgl

from rdkit import Chem # type: ignore
from dgl.data import DGLDataset # type: ignore

from .ligand_atom_feature import mol_to_graph
from .protein_atom_feature import get_all_graph, prot_to_graph, pl_to_c_graph

[docs] def process_dlg(file_path, only_cluster=False): from meeko import PDBQTMolecule from meeko import RDKitMolCreate name = os.path.basename(file_path).split('.')[0] if '.dlg' in file_path: pdbqt_mol = PDBQTMolecule.from_file(file_path, name=name, is_dlg=True, skip_typing=True) elif '.pdbqt' in file_path: pdbqt_mol = PDBQTMolecule.from_file(file_path, name=name, skip_typing=True) else: raise ValueError(f"Unsupported file type: {file_path}") rdkit_mols = RDKitMolCreate.from_pdbqt_mol(pdbqt_mol, only_cluster_leads=only_cluster, keep_flexres=False) sdf_string, _ = RDKitMolCreate.write_sd_string( pdbqt_mol, only_cluster_leads=only_cluster ) adg_score = [] for line in sdf_string.split('\n'): if '{' in line: words = line.split(',') free_energy = words[1].split(':')[1].strip() adg_score.append(float(free_energy)) mols = [] err_tags = [] names = [] for i, conf in enumerate(rdkit_mols[0].GetConformers()): mol = Chem.Mol(rdkit_mols[0]) if mol is None: mols.append(None) err_tags.append(1) names.append(f"{name}_{i}") continue else: mol.RemoveAllConformers() mol.AddConformer(conf, assignId=True) mols.append(mol) err_tags.append(0) names.append(f"{name}_{i}") # print(mols, err_tags, names, adg_score) return mols, err_tags, names, adg_score
[docs] def process_ligand_file(file_path): extension = os.path.splitext(file_path)[-1].lower() if extension == '.sdf': supplier = enumerate(Chem.SDMolSupplier(file_path, sanitize=False)) elif extension == '.mol2': with open(file_path, 'r') as f: mol2_data = f.read() mol2_blocks = mol2_data.split('@<TRIPOS>MOLECULE') supplier = enumerate( Chem.MolFromMol2Block('@<TRIPOS>MOLECULE' + block, sanitize=False) for block in mol2_blocks[1:] ) else: raise ValueError(f"Unsupported file type: {extension}") ligands = [] err_tag = [] ligand_names = [] base_name = os.path.splitext(os.path.basename(file_path))[0] for idx, mol in supplier: if mol is not None: ligands.append(mol) err_tag.append(0) ligand_name = mol.GetProp('_Name') if mol.HasProp('_Name') else '' if ligand_name == '': ligand_name = f"{base_name}_{idx}" ligand_names.append(ligand_name) else: ligands.append(None) err_tag.append(1) ligand_names.append(f"{base_name}_{idx}") return ligands, err_tag, ligand_names
[docs] def load_ligands(file_path): lig_mols = [] err_tags = [] lig_names = [] def process_single_file(line): assert os.path.isfile(line), f"File not found: {line}" return process_ligand_file(line) file_extension = os.path.splitext(file_path)[-1].lower() if file_extension == '.txt': with open(file_path, 'r') as f: lines = [line.strip() for line in f if line.strip()] for line in lines: file_ligands, file_err_tag, file_ligand_names = process_single_file(line) lig_mols.extend(file_ligands) err_tags.extend(file_err_tag) lig_names.extend(file_ligand_names) elif file_extension in ['.sdf', '.mol2']: lig_mols, err_tags, lig_names = process_single_file(file_path) else: raise ValueError("Unsupported file type. Use '.txt', '.sdf', or '.mol2'.") return lig_mols, err_tags, lig_names
[docs] class PoseSelectionDataset(DGLDataset): def __init__(self, protein_pdb, ligand_file): super(PoseSelectionDataset, self).__init__(name='Protein Ligand Binding conformation RMSD prediction') # self.ligand_mols = Chem.SDMolSupplier( ligand_sdf ) self.lig_mols, self.err_tags, self.lig_names = load_ligands(ligand_file) self.gp = prot_to_graph( protein_pdb ) def __getitem__(self, idx): try: mol = self.lig_mols[idx] gl = mol_to_graph(mol) gp, gl, gc = get_all_graph(self.gp, gl) error = self.err_tags[idx] name = self.lig_names[idx] except: gl = self.lig_dummy_graph(num_nodes=3) gp, gl, gc = get_all_graph(self.gp, gl) error = self.err_tags[idx] name = self.lig_names[idx] return gp, gl, gc, error, name def __len__(self): return len(self.lig_mols)
[docs] def lig_dummy_graph(self, num_nodes): src = torch.randint(0, num_nodes, (10,)) dst = torch.randint(0, num_nodes, (10,)) gl = dgl.graph( (src, dst), num_nodes=num_nodes) gl.ndata['feat'] = torch.zeros((num_nodes, 57)).float() # Example: adding dummy node features gl.ndata['pos_enc'] = torch.zeros((num_nodes, 20)).float() # Example: adding dummy node features gl.ndata['coord'] = torch.randn((num_nodes, 3)).float() # Example: adding dummy node features gl.edata['feat'] = torch.zeros((10, 13)).float() return gl
if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument('--dlg', type=str, required=True) parser.add_argument('--only_cluster', action='store_true') args = parser.parse_args() process_dlg(args.dlg, args.only_cluster)