import torch
import pandas as pd
from tqdm import tqdm
from dgl.dataloading import GraphDataLoader
from bindingrmsd.data.data import PoseSelectionDataset
from bindingrmsd.model.model import PredictionRMSD
[docs]
def inference(protein_pdb, ligand_file, output, batch_size, model_path, device='cpu'):
dataset = PoseSelectionDataset(
protein_pdb=protein_pdb,
ligand_file=ligand_file
)
loader = GraphDataLoader(dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
rmsd_model = PredictionRMSD(57, 256, 13, 25, 20, 4, 0).to(device)
prob_model = PredictionRMSD(57, 256, 13, 25, 20, 4, 0).to(device)
reg_save_path = f'{model_path}/reg.pth'
bce_save_path = f'{model_path}/bce.pth'
rmsd_model.load_state_dict(torch.load(reg_save_path, weights_only=True)['model_state_dict'])
prob_model.load_state_dict(torch.load(bce_save_path, weights_only=True)['model_state_dict'])
rmsd_model.eval()
prob_model.eval()
results = {
"Name": [],
"pRMSD": [],
"Is_Above_2A": [],
"ADG_Score": [],
}
with torch.no_grad():
progress_bar = tqdm(total=len(loader.dataset), unit='ligand')
for data in loader:
bgp, bgl, bgc, error, names, adg_score = data
bgp, bgl, bgc = bgp.to(device), bgl.to(device), bgc.to(device)
rmsd = rmsd_model(bgp, bgl, bgc)
prob = prob_model(bgp, bgl, bgc)
rmsd = rmsd.view(-1)
prob = prob.view(-1)
prob = torch.sigmoid(prob)
rmsd[error == 1] = torch.tensor(float('nan'))
prob[error == 1] = torch.tensor(float('nan'))
results["Name"].extend(names)
results["pRMSD"].extend(rmsd.tolist())
results["Is_Above_2A"].extend(prob.tolist())
results["ADG_Score"].extend(adg_score.tolist())
progress_bar.update(len(names))
progress_bar.close()
df = pd.DataFrame(results)
df = df.round(4)
df.to_csv(output, sep='\t', na_rep='NaN', index=False)
if __name__ == "__main__":
import argparse
import os
parser = argparse.ArgumentParser()
parser.add_argument('-r', '--protein_pdb', default='./1KLT_rec.pdb', help='receptor .pdb')
parser.add_argument('-l', '--ligand_file', default='./chk.sdf', help='ligand .sdf .txt .mol2 .dlg .pdbqt')
parser.add_argument('-o', '--output', default='./result.csv', help='result output file')
parser.add_argument('--batch_size', default=128, type=int, help='batch size')
parser.add_argument('--ncpu', default=4, type=int, help="cpu worker number")
parser.add_argument('--device', type=str, default='cuda', help='choose device: cpu or cuda')
parser.add_argument('--model_path', type=str, default='./save', help='model weight path')
args = parser.parse_args()
os.environ["OMP_NUM_THREADS"] = str(args.ncpu)
os.environ["MKL_NUM_THREADS"] = str(args.ncpu)
torch.set_num_threads(args.ncpu)
if args.device == 'cpu':
device = torch.device("cpu")
else:
if torch.cuda.is_available():
device = torch.device("cuda")
else:
print("gpu is not available, run on cpu")
device = torch.device("cpu")
inference(
protein_pdb=args.protein_pdb,
ligand_file=args.ligand_file,
output=args.output,
batch_size=args.batch_size,
model_path=args.model_path,
device=args.device
)