Source code for promptbind.train_promptbind

import numpy as np
import pandas as pd
import os
from tqdm.auto import tqdm

import torch

from promptbind.data.data import get_data
from torch_geometric.loader import DataLoader
from promptbind.utils.metrics import *
from promptbind.utils.utils import *
from datetime import datetime
from promptbind.utils.logging_utils import Logger
import sys
import argparse
from torch.utils.data import RandomSampler
import random
from torch_scatter import scatter_mean
from promptbind.utils.metrics_to_tsb import metrics_runtime_no_prefix
from torch.utils.tensorboard import SummaryWriter
# from torch.nn.utils import clip_grad_norm_
from accelerate import Accelerator
from accelerate import DistributedDataParallelKwargs
from accelerate.utils import set_seed


[docs] def Seed_everything(seed=42): random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False
if __name__ == '__main__': parser = argparse.ArgumentParser(description='FABind model training.') parser.add_argument("-m", "--mode", type=int, default=0, help="mode specify the model to use.") parser.add_argument("-d", "--data", type=str, default="0", help="data specify the data to use. \ 0 for re-docking, 1 for self-docking.") parser.add_argument('--seed', type=int, default=128, help="seed to use.") parser.add_argument("--gs-tau", type=float, default=1, help="Tau for the temperature-based softmax.") parser.add_argument("--gs-hard", action='store_true', default=False, help="Hard mode for gumbel softmax.") parser.add_argument("--batch_size", type=int, default=8, help="batch size.") parser.add_argument("--restart", type=str, default=None, help="continue the training from the model we saved from scratch.") parser.add_argument("--reload", type=str, default=None, help="continue the training from the model we saved.") parser.add_argument("--addNoise", type=str, default=None, help="shift the location of the pocket center in each training sample \ such that the protein pocket encloses a slightly different space.") pair_interaction_mask = parser.add_mutually_exclusive_group() # use_equivalent_native_y_mask is probably a better choice. pair_interaction_mask.add_argument("--use_y_mask", action='store_true', default=False, help="mask the pair interaction during pair interaction loss evaluation based on data.real_y_mask. \ real_y_mask=True if it's the native pocket that ligand binds to.") pair_interaction_mask.add_argument("--use_equivalent_native_y_mask", action='store_true', default=False, help="mask the pair interaction during pair interaction loss evaluation based on data.equivalent_native_y_mask. \ real_y_mask=True if most of the native interaction between ligand and protein happen inside this pocket.") parser.add_argument("--use_affinity_mask", type=int, default=0, help="mask affinity in loss evaluation based on data.real_affinity_mask") parser.add_argument("--affinity_loss_mode", type=int, default=1, help="define which affinity loss function to use.") parser.add_argument("--pred_dis", type=int, default=1, help="pred distance map or predict contact map.") parser.add_argument("--posweight", type=int, default=8, help="pos weight in pair contact loss, not useful if args.pred_dis=1") parser.add_argument("--relative_k", type=float, default=0.01, help="adjust the strength of the affinity loss head relative to the pair interaction loss.") parser.add_argument("-r", "--relative_k_mode", type=int, default=0, help="define how the relative_k changes over epochs") parser.add_argument("--resultFolder", type=str, default="./result", help="information you want to keep a record.") parser.add_argument("--label", type=str, default="", help="information you want to keep a record.") parser.add_argument("--use-whole-protein", action='store_true', default=False, help="currently not used.") parser.add_argument("--data-path", type=str, default="/PDBbind_data/pdbbind2020", help="Data path.") parser.add_argument("--exp-name", type=str, default="", help="data path.") parser.add_argument("--tqdm-interval", type=float, default=0.1, help="tqdm bar update interval") parser.add_argument("--lr", type=float, default=0.0001) parser.add_argument("--pocket-coord-huber-delta", type=float, default=3.0) parser.add_argument("--coord-loss-function", type=str, default='SmoothL1', choices=['MSE', 'SmoothL1']) parser.add_argument("--coord-loss-weight", type=float, default=1.0) parser.add_argument("--pair-distance-loss-weight", type=float, default=1.0) parser.add_argument("--pair-distance-distill-loss-weight", type=float, default=1.0) parser.add_argument("--pocket-cls-loss-weight", type=float, default=1.0) parser.add_argument("--pocket-distance-loss-weight", type=float, default=0.05) parser.add_argument("--pocket-cls-loss-func", type=str, default='bce') # parser.add_argument("--warm-mae-thr", type=float, default=5.0) parser.add_argument("--use-compound-com-cls", action='store_true', default=False, help="only use real pocket to run pocket classification task") parser.add_argument("--compound-coords-init-mode", type=str, default="pocket_center_rdkit", choices=['pocket_center_rdkit', 'pocket_center', 'compound_center', 'perturb_3A', 'perturb_4A', 'perturb_5A', 'random']) parser.add_argument('--trig-layers', type=int, default=1) parser.add_argument('--distmap-pred', type=str, default='mlp', choices=['mlp', 'trig']) parser.add_argument('--mean-layers', type=int, default=3) parser.add_argument('--n-iter', type=int, default=5) parser.add_argument('--inter-cutoff', type=float, default=10.0) parser.add_argument('--intra-cutoff', type=float, default=8.0) parser.add_argument('--refine', type=str, default='refine_coord', choices=['stack', 'refine_coord']) parser.add_argument('--coordinate-scale', type=float, default=5.0) parser.add_argument('--geometry-reg-step-size', type=float, default=0.001) parser.add_argument('--lr-scheduler', type=str, default="constant", choices=['constant', 'poly_decay', 'cosine_decay', 'cosine_decay_restart', 'exp_decay']) parser.add_argument('--add-attn-pair-bias', action='store_true', default=False) parser.add_argument('--explicit-pair-embed', action='store_true', default=False) parser.add_argument('--opm', action='store_true', default=False) parser.add_argument('--add-cross-attn-layer', action='store_true', default=False) parser.add_argument('--rm-layernorm', action='store_true', default=False) parser.add_argument('--keep-trig-attn', action='store_true', default=False) parser.add_argument('--pocket-radius', type=float, default=20.0) parser.add_argument('--rm-LAS-constrained-optim', action='store_true', default=False) parser.add_argument('--rm-F-norm', action='store_true', default=False) parser.add_argument('--norm-type', type=str, default="per_sample", choices=['per_sample', '4_sample', 'all_sample']) # parser.add_argument("--only-predicted-pocket-mae-thr", type=float, default=3.0) parser.add_argument('--noise-for-predicted-pocket', type=float, default=5.0) parser.add_argument('--test-random-rotation', action='store_true', default=False) parser.add_argument('--random-n-iter', action='store_true', default=False) parser.add_argument('--clip-grad', action='store_true', default=False) # one batch actually contains 20000 samples, not the size of training set parser.add_argument("--sample-n", type=int, default=0, help="number of samples in one epoch.") parser.add_argument('--fix-pocket', action='store_true', default=False) parser.add_argument('--pocket-idx-no-noise', action='store_true', default=False) parser.add_argument('--ablation-no-attention', action='store_true', default=False) parser.add_argument('--ablation-no-attention-with-cross-attn', action='store_true', default=False) parser.add_argument('--redocking', action='store_true', default=False) parser.add_argument('--redocking-no-rotate', action='store_true', default=False) parser.add_argument("--pocket-pred-layers", type=int, default=1, help="number of layers for pocket pred model.") parser.add_argument('--pocket-pred-n-iter', type=int, default=1, help="number of iterations for pocket pred model.") parser.add_argument('--use-esm2-feat', action='store_true', default=False) parser.add_argument("--center-dist-threshold", type=float, default=4.0) parser.add_argument("--mixed-precision", type=str, default='no', choices=['no', 'fp16']) parser.add_argument('--disable-tqdm', action='store_true', default=False) parser.add_argument('--log-interval', type=int, default=50) parser.add_argument('--optim', type=str, default='adamw', choices=['adam', 'adamw']) parser.add_argument("--warmup-epochs", type=int, default=0, help="used in combination with relative_k_mode.") parser.add_argument("--total-epochs", type=int, default=400, help="option to switch training data after certain epochs.") parser.add_argument('--disable-validate', action='store_true', default=False) parser.add_argument('--disable-tensorboard', action='store_true', default=False) parser.add_argument("--hidden-size", type=int, default=256) parser.add_argument("--weight-decay", type=float, default=0.01) parser.add_argument("--stage-prob", type=float, default=0.25) parser.add_argument("--pocket-pred-hidden-size", type=int, default=128) parser.add_argument("--local-eval", action='store_true', default=False) parser.add_argument("--train-ligand-torsion-noise", action='store_true', default=False) parser.add_argument("--train-pred-pocket-noise", type=float, default=0.0) parser.add_argument('--esm2-concat-raw', action='store_true', default=False) parser.add_argument('--pocket-prompt-nf', type=int, default=64, help="") parser.add_argument('--complex-prompt-nf', type=int, default=64, help="") parser.add_argument("--ckpt", type=str, default="ckpt/best_model.bin") args = parser.parse_args() ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) accelerator = Accelerator(kwargs_handlers=[ddp_kwargs], mixed_precision=args.mixed_precision) set_seed(args.seed) Seed_everything(seed=args.seed) pre = f"{args.resultFolder}/{args.exp_name}" if accelerator.is_main_process: os.system(f"mkdir -p {pre}/models") os.system(f"mkdir -p {pre}/metrics") if not args.disable_tensorboard: tsb_runtime_dir = f"{pre}/tsb_runtime" os.system(f"mkdir -p {tsb_runtime_dir}") train_writer = SummaryWriter(log_dir=f'{tsb_runtime_dir}/train') valid_writer = SummaryWriter(log_dir=f'{tsb_runtime_dir}/valid') test_writer = SummaryWriter(log_dir=f'{tsb_runtime_dir}/test') test_writer_use_predicted_pocket = SummaryWriter(log_dir=f'{tsb_runtime_dir}/test_use_predicted_pocket') accelerator.wait_for_everyone() timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M") logger = Logger(accelerator=accelerator, log_path=f'{pre}/{timestamp}.log') logger.log_message(f"{' '.join(sys.argv)}") # torch.set_num_threads(1) # # ----------without this, I could get 'RuntimeError: received 0 items of ancdata'----------- torch.multiprocessing.set_sharing_strategy('file_system') # train, valid, test: only native pocket. train_after_warm_up, all_pocket_test include all other pockets(protein center and P2rank result) if args.redocking: args.compound_coords_init_mode = "redocking" elif args.redocking_no_rotate: args.redocking = True args.compound_coords_init_mode = "redocking_no_rotate" train, valid, test= get_data(args, logger, addNoise=args.addNoise, use_whole_protein=args.use_whole_protein, compound_coords_init_mode=args.compound_coords_init_mode, pre=args.data_path) logger.log_message(f"data point train: {len(train)}, valid: {len(valid)}, test: {len(test)}") num_workers = 0 if args.sample_n > 0: sampler = RandomSampler(train, replacement=True, num_samples=args.sample_n) train_loader = DataLoader(train, batch_size=args.batch_size, follow_batch=['x', 'compound_pair'], sampler=sampler, pin_memory=False, num_workers=num_workers) # sampler_update = RandomSampler(train_update, replacement=True, num_samples=args.sample_n) # train_update_loader = DataLoader(train_update, batch_size=args.batch_size, follow_batch=['x', 'compound_pair'], sampler=sampler_update, pin_memory=False, num_workers=num_workers) else: train_loader = DataLoader(train, batch_size=args.batch_size, follow_batch=['x', 'compound_pair'], shuffle=True, pin_memory=False, num_workers=num_workers) # train_update_loader = DataLoader(train_update, batch_size=args.batch_size, follow_batch=['x', 'compound_pair'], shuffle=True, pin_memory=False, num_workers=num_workers) # sampler2 = RandomSampler(train_after_warm_up, replacement=True, num_samples=args.sample_n) # train_after_warm_up_loader = DataLoader(train_after_warm_up, batch_size=args.batch_size, follow_batch=['x', 'compound_pair'], sampler=sampler2, pin_memory=False, num_workers=num_workers) # valid_batch_size = test_batch_size = 4 valid_loader = DataLoader(valid, batch_size=args.batch_size, follow_batch=['x', 'compound_pair'], shuffle=False, pin_memory=False, num_workers=num_workers) test_loader = DataLoader(test, batch_size=args.batch_size, follow_batch=['x', 'compound_pair'], shuffle=False, pin_memory=False, num_workers=num_workers) # valid_update_loader = DataLoader(valid_update, batch_size=args.batch_size, follow_batch=['x', 'compound_pair'], shuffle=False, pin_memory=False, num_workers=num_workers) # test_update_loader = DataLoader(test_update, batch_size=args.batch_size, follow_batch=['x', 'compound_pair'], shuffle=False, pin_memory=False, num_workers=num_workers) # not used # all_pocket_test_loader = DataLoader(all_pocket_test, batch_size=2, follow_batch=['x', 'compound_pair'], shuffle=False, pin_memory=False, num_workers=4) # import model is put here due to an error related to torch.utils.data.ConcatDataset after importing torchdrug. from promptbind.models.model import * device = 'cuda' model = get_model(args, logger, device) model.load_state_dict(torch.load(args.ckpt), strict=False) for param in model.parameters() : param.requires_grad = True set_seed(args.seed) Seed_everything(seed=args.seed) if args.optim == "adam": optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) elif args.optim == "adamw": optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) last_epoch = -1 steps_per_epoch = len(train_loader) total_training_steps = args.total_epochs * len(train_loader) if args.lr_scheduler == "constant": scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0, total_iters=args.total_epochs*len(train_loader), last_epoch=last_epoch) elif args.lr_scheduler == "poly_decay": scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.0, total_iters=args.total_epochs*len(train_loader), last_epoch=last_epoch) elif args.lr_scheduler == "exp_decay": scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.995, last_epoch=last_epoch) elif args.lr_scheduler == "cosine_decay": scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.total_epochs*len(train_loader), eta_min=1e-5, last_epoch=last_epoch) elif args.lr_scheduler == "cosine_decay_restart": scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, eta_min=0.0001, last_epoch=last_epoch) ( model, optimizer, scheduler, train_loader, ) = accelerator.prepare( model, optimizer, scheduler, train_loader, ) output_last_epoch_dir = f"{pre}/models/epoch_last" if os.path.exists(output_last_epoch_dir) and os.path.exists(os.path.join(output_last_epoch_dir, "pytorch_model.bin")): # ckpt = os.path.join(args.resultFolder, args.exp_name, 'models', "epoch_last.pt") # model_ckpt, opt_ckpt, model_args, last_epoch = torch.load(ckpt) # model.load_state_dict(model_ckpt, strict=True) # optimizer.load_state_dict(opt_ckpt) accelerator.load_state(output_last_epoch_dir) last_epoch = round(scheduler.state_dict()['last_epoch'] / steps_per_epoch) - 1 logger.log_message(f'Load model from epoch: {last_epoch}') # TODO Future debug when needed # if args.restart: # model_ckpt, opt_ckpt, model_args, last_epoch = torch.load(args.restart) # model.load_state_dict(model_ckpt, strict=True) # optimizer.load_state_dict(opt_ckpt) # last_epoch = -1 # elif args.reload: # model_ckpt, opt_ckpt, model_args, last_epoch = torch.load(args.reload) # model.load_state_dict(model_ckpt, strict=True) # optimizer.load_state_dict(opt_ckpt) if args.pred_dis: criterion = nn.MSELoss() pred_dis = True else: criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(args.posweight)) if args.coord_loss_function == 'MSE': com_coord_criterion = nn.MSELoss() elif args.coord_loss_function == 'SmoothL1': com_coord_criterion = nn.SmoothL1Loss() if args.pocket_cls_loss_func == 'bce': pocket_cls_criterion = nn.BCEWithLogitsLoss(reduction='mean') pocket_coord_criterion = nn.HuberLoss(delta=args.pocket_coord_huber_delta) # metrics_list = [] # valid_metrics_list = [] # test_metrics_list = [] # test_metrics_stage2_list = [] best_auroc = 0 best_f1_1 = 0 epoch_not_improving = 0 logger.log_message(f"Total epochs: {args.total_epochs}") logger.log_message(f"Total training steps: {total_training_steps}") test_metrics = {"pocket_cls_accuracy":[], "rmsd":[], "rmsd < 2A":[], "rmsd < 5A":[], "rmsd 25%":[], "rmsd 50%":[], "rmsd 75%":[], "centroid_dis":[], "centroid_dis < 2A":[], "centroid_dis < 5A":[], "centroid_dis 25%":[], "centroid_dis 50%":[], "centroid_dis 75%":[]} model.train() for epoch in range(last_epoch+1, args.total_epochs): os.system(f"mkdir -p {pre}/prompt_components/epoch-{epoch}") y_list = [] y_pred_list = [] com_coord_list = [] com_coord_pred_list = [] rmsd_list = [] rmsd_2A_list = [] rmsd_5A_list = [] centroid_dis_list = [] centroid_dis_2A_list = [] centroid_dis_5A_list = [] pocket_coord_list = [] pocket_coord_pred_list = [] # pocket_coord_pred_for_update_list = [] pocket_cls_list = [] pocket_cls_pred_list = [] pocket_cls_pred_round_list = [] protein_len_list = [] # pdb_list = [] count = 0 skip_count = 0 batch_loss = 0.0 batch_by_pred_loss = 0.0 batch_distill_loss = 0.0 com_coord_batch_loss = 0.0 pocket_cls_batch_loss = 0.0 pocket_coord_batch_loss = 0.0 keepNode_less_5_count = 0 if args.disable_tqdm: data_iter = train_loader else: data_iter = tqdm(train_loader, mininterval=args.tqdm_interval, disable=not accelerator.is_main_process) for batch_id, data in enumerate(data_iter, start=1): optimizer.zero_grad() # Denote num_atom as N, num_amino_acid_of_pocket as M, num_amino_acid_of_protein as L # com_coord_pred: [B x N, 3] # y_pred, y_pred_by_coord: [B, N x M] # pocket_cls_pred, protein_out_mask_whole: [B, L] # p_coords_batched_whole: [B, L, 3] # pred_pocket_center: [B, 3] com_coord_pred, compound_batch, y_pred, y_pred_by_coord, pocket_cls_pred, pocket_cls, protein_out_mask_whole, p_coords_batched_whole, pred_pocket_center, dis_map, keepNode_less_5, pocket_prompt_component, complex_prompt_component, pocket_prompt_feat, complex_prompt_feat = model(data, train=True) # y = data.y if y_pred.isnan().any() or com_coord_pred.isnan().any() or pocket_cls_pred.isnan().any() or pred_pocket_center.isnan().any() or y_pred_by_coord.isnan().any(): print(f"nan occurs in epoch {epoch}") continue com_coord = data.coords pocket_cls_loss = args.pocket_cls_loss_weight * pocket_cls_criterion(pocket_cls_pred, pocket_cls.float()) * (protein_out_mask_whole.numel() / protein_out_mask_whole.sum()) pocket_coord_loss = args.pocket_distance_loss_weight * pocket_coord_criterion(pred_pocket_center, data.coords_center) contact_loss = args.pair_distance_loss_weight * criterion(y_pred, dis_map) if len(dis_map) > 0 else torch.tensor([0]) contact_by_pred_loss = args.pair_distance_loss_weight * criterion(y_pred_by_coord, dis_map) if len(dis_map) > 0 else torch.tensor([0]) contact_distill_loss = args.pair_distance_distill_loss_weight * criterion(y_pred_by_coord, y_pred) if len(y_pred) > 0 else torch.tensor([0]) com_coord_loss = args.coord_loss_weight * com_coord_criterion(com_coord_pred, com_coord) if len(com_coord) > 0 else torch.tensor([0]) sd = ((com_coord_pred.detach() - com_coord) ** 2).sum(dim=-1) rmsd = scatter_mean(sd, index=compound_batch, dim=0).sqrt().detach().cpu() centroid_pred = scatter_mean(src=com_coord_pred, index=compound_batch, dim=0) centroid_true = scatter_mean(src=com_coord, index=compound_batch, dim=0) centroid_dis = (centroid_pred - centroid_true).norm(dim=-1) loss = com_coord_loss + \ contact_loss + contact_by_pred_loss + contact_distill_loss + \ pocket_cls_loss + \ pocket_coord_loss accelerator.backward(loss) if args.clip_grad: # clip_grad_norm_(model.parameters(), max_norm=1.0, error_if_nonfinite=True) if accelerator.sync_gradients: accelerator.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() scheduler.step() batch_loss += len(y_pred)*contact_loss.cpu().item() batch_by_pred_loss += len(y_pred_by_coord)*contact_by_pred_loss.cpu().item() batch_distill_loss += len(y_pred_by_coord)*contact_distill_loss.cpu().item() com_coord_batch_loss += len(com_coord_pred)*com_coord_loss.cpu().item() pocket_cls_batch_loss += len(pocket_cls_pred)*pocket_cls_loss.cpu().item() pocket_coord_batch_loss += len(pred_pocket_center)*pocket_coord_loss.cpu().item() keepNode_less_5_count += keepNode_less_5 y_list.append(dis_map.detach().cpu()) y_pred_list.append(y_pred.detach().cpu()) com_coord_list.append(com_coord) com_coord_pred_list.append(com_coord_pred.detach().cpu()) rmsd_list.append(rmsd.detach().cpu()) rmsd_2A_list.append((rmsd.detach().cpu() < 2).float()) rmsd_5A_list.append((rmsd.detach().cpu() < 5).float()) centroid_dis_list.append(centroid_dis.detach().cpu()) centroid_dis_2A_list.append((centroid_dis.detach().cpu() < 2).float()) centroid_dis_5A_list.append((centroid_dis.detach().cpu() < 5).float()) batch_len = protein_out_mask_whole.sum(dim=1).detach().cpu() protein_len_list.append(batch_len) pocket_coord_pred_list.append(pred_pocket_center) pocket_coord_list.append(data.coords_center) # use hard to calculate acc and skip samples for i, j in enumerate(batch_len): count += 1 pocket_cls_list.append(pocket_cls.detach().cpu()[i][:j]) pocket_cls_pred_list.append(pocket_cls_pred.detach().cpu()[i][:j].sigmoid()) pocket_cls_pred_round_list.append(pocket_cls_pred.detach().cpu()[i][:j].sigmoid().round().int()) pred_index_bool = (pocket_cls_pred.detach().cpu()[i][:j].sigmoid().round().int() == 1) if pred_index_bool.sum() == 0: # all the prediction is False, skip skip_count += 1 if batch_id % args.log_interval == 0: stats_dict = {} stats_dict['step'] = batch_id stats_dict['lr'] = optimizer.param_groups[0]['lr'] stats_dict['contact_loss'] = contact_loss.cpu().item() stats_dict['contact_by_pred_loss'] = contact_by_pred_loss.cpu().item() stats_dict['contact_distill_loss'] = contact_distill_loss.cpu().item() stats_dict['com_coord_loss'] = com_coord_loss.cpu().item() stats_dict['pocket_cls_loss'] = pocket_cls_loss.cpu().item() stats_dict['pocket_coord_loss'] = pocket_coord_loss.cpu().item() logger.log_stats(stats_dict, epoch, args, prefix='train') torch.cuda.empty_cache() y = torch.cat(y_list) y_pred = torch.cat(y_pred_list) # y, y_pred = accelerator.gather((y, y_pred)) com_coord = torch.cat(com_coord_list) com_coord_pred = torch.cat(com_coord_pred_list) # com_coord, com_coord_pred = accelerator.gather((com_coord, com_coord_pred)) rmsd = torch.cat(rmsd_list) rmsd_2A = torch.cat(rmsd_2A_list) rmsd_5A = torch.cat(rmsd_5A_list) # rmsd, rmsd_2A, rmsd_5A = accelerator.gather((rmsd, rmsd_2A, rmsd_5A)) rmsd_25 = torch.quantile(rmsd, 0.25) rmsd_50 = torch.quantile(rmsd, 0.50) rmsd_75 = torch.quantile(rmsd, 0.75) centroid_dis = torch.cat(centroid_dis_list) centroid_dis_2A = torch.cat(centroid_dis_2A_list) centroid_dis_5A = torch.cat(centroid_dis_5A_list) # centroid_dis, centroid_dis_2A, centroid_dis_5A = accelerator.gather((centroid_dis, centroid_dis_2A, centroid_dis_5A)) centroid_dis_25 = torch.quantile(centroid_dis, 0.25) centroid_dis_50 = torch.quantile(centroid_dis, 0.50) centroid_dis_75 = torch.quantile(centroid_dis, 0.75) pocket_cls = torch.cat(pocket_cls_list) pocket_cls_pred = torch.cat(pocket_cls_pred_list) pocket_cls_pred_round = torch.cat(pocket_cls_pred_round_list) pocket_coord_pred = torch.cat(pocket_coord_pred_list) pocket_coord = torch.cat(pocket_coord_list) protein_len = torch.cat(protein_len_list) # pocket_cls, pocket_cls_pred, pocket_cls_pred_round, pocket_coord_pred, pocket_coord, protein_len = accelerator.gather( # (pocket_cls, pocket_cls_pred, pocket_cls_pred_round, pocket_coord_pred, pocket_coord, protein_len) # ) # count *= accelerator.num_processes # skip_count *= accelerator.num_processes # batch_loss *= accelerator.num_processes # batch_by_pred_loss *= accelerator.num_processes # batch_distill_loss *= accelerator.num_processes # com_coord_batch_loss *= accelerator.num_processes # pocket_cls_batch_loss *= accelerator.num_processes # pocket_coord_batch_loss *= accelerator.num_processes # keepNode_less_5_count *= accelerator.num_processes pocket_cls_accuracy = (pocket_cls_pred_round == pocket_cls).sum().cpu().item() / len(pocket_cls_pred_round) metrics = {"samples": count, "skip_samples": skip_count, "keepNode < 5": keepNode_less_5_count} metrics.update({"contact_loss":batch_loss/len(y_pred), "contact_by_pred_loss":batch_by_pred_loss/len(y_pred), "contact_distill_loss": batch_distill_loss/len(y_pred)}) metrics.update({"com_coord_huber_loss": com_coord_batch_loss/len(com_coord_pred)}) metrics.update({"rmsd": rmsd.mean().cpu().item(), "rmsd < 2A": rmsd_2A.mean().cpu().item(), "rmsd < 5A": rmsd_5A.mean().cpu().item()}) metrics.update({"rmsd 25%": rmsd_25.cpu().item(), "rmsd 50%": rmsd_50.cpu().item(), "rmsd 75%": rmsd_75.cpu().item()}) metrics.update({"centroid_dis": centroid_dis.mean().cpu().item(), "centroid_dis < 2A": centroid_dis_2A.mean().cpu().item(), "centroid_dis < 5A": centroid_dis_5A.mean().cpu().item()}) metrics.update({"centroid_dis 25%": centroid_dis_25.cpu().item(), "centroid_dis 50%": centroid_dis_50.cpu().item(), "centroid_dis 75%": centroid_dis_75.cpu().item()}) metrics.update({"pocket_cls_bce_loss": pocket_cls_batch_loss / len(pocket_coord_pred)}) metrics.update({"pocket_coord_mse_loss": pocket_coord_batch_loss / len(pocket_coord_pred)}) metrics.update({"pocket_cls_accuracy": pocket_cls_accuracy}) metrics.update(pocket_metrics(pocket_coord_pred, pocket_coord)) # logger.log_message(f"epoch {epoch:<4d}, train, " + print_metrics(metrics)) logger.log_stats(metrics, epoch, args, prefix="Train") if accelerator.is_main_process and not args.disable_tensorboard: metrics_runtime_no_prefix(metrics, train_writer, epoch) pocket_prompt_node_component = pocket_prompt_component[0].squeeze(0).detach().cpu().numpy() pocket_prompt_coord_component = pocket_prompt_component[1].squeeze(0).detach().cpu().numpy() complex_prompt_node_component = complex_prompt_component[0].squeeze(0).detach().cpu().numpy() complex_prompt_coord_component = complex_prompt_component[1].squeeze(0).detach().cpu().numpy() np.save(f"{pre}/prompt_components/epoch-{epoch}/pocket_prompt_node_components.npy", pocket_prompt_node_component) np.save(f"{pre}/prompt_components/epoch-{epoch}/pocket_prompt_coord_components.npy", pocket_prompt_coord_component) np.save(f"{pre}/prompt_components/epoch-{epoch}/complex_prompt_node_components.npy", complex_prompt_node_component) np.save(f"{pre}/prompt_components/epoch-{epoch}/complex_prompt_coord_components.npy", complex_prompt_coord_component) accelerator.wait_for_everyone() # metrics_list.append(metrics) # release memory y, y_pred = None, None com_coord, com_coord_pred = None, None rmsd, rmsd_2A, rmsd_5A = None, None, None centroid_dis, centroid_dis_2A, centroid_dis_5A = None, None, None pocket_cls, pocket_cls_pred, pocket_cls_pred_round, pocket_coord_pred, pocket_coord, protein_len = None, None, None, None, None, None model.eval() # TODO check and think # use_y_mask = args.use_equivalent_native_y_mask or args.use_y_mask use_y_mask = False logger.log_message(f"Begin validation") if accelerator.is_main_process: if not args.disable_validate: metrics, _, _ = evaluate_mean_pocket_cls_coord_multi_task(accelerator, args, valid_loader, model, com_coord_criterion, criterion, pocket_cls_criterion, pocket_coord_criterion, args.relative_k, device, pred_dis=pred_dis, use_y_mask=use_y_mask, stage=1) # valid_metrics_list.append(metrics) # logger.log_message(f"epoch {epoch:<4d}, valid, " + print_metrics(metrics)) logger.log_stats(metrics, epoch, args, prefix="Valid") metrics_runtime_no_prefix(metrics, valid_writer, epoch) logger.log_message(f"Begin test") if accelerator.is_main_process: #metrics, _, _ = evaluate_mean_pocket_cls_coord_multi_task(accelerator, args, test_loader, accelerator.unwrap_model(model), com_coord_criterion, criterion, pocket_cls_criterion, pocket_coord_criterion, args.relative_k, #accelerator.device, pred_dis=pred_dis, use_y_mask=use_y_mask, stage=1) # test_metrics_list.append(metrics) # logger.log_message(f"epoch {epoch:<4d}, test, " + print_metrics(metrics)) #logger.log_stats(metrics, epoch, args, prefix="Test") #if not args.disable_tensorboard: #metrics_runtime_no_prefix(metrics, test_writer, epoch) metrics, _, _ = evaluate_mean_pocket_cls_coord_multi_task(accelerator, args, test_loader, accelerator.unwrap_model(model), com_coord_criterion, criterion, pocket_cls_criterion, pocket_coord_criterion, args.relative_k, accelerator.device, pred_dis=pred_dis, use_y_mask=use_y_mask, stage=2) # test_metrics_stage2_list.append(metrics) # logger.log_message(f"epoch {epoch:<4d}, testp, " + print_metrics(metrics)) test_metrics["pocket_cls_accuracy"].append(metrics["pocket_cls_accuracy"]) test_metrics["rmsd"].append(metrics["rmsd"]), test_metrics["rmsd < 2A"].append(metrics["rmsd < 2A"]), test_metrics["rmsd < 5A"].append(metrics["rmsd < 5A"]) test_metrics["rmsd 25%"].append(metrics["rmsd 25%"]), test_metrics["rmsd 50%"].append(metrics["rmsd 50%"]), test_metrics["rmsd 75%"].append(metrics["rmsd 75%"]) test_metrics["centroid_dis"].append(metrics["centroid_dis"]), test_metrics["centroid_dis < 2A"].append(metrics["centroid_dis < 2A"]), test_metrics["centroid_dis < 5A"].append(metrics["centroid_dis < 5A"]) test_metrics["centroid_dis 25%"].append(metrics["centroid_dis 25%"]), test_metrics["centroid_dis 50%"].append(metrics["centroid_dis 50%"]), test_metrics["centroid_dis 75%"].append(metrics["centroid_dis 75%"]) pd.DataFrame(data=test_metrics).to_csv(f"{pre}/metrics/test_metrics.csv", index=False) logger.log_stats(metrics, epoch, args, prefix="Test_pp") if not args.disable_tensorboard: metrics_runtime_no_prefix(metrics, test_writer_use_predicted_pocket, epoch) # ckpt = (model.state_dict(), optimizer.state_dict(), args, epoch) # torch.save(ckpt, f"{pre}/models/epoch_{epoch}.pt") # torch.save(ckpt, f"{pre}/models/epoch_last.pt") output_dir = f"{pre}/models/epoch_{epoch}" accelerator.save_state(output_dir=output_dir) accelerator.save_state(output_dir=output_last_epoch_dir) accelerator.wait_for_everyone()