import os
import torch
import dgl
from rdkit import Chem # type: ignore
from meeko import PDBQTMolecule, RDKitMolCreate
from dgl.data import DGLDataset # type: ignore
from .ligand_atom_feature import mol_to_graph
from .protein_atom_feature import get_all_graph, prot_to_graph, pl_to_c_graph
def _process_dlg_pdbqt(file_path, is_dlg):
"""Helper function to process .dlg and .pdbqt files."""
name = os.path.basename(file_path).split('.')[0]
pdbqt_mol = PDBQTMolecule.from_file(
file_path, name=name, is_dlg=is_dlg, skip_typing=True
)
rdkit_mols = RDKitMolCreate.from_pdbqt_mol(
pdbqt_mol, only_cluster_leads=False, keep_flexres=False
)
sdf_string, _ = RDKitMolCreate.write_sd_string(pdbqt_mol, only_cluster_leads=False)
adg_score = []
for line in sdf_string.split('\n'):
if '{' in line:
words = line.split(',')
free_energy = words[1].split(':')[1].strip()
adg_score.append(float(free_energy))
mols, err_tags, names = [], [], []
for i, conf in enumerate(rdkit_mols[0].GetConformers()):
mol = Chem.Mol(rdkit_mols[0])
if mol is None:
mols.append(None)
err_tags.append(1)
else:
mol.RemoveAllConformers()
mol.AddConformer(conf, assignId=True)
mol = Chem.RemoveHs(mol)
mols.append(mol)
err_tags.append(0)
names.append(f"{name}_{i}")
return mols, err_tags, names, adg_score
def _process_sdf(file_path):
"""Helper function to process .sdf files."""
supplier = Chem.SDMolSupplier(file_path, sanitize=False)
return _process_supplier(supplier, file_path)
def _process_mol2(file_path):
"""Helper function to process .mol2 files"""
with open(file_path, 'r') as f:
mol2_data = f.read()
mol2_blocks = mol2_data.split('@<TRIPOS>MOLECULE')
supplier = (
Chem.MolFromMol2Block('@<TRIPOS>MOLECULE' + block, sanitize=False)
for block in mol2_blocks[1:]
)
return _process_supplier(supplier, file_path)
def _process_supplier(supplier, file_path):
"""Common logic for processing SDF and Mol2 suppliers."""
ligands, err_tag, ligand_names = [], [], []
base_name = os.path.splitext(os.path.basename(file_path))[0]
for idx, mol in enumerate(supplier):
if mol is not None:
mol = Chem.RemoveHs(mol)
ligands.append(mol)
err_tag.append(0)
ligand_name = mol.GetProp('_Name') if mol.HasProp('_Name') and mol.GetProp('_Name').strip() else f"{base_name}_{idx}"
ligand_names.append(ligand_name)
else:
ligands.append(None)
err_tag.append(1)
ligand_names.append(f"{base_name}_err_{idx}")
return ligands, err_tag, ligand_names, [float('nan')] * len(ligands)
[docs]
def process_ligand_file(file_path):
"""Processes a single ligand file (.dlg, .pdbqt, .sdf, .mol2)."""
extension = os.path.splitext(file_path)[-1].lower()
if extension == '.dlg':
return _process_dlg_pdbqt(file_path, is_dlg=True)
elif extension == '.pdbqt':
return _process_dlg_pdbqt(file_path, is_dlg=False)
elif extension == '.sdf':
return _process_sdf(file_path)
elif extension == '.mol2':
return _process_mol2(file_path)
else:
raise ValueError(f"Unsupported file type: {extension}")
[docs]
def load_ligands(file_path):
"""Loads ligands from a file or a list of files."""
file_extension = os.path.splitext(file_path)[-1].lower()
if file_extension == '.txt':
with open(file_path, 'r') as f:
lines = [line.strip() for line in f if line.strip()]
lig_mols, err_tags, lig_names, adg_scores = [], [], [], []
for line in lines:
assert os.path.isfile(line), f"File not found: {line}"
file_ligands, file_err_tag, file_ligand_names, file_adg_score = process_ligand_file(line)
lig_mols.extend(file_ligands)
err_tags.extend(file_err_tag)
lig_names.extend(file_ligand_names)
adg_scores.extend(file_adg_score)
return lig_mols, err_tags, lig_names, adg_scores
elif file_extension in ['.sdf', '.mol2', '.dlg', '.pdbqt']:
return process_ligand_file(file_path)
else:
raise ValueError("Unsupported file type. Use '.txt', '.sdf', '.mol2', '.dlg', or '.pdbqt'.")
[docs]
class PoseSelectionDataset(DGLDataset):
def __init__(self, protein_pdb, ligand_file):
super(PoseSelectionDataset, self).__init__(name='Protein Ligand Binding Conformation RMSD Prediction')
self.lig_mols, self.err_tags, self.lig_names, self.adg_scores = load_ligands(ligand_file)
self.gp = prot_to_graph(protein_pdb)
def __getitem__(self, idx):
try:
mol = self.lig_mols[idx]
gl = mol_to_graph(mol)
gp, gl, gc = get_all_graph(self.gp, gl)
error = self.err_tags[idx]
name = self.lig_names[idx]
adg_score = self.adg_scores[idx] if idx < len(self.adg_scores) else float('nan')
except:
gl = self.lig_dummy_graph(num_nodes=3)
gp, gl, gc = get_all_graph(self.gp, gl)
error = self.err_tags[idx]
name = self.lig_names[idx]
adg_score = float('nan')
return gp, gl, gc, error, name, adg_score
def __len__(self):
return len(self.lig_mols)
[docs]
def lig_dummy_graph(self, num_nodes):
src = torch.randint(0, num_nodes, (10,))
dst = torch.randint(0, num_nodes, (10,))
gl = dgl.graph((src, dst), num_nodes=num_nodes)
gl.ndata['feat'] = torch.zeros((num_nodes, 57)).float()
gl.ndata['pos_enc'] = torch.zeros((num_nodes, 20)).float()
gl.ndata['coord'] = torch.randn((num_nodes, 3)).float()
gl.edata['feat'] = torch.zeros((10, 13)).float()
return gl
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()