Source code for gd_dl.data.ML_inference

import os
import time

import argparse
import numpy as np
import torch

from torch_geometric.data import Data
from torch_geometric.data import Batch

import gc

from gd_dl.rerank_model import Rerank_model

READOUT = 'mean'

device = torch.device('cpu')

torch.set_num_threads(1)
torch.set_num_interop_threads(1)

N_NODE = 29

model_1_feature_dict = {'dist_cutoff' : 5.0,
 'bucket' : torch.tensor([1.5,1.9,2.3,2.7,3.1,3.5,4.0,4.5,5.0]),
  'n_edge': 17,'zip':1}

[docs] def get_ref_vec(i_atom, neigh_dict, coord_l): epsilon=1e-10 crd_ref = coord_l[i_atom] ref_vec = torch.zeros(3) for i in neigh_dict[i_atom]: crd_neigh = coord_l[i] ref_vec += crd_neigh - crd_ref ref_vec = -ref_vec ref_vec_size = torch.sqrt(ref_vec.pow(2).sum(-1)) ref_vec = ref_vec / (ref_vec_size + epsilon) if ref_vec_size < 1.0: ref_vec_size = 0.0 else: ref_vec_size = 1.0 return ref_vec, ref_vec_size
[docs] def load_files(pre_ML_fn, mol_fn): all_zip = torch.load(pre_ML_fn, weights_only=False) coord = [] i_model = -1 with open(mol_fn) as f: read = False for l in f: if l.startswith('@<TRIPOS>'): if 'MOLECULE' in l: i_model += 1 tmp_coord = [] elif 'ATOM' in l: read = True elif 'BOND' in l: read = False coord.append(tmp_coord) elif read: ll = l.split() tmp = ll[5] if tmp[0] == 'H': continue tmp_coord.append(ll[2:5]) coord = np.array(coord,dtype=np.float32) coord_l = torch.from_numpy(coord) return all_zip, coord_l
[docs] def calculate_energy(all_zip, coord_l, model, feature_dict): dist_cutoff = feature_dict['dist_cutoff'] bucket = feature_dict['bucket'] n_edge = feature_dict['n_edge'] coord_r, atom_type_r,ref_vec_r,ref_vec_size_r = all_zip[0] x, cov_edge_index_list, cov_edge_attr_list, n_atom, atm_idx_list, bond, neigh_dict = all_zip[feature_dict['zip']] energy_s = [] # generate graph protein_indice = torch.arange(len(coord_r)) coord_r = coord_r.unsqueeze(0) n_model = len(coord_l) for i_model in range(n_model): if i_model%50 == 0: graph_s = [] coord_l_i = coord_l[i_model] # generate ref_vec ref_vec_dict = {} for i_atom in range(n_atom): ref_vec, ref_vec_size = get_ref_vec(i_atom, neigh_dict, coord_l_i) ref_vec_dict[i_atom] = (ref_vec, ref_vec_size) # # internal # unsqueese for broadcasting --> graph y1 = coord_l_i.unsqueeze(0) y2 = coord_l_i.unsqueeze(1) dm_ll = torch.sqrt((y1-y2).pow(2).sum(-1)) # non_edge_index_list = [] non_edge_attr_list = [] for i_atom in range(n_atom): for j_atom in range(i_atom): tmp_dist = dm_ll[i_atom,j_atom] if tmp_dist <= dist_cutoff and (atm_idx_list[j_atom], atm_idx_list[i_atom]) not in bond: tmp_edge_index = torch.tensor([j_atom, i_atom],dtype=torch.long) tmp_edge_attr = torch.zeros(n_edge) tmp_idx = torch.bucketize(tmp_dist,bucket) tmp_edge_attr[6+tmp_idx] = 1.0 ref_vec_i, ref_vec_size_i = ref_vec_dict[i_atom] ref_vec_j, ref_vec_size_j = ref_vec_dict[j_atom] if ref_vec_size_i == 1.0 and ref_vec_size_j == 1.0: cos_theta = torch.dot(ref_vec_i,ref_vec_j) tmp_edge_attr[-2] = cos_theta tmp_edge_attr[-1] = 1.0 non_edge_index_list.append(tmp_edge_index) non_edge_attr_list.append(tmp_edge_attr) # rec-lig dm_pl = (y2 - coord_r).pow(2).sum(-1) prot_node_dict = {} prot_node_list = [] pl_edge_index_list = [] pl_edge_attr_list = [] for i_atom in range(n_atom): per_dist = dm_pl[i_atom] mask = torch.le(per_dist,dist_cutoff**2) tmp_dist_s = per_dist[mask] per_prot_indice = protein_indice[mask] for idx, prot_idx in enumerate(per_prot_indice): tmp_dist = tmp_dist_s[idx] if prot_idx not in prot_node_dict: prot_node_dict[prot_idx] = n_atom + len(prot_node_dict) prot_node_list.append(atom_type_r[prot_idx]) tmp_edge_index = torch.tensor([i_atom, prot_node_dict[prot_idx]],dtype=torch.long) tmp_edge_attr = torch.zeros(n_edge) tmp_dist = tmp_dist**0.5 tmp_idx = torch.bucketize(tmp_dist,bucket) tmp_edge_attr[6+tmp_idx] = 1.0 ref_vec_i, ref_vec_size_i = ref_vec_dict[i_atom] ref_vec_ip, ref_vec_size_ip = ref_vec_r[prot_idx], ref_vec_size_r[prot_idx] if ref_vec_size_i == 1.0 and ref_vec_size_ip == 1.0: cos_theta = torch.dot(ref_vec_i,ref_vec_ip) tmp_edge_attr[-2] = cos_theta tmp_edge_attr[-1] = 1.0 pl_edge_index_list.append(tmp_edge_index) pl_edge_attr_list.append(tmp_edge_attr) if len(prot_node_list) > 0: x2 = torch.stack(prot_node_list) tot_x = torch.cat((x,x2), dim=0) else: tot_x = x tot_edge_index = cov_edge_index_list + non_edge_index_list + pl_edge_index_list tot_edge_attr = cov_edge_attr_list + non_edge_attr_list + pl_edge_attr_list tot_edge_index = torch.stack(tot_edge_index) tot_edge_index = tot_edge_index.t().contiguous() tot_edge_attr = torch.stack(tot_edge_attr) row, col = tot_edge_index row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0) tot_edge_index = torch.stack([row, col], dim=0) tot_edge_attr = torch.cat([tot_edge_attr, tot_edge_attr], dim=0) # dat = Data(x=tot_x, edge_index=tot_edge_index, edge_attr=tot_edge_attr) graph_s.append(dat) if len(graph_s) == 50: graph_s = Batch.from_data_list(graph_s) graph_s = graph_s.to(device=device) tmp_energy = model(graph_s,n_atom) energy_s.append(tmp_energy) del graph_s gc.collect() energy_s = torch.cat(energy_s) return energy_s
[docs] def calc_e(args): pre_ML_fn = args.infile_pre_ML mol_fn = os.path.join(args.mol2_prefix+'.mol2') fn_model = args.load_model fn_model_list = [fn_model, fn_model[:-3] + '_0.pt', fn_model[:-3] + '_1.pt'] model_list = [] for fn_model_i in fn_model_list: tmp_model = Rerank_model(node_dim_hidden=64, edge_dim_hidden=32, readout=READOUT, ligand_only=True).to(device) checkpoint = torch.load(fn_model_i, map_location=device, weights_only=True) tmp_model.load_state_dict(checkpoint['model_state_dict']) tmp_model.eval() model_list.append(tmp_model) all_zip, coord_l = load_files(pre_ML_fn, mol_fn) with torch.no_grad(): energy_s = 0.0 for model_i in model_list: energy_s += calculate_energy(all_zip, coord_l, model_i, model_1_feature_dict) energy_s /= len(model_list) print('%s is working!'%mol_fn) out_fn = os.path.join(args.mol2_prefix+'.mol2.th1') with open(out_fn, 'w') as f: for energy_ in energy_s: f.write('%10.3f\n'%(energy_)) return
[docs] def main(args): st = time.time() calc_e(args) et = time.time() print ("Time for inference:", round(et-st, 2), " (s)") return
if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--infile_pre_ML', type=str, required=True, help='torch_file_for_ML') parser.add_argument('--mol2_prefix', type=str, required=True, help='Prefix of a mol2 file') parser.add_argument('--load_model', type=str, required=True, help='saved_model') args = parser.parse_args() main(args)