promptbind.test_promptbind.PromptBindInference

class promptbind.test_promptbind.PromptBindInference(config_path='args.yml')[source]
__init__(config_path='args.yml')[source]

Initializes the class with the given configuration path.

Parameters:

config_path (str) – Path to the configuration file. Defaults to ‘args.yml’.

config_path

Path to the configuration file.

Type:

str

args

Arguments loaded from the configuration file.

Type:

argparse.Namespace

accelerator

Accelerator setup for the model.

logger

Logger setup for logging information.

device

Device to be used for computation, set to ‘cuda’.

Type:

str

model

Loaded model.

criterion

Main criterion for model evaluation.

com_coord_criterion

Criterion for center of mass coordinates.

pocket_cls_criterion

Criterion for pocket classification.

pocket_coord_criterion

Criterion for pocket coordinates.

test_loader

Data loader for test data.

test_unseen_loader

Data loader for unseen test data.

Methods

__init__([config_path])

Initializes the class with the given configuration path.

load_args()

load_model()

run_inference()

Runs the inference process for the model.

setup_accelerator()

setup_criterions()

setup_data_loaders()

setup_logger()