Source code for gd_dl.data.single_ML_rerank

#!/usr/bin/env python


import os

import time

import argparse
import torch

from gd_dl.rerank_model import Rerank_model
from gd_dl.data.ML_inference import load_files, calculate_energy

READOUT = 'mean'

device = torch.device('cpu')

torch.set_num_threads(1)
#torch.set_num_interop_threads(1)

REP_DIST = 1.5
N_NODE = 29

model_1_feature_dict = {'dist_cutoff' : 5.0,
 'bucket' : torch.tensor([1.5,1.9,2.3,2.7,3.1,3.5,4.0,4.5,5.0]),
  'n_edge': 17,'zip':1}
model_2_feature_dict = {'dist_cutoff' : 6.0,
 'bucket' : torch.tensor([1.5,1.9,2.3,2.7,3.1,3.5,4.0,4.5,5.0,5.5,6.0]),
  'n_edge': 19,'zip':2}

[docs] def calc_e(args): pre_ML_fn = args.infile_pre_ML mol_fn = os.path.join(args.mol2_prefix+'.mol2') rerank_model = args.load_model fn_model_1 = rerank_model.replace('rerank_model.pt','sampling_model.pt') fn_model_2 = rerank_model model_1 = Rerank_model(node_dim_hidden=64, edge_dim_hidden=32, edge_dim_in=model_1_feature_dict['n_edge'], readout=READOUT,ligand_only=True).to(device) checkpoint = torch.load(fn_model_1, map_location=device) model_1.load_state_dict(checkpoint['model_state_dict']) model_1.eval() model_2 = Rerank_model(node_dim_hidden=64, edge_dim_hidden=32, edge_dim_in=model_2_feature_dict['n_edge'], readout=READOUT,ligand_only=True).to(device) checkpoint = torch.load(fn_model_2, map_location=device) model_2.load_state_dict(checkpoint['model_state_dict']) model_2.eval() all_zip, coord_l = load_files(pre_ML_fn, mol_fn) with torch.no_grad(): energy_s_1 = calculate_energy(all_zip, coord_l, model_1, model_1_feature_dict) energy_s_2 = calculate_energy(all_zip, coord_l, model_2, model_2_feature_dict) energy = (energy_s_1 + energy_s_2)/2 out_fn = os.path.join(args.mol2_prefix+'.mol2.th1') with open(out_fn, 'w') as f: for energy_ in energy: f.write('%10.3f\n'%(energy_)) print('%s is working!'%mol_fn) return
[docs] def main(args): st = time.time() calc_e(args) et = time.time() print ("Time for inference:", round(et-st, 2), " (s)") return
if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--infile_pre_ML', type=str, required=True, help='torch_file_for_ML') parser.add_argument('--mol2_prefix', type=str, required=True, help='Prefix of a mol2 file') parser.add_argument('--load_model', type=str, required=True, help='saved_model') args = parser.parse_args() main(args)