Source code for gd_dl.data.all_ML_rerank

import os
import time

import argparse
import torch

from gd_dl.rerank_model import Rerank_model
from gd_dl.data.ML_inference import load_files, calculate_energy

READOUT = 'mean'

device = torch.device('cpu')

torch.set_num_threads(1)
#torch.set_num_interop_threads(1)

REP_DIST = 1.5
N_NODE = 29

model_1_feature_dict = {'dist_cutoff' : 5.0,
 'bucket' : torch.tensor([1.5,1.9,2.3,2.7,3.1,3.5,4.0,4.5,5.0]),
  'n_edge': 17,'zip':1}
model_2_feature_dict = {'dist_cutoff' : 6.0,
 'bucket' : torch.tensor([1.5,1.9,2.3,2.7,3.1,3.5,4.0,4.5,5.0,5.5,6.0]),
  'n_edge': 19,'zip':2}

[docs] def calc_e(args): pre_ML_fn = args.infile_pre_ML mol_fn = os.path.join(args.mol2_prefix+'.mol2') rerank_model = args.load_model fn_sampling_model = rerank_model.replace('rerank_model.pt','sampling_model.pt') fn_rerank_model = rerank_model fn_sampling_model_list = [fn_sampling_model, fn_sampling_model[:-3] + '_0.pt', fn_sampling_model[:-3] + '_1.pt'] sampling_model_list = [] for fn_model_i in fn_sampling_model_list: tmp_model = Rerank_model(node_dim_hidden=64, edge_dim_hidden=32, readout=READOUT, ligand_only=True).to(device) checkpoint = torch.load(fn_model_i, map_location=device) tmp_model.load_state_dict(checkpoint['model_state_dict']) tmp_model.eval() sampling_model_list.append(tmp_model) rerank_model = Rerank_model(node_dim_hidden=64, edge_dim_hidden=32, edge_dim_in=model_2_feature_dict['n_edge'], readout=READOUT,ligand_only=True).to(device) checkpoint = torch.load(fn_rerank_model, map_location=device) rerank_model.load_state_dict(checkpoint['model_state_dict']) rerank_model.eval() all_zip, coord_l = load_files(pre_ML_fn, mol_fn) with torch.no_grad(): energy = 0.0 for model_i in sampling_model_list: energy += calculate_energy(all_zip, coord_l, model_i, model_1_feature_dict) energy += calculate_energy(all_zip, coord_l, rerank_model, model_2_feature_dict) energy /= len(fn_sampling_model_list) + 1 out_fn = os.path.join(args.mol2_prefix+'.mol2.th1') with open(out_fn, 'w') as f: for energy_ in energy: f.write('%10.3f\n'%(energy_)) print('%s is working!'%mol_fn) return
[docs] def main(args): st = time.time() calc_e(args) et = time.time() print ("Time for inference:", round(et-st, 2), " (s)") return
if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--infile_pre_ML', type=str, required=True, help='torch_file_for_ML') parser.add_argument('--mol2_prefix', type=str, required=True, help='Prefix of a mol2 file') parser.add_argument('--load_model', type=str, required=True, help='saved_model') args = parser.parse_args() main(args)