Source code for promptbind.test_promptbind

import argparse
import os
import sys

import numpy as np
import torch
import yaml
from accelerate import Accelerator, DistributedDataParallelKwargs
from accelerate.utils import set_seed
from promptbind.data.data import get_data
from promptbind.models.model import *
from safetensors.torch import load_model
from torch_geometric.loader import DataLoader
from promptbind.utils.logging_utils import Logger
from promptbind.utils.metrics import *
from promptbind.utils.utils import *


[docs] class PromptBindInference: def __init__(self, config_path='args.yml'): """ Initializes the class with the given configuration path. Args: config_path (str): Path to the configuration file. Defaults to 'args.yml'. Attributes: config_path (str): Path to the configuration file. args (argparse.Namespace): Arguments loaded from the configuration file. accelerator: Accelerator setup for the model. logger: Logger setup for logging information. device (str): Device to be used for computation, set to 'cuda'. model: Loaded model. criterion: Main criterion for model evaluation. com_coord_criterion: Criterion for center of mass coordinates. pocket_cls_criterion: Criterion for pocket classification. pocket_coord_criterion: Criterion for pocket coordinates. test_loader: Data loader for test data. test_unseen_loader: Data loader for unseen test data. """ self.config_path = config_path self.args = self.load_args() self.accelerator = self.setup_accelerator() self.logger = self.setup_logger() self.device = 'cuda' self.model = self.load_model() self.criterion, self.com_coord_criterion, self.pocket_cls_criterion, self.pocket_coord_criterion = self.setup_criterions() self.test_loader, self.test_unseen_loader = self.setup_data_loaders()
[docs] def load_args(self): with open(self.config_path, 'r') as f: args_dict = yaml.safe_load(f) combined_args_dict = {**args_dict['config'], **args_dict['args']} prompt_nf = combined_args_dict.get('prompt_nf', '') combined_args_dict['pocket_prompt_nf'] = prompt_nf combined_args_dict['complex_prompt_nf'] = prompt_nf if 'exp_name' not in combined_args_dict: combined_args_dict['exp_name'] = f"test_prompt_{prompt_nf}" if 'ckpt' not in combined_args_dict: combined_args_dict['ckpt'] = f"pretrained/prompt_{prompt_nf}/best/model.safetensors" args = argparse.Namespace(**combined_args_dict) set_seed(args.seed) return args
[docs] def setup_accelerator(self): ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) accelerator = Accelerator(kwargs_handlers=[ddp_kwargs], mixed_precision=self.args.mixed_precision) return accelerator
[docs] def setup_logger(self): pre = f"{self.args.resultFolder}/{self.args.exp_name}" os.makedirs(pre, exist_ok=True) logger = Logger(accelerator=self.accelerator, log_path=f'{pre}/test.log') logger.log_message(f"{' '.join(sys.argv)}") return logger
[docs] def load_model(self): model = get_model(self.args, self.logger, self.device) model = self.accelerator.prepare(model) load_model(model, self.args.ckpt) return model
[docs] def setup_criterions(self): if self.args.pred_dis: criterion = nn.MSELoss() pred_dis = True else: criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(self.args.posweight)) if self.args.coord_loss_function == 'MSE': com_coord_criterion = nn.MSELoss() elif self.args.coord_loss_function == 'SmoothL1': com_coord_criterion = nn.SmoothL1Loss() if self.args.pocket_cls_loss_func == 'bce': pocket_cls_criterion = nn.BCEWithLogitsLoss(reduction='mean') pocket_coord_criterion = nn.HuberLoss(delta=self.args.pocket_coord_huber_delta) return criterion, com_coord_criterion, pocket_cls_criterion, pocket_coord_criterion
[docs] def setup_data_loaders(self): if self.args.redocking: self.args.compound_coords_init_mode = "redocking" elif self.args.redocking_no_rotate: self.args.redocking = True self.args.compound_coords_init_mode = "redocking_no_rotate" train, valid, test = get_data(self.args, self.logger, addNoise=self.args.addNoise, use_whole_protein=self.args.use_whole_protein, compound_coords_init_mode=self.args.compound_coords_init_mode, pre=self.args.data_path) self.logger.log_message(f"data point train: {len(train)}, valid: {len(valid)}, test: {len(test)}") num_workers = 0 test_loader = DataLoader(test, batch_size=self.args.batch_size, follow_batch=['x', 'compound_pair'], shuffle=False, pin_memory=False, num_workers=num_workers) test_unseen_pdb_list = [line.strip() for line in open('split_pdb_id/unseen_test_index')] test_unseen_index = test.data.query("(group =='test') and (pdb in @test_unseen_pdb_list)").index.values test_unseen_index_for_select = np.array([np.where(test._indices == i) for i in test_unseen_index]).reshape(-1) test_unseen = test.index_select(test_unseen_index_for_select) test_unseen_loader = DataLoader(test_unseen, batch_size=self.args.batch_size, follow_batch=['x', 'compound_pair'], shuffle=False, pin_memory=False, num_workers=num_workers) return test_loader, test_unseen_loader
[docs] def run_inference(self): """ Runs the inference process for the model. This method sets the model to evaluation mode, logs the beginning of the test, and if the current process is the main process, it evaluates the model using the provided evaluation function. The evaluation metrics are then logged. The method waits for all processes to complete before finishing. Returns: None """ self.model.eval() self.logger.log_message(f"Begin test") if self.accelerator.is_main_process: metrics, _, _ = evaluate_mean_pocket_cls_coord_multi_task(self.accelerator, self.args, self.test_unseen_loader, self.accelerator.unwrap_model(self.model), self.com_coord_criterion, self.criterion, self.pocket_cls_criterion, self.pocket_coord_criterion, self.args.relative_k, self.accelerator.device, pred_dis=self.args.pred_dis, use_y_mask=False, stage=2) self.logger.log_stats(metrics, 0, self.args, prefix="Test_unseen") self.accelerator.wait_for_everyone()
if __name__ == '__main__': infer = PromptBindInference(config_path='options/test_args.yml') infer.run_inference()