import torch
from .metrics import *
import numpy as np
import pandas as pd
import scipy.spatial
from torch_geometric.data import HeteroData
from tqdm.auto import tqdm
from torch_scatter import scatter_mean
from rdkit import Chem
from rdkit.Chem import rdMolTransforms
from .post_optim_utils import post_optimize_compound_coords, post_optimize_compound_coords_lbfgs
# from feature_utils import read_mol
import sys
from io import StringIO
[docs]
def read_mol(sdf_fileName, mol2_fileName, verbose=False):
# Chem.WrapLogs()
stderr = sys.stderr
sio = sys.stderr = StringIO()
mol = Chem.MolFromMolFile(sdf_fileName, sanitize=False)
problem = False
try:
Chem.SanitizeMol(mol)
mol = Chem.RemoveHs(mol)
sm = Chem.MolToSmiles(mol)
except Exception as e:
sm = str(e)
problem = True
if problem:
mol = Chem.MolFromMol2File(mol2_fileName, sanitize=False)
problem = False
try:
Chem.SanitizeMol(mol)
mol = Chem.RemoveHs(mol)
sm = Chem.MolToSmiles(mol)
problem = False
except Exception as e:
sm = str(e)
problem = True
if verbose:
print(sio.getvalue())
sys.stderr = stderr
return mol, problem
[docs]
def read_pdbbind_data(fileName):
with open(fileName) as f:
a = f.readlines()
info = []
for line in a:
if line[0] == '#':
continue
lines, ligand = line.split('//')
pdb, resolution, year, affinity, raw = lines.strip().split(' ')
ligand = ligand.strip().split('(')[1].split(')')[0]
# print(lines, ligand)
info.append([pdb, resolution, year, affinity, raw, ligand])
info = pd.DataFrame(info, columns=['pdb', 'resolution', 'year', 'affinity', 'raw', 'ligand'])
info.year = info.year.astype(int)
info.affinity = info.affinity.astype(float)
return info
[docs]
def compute_dis_between_two_vector(a, b):
return (((a - b)**2).sum())**0.5
[docs]
def get_protein_edge_features_and_index(protein_edge_index, protein_edge_s, protein_edge_v, keepNode):
# protein
input_edge_list = []
input_protein_edge_feature_idx = []
new_node_index = np.cumsum(keepNode) - 1
keepEdge = keepNode[protein_edge_index].min(axis=0)
new_edge_inex = new_node_index[protein_edge_index]
input_edge_idx = torch.tensor(new_edge_inex[:, keepEdge], dtype=torch.long)
input_protein_edge_s = protein_edge_s[keepEdge]
input_protein_edge_v = protein_edge_v[keepEdge]
return input_edge_idx, input_protein_edge_s, input_protein_edge_v
# During training, this function will add 5A noise to the ligand/pocket center.
[docs]
def get_keepNode(com, protein_node_xyz, n_node, pocket_radius, use_whole_protein,
use_compound_com_as_pocket, add_noise_to_com, chosen_pocket_com):
if use_whole_protein:
keepNode = np.ones(n_node, dtype=bool)
else:
keepNode = np.zeros(n_node, dtype=bool)
# extract node based on compound COM.
if use_compound_com_as_pocket:
if add_noise_to_com: # com is the mean coordinate of the compound
com = com + add_noise_to_com * (2 * np.random.rand(*com.shape) - 1)
for i, node in enumerate(protein_node_xyz):
dis = compute_dis_between_two_vector(node, com)
keepNode[i] = dis < pocket_radius
if chosen_pocket_com is not None:
another_keepNode = np.zeros(n_node, dtype=bool)
for a_com in chosen_pocket_com:
if add_noise_to_com:
a_com = a_com + add_noise_to_com * (2 * np.random.rand(*a_com.shape) - 1)
for i, node in enumerate(protein_node_xyz):
dis = compute_dis_between_two_vector(node, a_com)
another_keepNode[i] |= dis < pocket_radius
keepNode |= another_keepNode
return keepNode
[docs]
def compute_dis_between_two_vector_tensor(a, b):
return torch.sqrt(torch.sum((a - b)**2, dim=-1))
[docs]
def get_keepNode_tensor(protein_node_xyz, pocket_radius, add_noise_to_com, chosen_pocket_com):
if add_noise_to_com:
chosen_pocket_com = chosen_pocket_com + add_noise_to_com * (2 * torch.rand_like(chosen_pocket_com) - 1)
# Compute the distances between all nodes and the chosen_pocket_com in a vectorized manner
dis = compute_dis_between_two_vector_tensor(protein_node_xyz, chosen_pocket_com.unsqueeze(0))
# Create the keepNode tensor using a boolean mask
keepNode = dis < pocket_radius
return keepNode
[docs]
def get_torsions(m):
m = Chem.RemoveHs(m)
torsionList = []
torsionSmarts = "[!$(*#*)&!D1]-&!@[!$(*#*)&!D1]"
torsionQuery = Chem.MolFromSmarts(torsionSmarts)
matches = m.GetSubstructMatches(torsionQuery)
for match in matches:
idx2 = match[0]
idx3 = match[1]
bond = m.GetBondBetweenAtoms(idx2, idx3)
jAtom = m.GetAtomWithIdx(idx2)
kAtom = m.GetAtomWithIdx(idx3)
for b1 in jAtom.GetBonds():
if b1.GetIdx() == bond.GetIdx():
continue
idx1 = b1.GetOtherAtomIdx(idx2)
for b2 in kAtom.GetBonds():
if (b2.GetIdx() == bond.GetIdx()) or (b2.GetIdx() == b1.GetIdx()):
continue
idx4 = b2.GetOtherAtomIdx(idx3)
# skip 3-membered rings
if idx4 == idx1:
continue
# skip torsions that include hydrogens
if (m.GetAtomWithIdx(idx1).GetAtomicNum() == 1) or (
m.GetAtomWithIdx(idx4).GetAtomicNum() == 1
):
continue
if m.GetAtomWithIdx(idx4).IsInRing():
torsionList.append((idx4, idx3, idx2, idx1))
break
else:
torsionList.append((idx1, idx2, idx3, idx4))
break
break
return torsionList
[docs]
def SetDihedral(conf, atom_idx, new_vale):
rdMolTransforms.SetDihedralRad(
conf, atom_idx[0], atom_idx[1], atom_idx[2], atom_idx[3], new_vale
)
[docs]
def construct_data_from_graph_gvp_mean(args, protein_node_xyz, protein_seq,
coords, compound_node_features, input_atom_edge_list,
input_atom_edge_attr_list, LAS_edge_index, rdkit_coords, compound_coords_init_mode='pocket_center_rdkit', includeDisMap=True, pdb_id=None, group='train', seed=42, data_path=None, contactCutoff=8.0, pocket_radius=20, interactionThresholdDistance=10, compoundMode=1,
add_noise_to_com=None, use_whole_protein=False, use_compound_com_as_pocket=True, chosen_pocket_com=None, random_rotation=False, pocket_idx_no_noise=True, protein_esm2_feat=None):
n_node = protein_node_xyz.shape[0]
# n_compound_node = coords.shape[0]
# normalize the protein and ligand coords
coords_bias = protein_node_xyz.mean(dim=0)
coords = coords - coords_bias.numpy()
protein_node_xyz = protein_node_xyz - coords_bias
# centroid instead of com.
com = coords.mean(axis=0)
if args.train_pred_pocket_noise and group == 'train':
keepNode = get_keepNode(com, protein_node_xyz.numpy(), n_node, pocket_radius, use_whole_protein,
use_compound_com_as_pocket, args.train_pred_pocket_noise, chosen_pocket_com)
else:
keepNode = get_keepNode(com, protein_node_xyz.numpy(), n_node, pocket_radius, use_whole_protein,
use_compound_com_as_pocket, add_noise_to_com, chosen_pocket_com)
keepNode_no_noise = get_keepNode(com, protein_node_xyz.numpy(), n_node, pocket_radius, use_whole_protein,
use_compound_com_as_pocket, None, chosen_pocket_com)
if keepNode.sum() < 5:
# if only include less than 5 residues, simply add first 100 residues.
keepNode[:100] = True
input_node_xyz = protein_node_xyz[keepNode]
# input_edge_idx, input_protein_edge_s, input_protein_edge_v = get_protein_edge_features_and_index(protein_edge_index, protein_edge_s, protein_edge_v, keepNode)
# construct heterogeneous graph data.
data = HeteroData()
# only if your ligand is real this y_contact is meaningful. Distance map between ligand atoms and protein amino acids.
dis_map = scipy.spatial.distance.cdist(input_node_xyz.cpu().numpy(), coords)
# y_contact = dis_map < contactCutoff # contactCutoff is 8A
if includeDisMap:
# treat all distance above 10A as the same.
dis_map[dis_map>interactionThresholdDistance] = interactionThresholdDistance
data.dis_map = torch.tensor(dis_map, dtype=torch.float).flatten()
# TODO The difference between contactCutoff and interactionThresholdDistance:
# contactCutoff is for classification evaluation, interactionThresholdDistance is for distance regression.
# additional information. keep records.
data.node_xyz = input_node_xyz
data.coords = torch.tensor(coords, dtype=torch.float)
# data.y = torch.tensor(y_contact, dtype=torch.float).flatten() # whether the distance between ligand and protein is less than 8A.
# pocket information
if torch.is_tensor(protein_esm2_feat):
data['pocket'].node_feats = protein_esm2_feat[keepNode]
else:
raise ValueError("protein_esm2_feat should be a tensor")
data['pocket'].keepNode = torch.tensor(keepNode, dtype=torch.bool)
data['compound'].node_feats = compound_node_features.float()
data['compound', 'LAS', 'compound'].edge_index = LAS_edge_index
# complex information
n_protein = input_node_xyz.shape[0]
n_protein_whole = protein_node_xyz.shape[0]
n_compound = compound_node_features.shape[0]
# use zero coord to init compound
# data['complex'].node_coords = torch.cat( # [glb_c || compound || glb_p || protein]
# (torch.zeros(n_compound + 2, 3), input_node_xyz), dim=0
# ).float()
if args.local_eval:
if group == 'test':
from accelerate.utils import set_seed
set_seed(seed)
pre = args.data_path
mol, _ = read_mol(f"{pre}/renumber_atom_index_same_as_smiles/{pdb_id}.sdf", None)
rotable_bonds = get_torsions(mol)
values = 3.1415926 * 2 * np.random.rand(len(rotable_bonds))
for idx in range(len(rotable_bonds)):
SetDihedral(mol.GetConformer(), rotable_bonds[idx], values[idx])
Chem.rdMolTransforms.CanonicalizeConformer(mol.GetConformer())
rdkit_coords = uniform_random_rotation(mol.GetConformer().GetPositions())
if args.train_ligand_torsion_noise and group == 'train':
pre = data_path
try:
mol = Chem.MolFromMolFile(f"{pre}/renumber_atom_index_same_as_smiles/{pdb_id}.sdf", sanitize=False)
try:
Chem.SanitizeMol(mol)
except:
pass
mol = Chem.RemoveHs(mol)
# mol, _ = read_mol(f"{pre}/renumber_atom_index_same_as_smiles/{pdb_id}.sdf", None)
except:
raise ValueError(f"cannot find {pdb_id}.sdf in {pre}/renumber_atom_index_same_as_smiles/")
rotable_bonds = get_torsions(mol)
# np.random.seed(np_seed)
values = 3.1415926 * 2 * np.random.rand(len(rotable_bonds))
for idx in range(len(rotable_bonds)):
SetDihedral(mol.GetConformer(), rotable_bonds[idx], values[idx])
Chem.rdMolTransforms.CanonicalizeConformer(mol.GetConformer())
rdkit_coords = uniform_random_rotation(mol.GetConformer().GetPositions())
if compound_coords_init_mode == 'random':
coords_init = 4 * (2 * torch.rand(coords.shape) - 1)
elif compound_coords_init_mode == 'perturb_3A':
coords_init = torch.tensor(coords) + 3 * (2 * torch.rand(coords.shape) - 1)
elif compound_coords_init_mode == 'perturb_4A':
coords_init = torch.tensor(coords) + 4 * (2 * torch.rand(coords.shape) - 1)
elif compound_coords_init_mode == 'perturb_5A':
coords_init = torch.tensor(coords) + 5 * (2 * torch.rand(coords.shape) - 1)
elif compound_coords_init_mode == 'compound_center':
coords_init = torch.tensor(com).reshape(1, 3) + 10 * (2 * torch.rand(coords.shape) - 1)
elif compound_coords_init_mode == 'pocket_center':
coords_init = input_node_xyz.mean(dim=0).reshape(1, 3) + 5 * (2 * torch.rand(coords.shape) - 1)
elif compound_coords_init_mode == 'pocket_center_rdkit':
if random_rotation:
rdkit_coords = torch.tensor(uniform_random_rotation(rdkit_coords))
else:
rdkit_coords = torch.tensor(rdkit_coords)
coords_init = rdkit_coords - rdkit_coords.mean(dim=0).reshape(1, 3) + input_node_xyz.mean(dim=0).reshape(1, 3)
elif compound_coords_init_mode == 'redocking':
coords_rot = torch.tensor(uniform_random_rotation(coords))
coords_init = coords_rot - coords_rot.mean(dim=0).reshape(1, 3) + input_node_xyz.mean(dim=0).reshape(1, 3)
elif compound_coords_init_mode == 'redocking_no_rotate':
coords_rot = torch.tensor(coords)
coords_init = coords_rot - coords_rot.mean(dim=0).reshape(1, 3) + input_node_xyz.mean(dim=0).reshape(1, 3)
# ground truth ligand and pocket
data['complex'].node_coords = torch.cat( # [glb_c || compound || glb_p || protein]
(
torch.zeros(1, 3),
coords_init,
torch.zeros(1, 3),
input_node_xyz
), dim=0
).float()
if compound_coords_init_mode == 'redocking' or compound_coords_init_mode == 'redocking_no_rotate':
data['complex'].node_coords_LAS = torch.cat( # [glb_c || compound || glb_p || protein]
(
torch.zeros(1, 3),
torch.tensor(coords),
torch.zeros(1, 3),
torch.zeros_like(input_node_xyz)
), dim=0
).float()
else:
data['complex'].node_coords_LAS = torch.cat( # [glb_c || compound || glb_p || protein]
(
torch.zeros(1, 3),
rdkit_coords,
torch.zeros(1, 3),
torch.zeros_like(input_node_xyz)
), dim=0
).float()
segment = torch.zeros(n_protein + n_compound + 2)
segment[n_compound+1:] = 1 # compound: 0, protein: 1
data['complex'].segment = segment # protein or ligand
mask = torch.zeros(n_protein + n_compound + 2)
mask[:n_compound+2] = 1 # glb_p can be updated
data['complex'].mask = mask.bool()
is_global = torch.zeros(n_protein + n_compound + 2)
is_global[0] = 1
is_global[n_compound+1] = 1
data['complex'].is_global = is_global.bool()
data['complex', 'c2c', 'complex'].edge_index = input_atom_edge_list[:,:2].long().t().contiguous() + 1
if compound_coords_init_mode == 'redocking' or compound_coords_init_mode == 'redocking_no_rotate':
data['complex', 'LAS', 'complex'].edge_index = torch.nonzero(torch.ones(n_compound, n_compound)).t() + 1
else:
data['complex', 'LAS', 'complex'].edge_index = LAS_edge_index + 1
# ground truth ligand and whole protein
data['complex_whole_protein'].node_coords = torch.cat( # [glb_c || compound || glb_p || protein]
(
torch.zeros(1, 3),
coords_init - coords_init.mean(dim=0).reshape(1, 3), # for pocket prediction module, the ligand is centered at the protein center/origin
torch.zeros(1, 3),
protein_node_xyz
), dim=0
).float()
if compound_coords_init_mode == 'redocking' or compound_coords_init_mode == 'redocking_no_rotate':
data['complex_whole_protein'].node_coords_LAS = torch.cat( # [glb_c || compound || glb_p || protein]
(
torch.zeros(1, 3),
torch.tensor(coords),
torch.zeros(1, 3),
torch.zeros_like(protein_node_xyz)
), dim=0
).float()
else:
data['complex_whole_protein'].node_coords_LAS = torch.cat( # [glb_c || compound || glb_p || protein]
(
torch.zeros(1, 3),
rdkit_coords,
torch.zeros(1, 3),
torch.zeros_like(protein_node_xyz)
), dim=0
).float()
segment = torch.zeros(n_protein_whole + n_compound + 2)
segment[n_compound+1:] = 1 # compound: 0, protein: 1
data['complex_whole_protein'].segment = segment # protein or ligand
mask = torch.zeros(n_protein_whole + n_compound + 2)
mask[:n_compound+2] = 1 # glb_p can be updated
data['complex_whole_protein'].mask = mask.bool()
is_global = torch.zeros(n_protein_whole + n_compound + 2)
is_global[0] = 1
is_global[n_compound+1] = 1
data['complex_whole_protein'].is_global = is_global.bool()
data['complex_whole_protein', 'c2c', 'complex_whole_protein'].edge_index = input_atom_edge_list[:,:2].long().t().contiguous() + 1
if compound_coords_init_mode == 'redocking' or compound_coords_init_mode == 'redocking_no_rotate':
data['complex_whole_protein', 'LAS', 'complex_whole_protein'].edge_index = torch.nonzero(torch.ones(n_compound, n_compound)).t() + 1
else:
data['complex_whole_protein', 'LAS', 'complex_whole_protein'].edge_index = LAS_edge_index + 1
# for stage 3
data['compound'].node_coords = coords_init
data['compound'].rdkit_coords = rdkit_coords
data['compound_atom_edge_list'].x = (input_atom_edge_list[:,:2].long().contiguous() + 1).clone()
data['LAS_edge_list'].x = data['complex', 'LAS', 'complex'].edge_index.clone().t()
# add whole protein information for pocket prediction
data.node_xyz_whole = protein_node_xyz
data.coords_center = torch.tensor(com, dtype=torch.float).unsqueeze(0)
data.seq_whole = protein_seq
data.coord_offset = coords_bias.unsqueeze(0)
# save the pocket index for binary classification
if pocket_idx_no_noise:
data.pocket_idx = torch.tensor(keepNode_no_noise, dtype=torch.int)
else:
data.pocket_idx = torch.tensor(keepNode, dtype=torch.int)
if torch.is_tensor(protein_esm2_feat):
data['protein_whole'].node_feats = protein_esm2_feat
else:
raise ValueError("protein_esm2_feat should be a tensor")
return data, input_node_xyz, keepNode
[docs]
def post_optim_mol(args, accelerator, data, com_coord_pred, compound_batch, LAS_tmp, rigid=False):
post_optim_device='cuda'
for i in range(compound_batch.max().item()+1):
i_mask = (compound_batch == i)
com_coord_pred_i = com_coord_pred[i_mask]
com_coord_i = data[i]['compound'].rdkit_coords
com_coord_pred_center_i = com_coord_pred_i.mean(dim=0).reshape(1, 3)
if args.post_optim is not None:
if rigid:
if args.post_optim == 'adam':
predict_coord = post_optimize_compound_coords(
reference_compound_coords=com_coord_i.to(post_optim_device),
predict_compound_coords=com_coord_pred_i.to(post_optim_device),
LAS_edge_index=None,
mode=args.post_optim_mode,
lr=args.post_optim_lr,
total_epoch=args.post_optim_epoch,
)
if args.post_optim == 'lbfgs':
predict_coord = post_optimize_compound_coords_lbfgs(
reference_compound_coords=com_coord_i.to(post_optim_device),
predict_compound_coords=com_coord_pred_i.to(post_optim_device),
LAS_edge_index=None,
mode=args.post_optim_mode,
lr=args.post_optim_lr,
total_iter=args.post_optim_lbfgs_iter,
total_epoch=args.post_optim_epoch,
)
predict_coord.to(accelerator.device)
predict_coord = predict_coord - predict_coord.mean(dim=0).reshape(1, 3) + com_coord_pred_center_i
com_coord_pred[i_mask] = predict_coord
else:
if args.post_optim == 'adam':
predict_coord = post_optimize_compound_coords(
reference_compound_coords=com_coord_i.to(post_optim_device),
predict_compound_coords=com_coord_pred_i.to(post_optim_device),
# LAS_edge_index=(data[i]['complex', 'LAS', 'complex'].edge_index - data[i]['complex', 'LAS', 'complex'].edge_index.min()).to(post_optim_device),
LAS_edge_index=LAS_tmp[i].to(post_optim_device),
mode=args.post_optim_mode,
lr=args.post_optim_lr,
total_epoch=args.post_optim_epoch,
)
if args.post_optim == 'lbfgs':
predict_coord = post_optimize_compound_coords_lbfgs(
reference_compound_coords=com_coord_i.to(post_optim_device),
predict_compound_coords=com_coord_pred_i.to(post_optim_device),
# LAS_edge_index=(data[i]['complex', 'LAS', 'complex'].edge_index - data[i]['complex', 'LAS', 'complex'].edge_index.min()).to(post_optim_device),
LAS_edge_index=LAS_tmp[i].to(post_optim_device),
mode=args.post_optim_mode,
lr=args.post_optim_lr,
total_iter=args.post_optim_lbfgs_iter,
total_epoch=args.post_optim_epoch,
)
predict_coord = predict_coord.to(accelerator.device)
predict_coord = predict_coord - predict_coord.mean(dim=0).reshape(1, 3) + com_coord_pred_center_i
com_coord_pred[i_mask] = predict_coord
return
[docs]
def evaluate_mean_pocket_cls_coord_multi_task(accelerator, args, data_loader, model, com_coord_criterion, criterion, pocket_cls_criterion, pocket_coord_criterion, relative_k, device, pred_dis=False, info=None, saveFileName=None, use_y_mask=False, skip_y_metrics_evaluation=False, stage=1):
y_list = []
y_pred_list = []
com_coord_list = []
com_coord_pred_list = []
# contain the ground truth for classiifcation(may not all)
pocket_coord_list = []
pocket_coord_pred_list = []
# contain the ground truth for regression(all)
pocket_coord_direct_list = []
pocket_coord_pred_direct_list = []
pocket_cls_list = []
pocket_cls_pred_list = []
# protein_len_list = []
# real_y_mask_list = []
rmsd_list = []
rmsd_2A_list = []
rmsd_5A_list = []
centroid_dis_list = []
centroid_dis_2A_list = []
centroid_dis_5A_list = []
pdb_list = []
com_coord_pred_per_sample_list = []
com_coord_offset_per_sample_list = []
mol_list = []
skip_count = 0
count = 0
batch_loss = 0.0
batch_by_pred_loss = 0.0
batch_distill_loss = 0.0
com_coord_batch_loss = 0.0
pocket_cls_batch_loss = 0.0
pocket_coord_direct_batch_loss = 0.0
keepNode_less_5_count = 0
if args.disable_tqdm:
data_iter = data_loader
else:
if accelerator is not None:
data_iter = tqdm(data_loader, mininterval=args.tqdm_interval, disable=not accelerator.is_main_process)
else:
data_iter = tqdm(data_loader, mininterval=args.tqdm_interval)
for data in data_iter:
try:
data = data.to(device)
LAS_tmp = []
for i in range(len(data)):
LAS_tmp.append(data[i]['compound', 'LAS', 'compound'].edge_index.detach().clone())
with torch.no_grad():
com_coord_pred, compound_batch, y_pred, y_pred_by_coord, pocket_cls_pred, pocket_cls, protein_out_mask_whole, p_coords_batched_whole, pocket_coord_pred_direct, dis_map, keepNode_less_5, pocket_prompt_component, complex_prompt_component, pocket_prompt_feat, complex_prompt_feat = model(data, stage=stage)
# y = data.y
com_coord = data.coords
sd = ((com_coord_pred - com_coord) ** 2).sum(dim=-1)
rmsd = scatter_mean(src=sd, index=compound_batch, dim=0).sqrt()
centroid_pred = scatter_mean(src=com_coord_pred, index=compound_batch, dim=0)
centroid_true = scatter_mean(src=com_coord, index=compound_batch, dim=0)
centroid_dis = (centroid_pred - centroid_true).norm(dim=-1)
if pred_dis:
contact_loss = args.pair_distance_loss_weight * criterion(y_pred, dis_map) if len(dis_map) > 0 else torch.tensor([0])
contact_by_pred_loss = args.pair_distance_loss_weight * criterion(y_pred_by_coord, dis_map) if len(dis_map) > 0 else torch.tensor([0])
contact_distill_loss = args.pair_distance_distill_loss_weight * criterion(y_pred_by_coord, y_pred) if len(y_pred) > 0 else torch.tensor([0])
else:
contact_loss = criterion(y_pred, dis_map) if len(dis_map) > 0 else torch.tensor([0])
y_pred = y_pred.sigmoid()
pocket_cls_loss = args.pocket_cls_loss_weight * pocket_cls_criterion(pocket_cls_pred, pocket_cls.float())
pocket_coord_direct_loss = args.pocket_distance_loss_weight * pocket_coord_criterion(pocket_coord_pred_direct, data.coords_center)
com_coord_loss = args.coord_loss_weight * com_coord_criterion(com_coord_pred, com_coord)
batch_loss += len(y_pred)*contact_loss.item()
batch_by_pred_loss += len(y_pred_by_coord)*contact_by_pred_loss.item()
batch_distill_loss += len(y_pred_by_coord)*contact_distill_loss.item()
com_coord_batch_loss += len(com_coord_pred)*com_coord_loss.item()
pocket_cls_batch_loss += len(pocket_cls_pred)*pocket_cls_loss.item()
pocket_coord_direct_batch_loss += len(pocket_coord_pred_direct)*pocket_coord_direct_loss.item()
keepNode_less_5_count += keepNode_less_5
y_list.append(dis_map)
y_pred_list.append(y_pred.detach())
com_coord_list.append(com_coord)
com_coord_pred_list.append(com_coord_pred.detach())
rmsd_list.append(rmsd.detach())
rmsd_2A_list.append((rmsd.detach() < 2).float())
rmsd_5A_list.append((rmsd.detach() < 5).float())
centroid_dis_list.append(centroid_dis.detach())
centroid_dis_2A_list.append((centroid_dis.detach() < 2).float())
centroid_dis_5A_list.append((centroid_dis.detach() < 5).float())
batch_len = protein_out_mask_whole.sum(dim=1).detach()
# protein_len_list.append(batch_len)
pocket_coord_pred_direct_list.append(pocket_coord_pred_direct.detach())
pocket_coord_direct_list.append(data.coords_center)
for i, j in enumerate(batch_len):
count += 1
pdb_list.append(data.pdb[i])
pocket_cls_list.append(pocket_cls.detach()[i][:j])
pocket_cls_pred_list.append(pocket_cls_pred.detach()[i][:j].sigmoid().round().int())
pred_index_bool = (pocket_cls_pred.detach()[i][:j].sigmoid().round().int() == 1)
if pred_index_bool.sum() != 0:
pred_pocket_center = p_coords_batched_whole.detach()[i][:j][pred_index_bool].mean(dim=0).unsqueeze(0)
pocket_coord_pred_list.append(pred_pocket_center)
pocket_coord_list.append(data.coords_center[i].unsqueeze(0))
else: # all the prediction is False, skip
skip_count += 1
pred_index_true = pocket_cls_pred[i][:j].sigmoid().unsqueeze(-1)
pred_index_false = 1. - pred_index_true
pred_index_prob = torch.cat([pred_index_false, pred_index_true], dim=-1)
pred_index_log_prob = torch.log(pred_index_prob)
pred_index_one_hot = gumbel_softmax_no_random(pred_index_log_prob, tau=args.gs_tau, hard=False)
pred_index_one_hot_true = pred_index_one_hot[:, 1].unsqueeze(-1)
pred_pocket_center_gumbel = pred_index_one_hot_true * p_coords_batched_whole[i][:j]
pred_pocket_center_gumbel_mean = pred_pocket_center_gumbel.sum(dim=0) / pred_index_one_hot_true.sum(dim=0)
pocket_coord_pred_list.append(pred_pocket_center_gumbel_mean.unsqueeze(0).detach())
pocket_coord_list.append(data.coords_center[i].unsqueeze(0))
for i in range(len(data)):
i_mask = (compound_batch == i)
com_coord_pred_i = com_coord_pred[i_mask]
com_coord_pred_per_sample_list.append(com_coord_pred_i.cpu())
com_coord_offset_per_sample_list.append(data[i].coord_offset.cpu())
mol_list.append(data[i].pdb)
torch.cuda.empty_cache()
except Exception as e:
print(e)
continue
y = torch.cat(y_list)
y_pred = torch.cat(y_pred_list)
com_coord = torch.cat(com_coord_list)
com_coord_pred = torch.cat(com_coord_pred_list)
rmsd = torch.cat(rmsd_list)
rmsd_2A = torch.cat(rmsd_2A_list)
rmsd_5A = torch.cat(rmsd_5A_list)
rmsd_25 = torch.quantile(rmsd, 0.25)
rmsd_50 = torch.quantile(rmsd, 0.50)
rmsd_75 = torch.quantile(rmsd, 0.75)
centroid_dis = torch.cat(centroid_dis_list)
centroid_dis_2A = torch.cat(centroid_dis_2A_list)
centroid_dis_5A = torch.cat(centroid_dis_5A_list)
centroid_dis_25 = torch.quantile(centroid_dis, 0.25)
centroid_dis_50 = torch.quantile(centroid_dis, 0.50)
centroid_dis_75 = torch.quantile(centroid_dis, 0.75)
pocket_cls = torch.cat(pocket_cls_list)
pocket_cls_pred = torch.cat(pocket_cls_pred_list)
if len(pocket_coord_pred_list) > 0:
pocket_coord_pred = torch.cat(pocket_coord_pred_list)
pocket_coord = torch.cat(pocket_coord_list)
pocket_coord_pred_direct = torch.cat(pocket_coord_pred_direct_list)
pocket_coord_direct = torch.cat(pocket_coord_direct_list)
pocket_cls_accuracy = (pocket_cls_pred == pocket_cls).sum().item() / len(pocket_cls_pred)
metrics = {"samples": count, "skip_samples": skip_count, "keepNode < 5": keepNode_less_5_count}
metrics.update({"contact_loss":batch_loss/len(y_pred), "contact_by_pred_loss":batch_by_pred_loss/len(y_pred)})
metrics.update({"com_coord_huber_loss": com_coord_batch_loss/len(com_coord_pred)})
# Final evaluation metrics
metrics.update({"rmsd": rmsd.mean().cpu().item(), "rmsd < 2A": rmsd_2A.mean().cpu().item(), "rmsd < 5A": rmsd_5A.mean().cpu().item()})
metrics.update({"rmsd 25%": rmsd_25.cpu().item(), "rmsd 50%": rmsd_50.cpu().item(), "rmsd 75%": rmsd_75.cpu().item()})
metrics.update({"centroid_dis": centroid_dis.mean().cpu().item(), "centroid_dis < 2A": centroid_dis_2A.mean().cpu().item(), "centroid_dis < 5A": centroid_dis_5A.mean().item()})
metrics.update({"centroid_dis 25%": centroid_dis_25.cpu().item(), "centroid_dis 50%": centroid_dis_50.cpu().item(), "centroid_dis 75%": centroid_dis_75.cpu().item()})
metrics.update({"pocket_cls_bce_loss": pocket_cls_batch_loss / len(pocket_cls_pred_list)})
metrics.update({"pocket_coord_mse_loss": pocket_coord_direct_batch_loss / len(pocket_coord_pred_direct)})
metrics.update({"pocket_cls_accuracy": pocket_cls_accuracy})
if len(pocket_coord_pred_list) > 0:
metrics.update(pocket_metrics(pocket_coord_pred, pocket_coord))
return (
metrics, pocket_prompt_feat, complex_prompt_feat,
com_coord_pred_per_sample_list, com_coord_offset_per_sample_list,
pdb_list, mol_list
)
[docs]
@torch.no_grad()
def evaluate_mean_pocket_cls_coord_pocket_pred(args, data_loader, model, com_coord_criterion, criterion, pocket_cls_criterion, pocket_coord_criterion, relative_k, device, pred_dis=False, info=None, saveFileName=None, use_y_mask=False, skip_y_metrics_evaluation=False, stage=1):
# contain the ground truth for classiifcation(may not all)
pocket_coord_list = []
pocket_coord_pred_list = []
# contain the ground truth for regression(all)
pocket_coord_direct_list = []
pocket_coord_pred_direct_list = []
pocket_cls_list = []
pocket_cls_pred_list = []
pdb_list = []
skip_count = 0
count = 0
pocket_cls_batch_loss = 0.0
pocket_coord_direct_batch_loss = 0.0
keepNode_less_5_count = 0
for data in tqdm(data_loader, mininterval=args.tqdm_interval):
data = data.to(device)
pocket_cls_pred, pocket_cls, protein_out_mask_whole, p_coords_batched_whole, pocket_coord_pred_direct, keepNode_less_5, prompt_component = model(data, stage=stage)
pocket_cls_loss = args.pocket_cls_loss_weight * pocket_cls_criterion(pocket_cls_pred, pocket_cls.float())
pocket_coord_direct_loss = args.pocket_distance_loss_weight * pocket_coord_criterion(pocket_coord_pred_direct, data.coords_center)
pocket_cls_batch_loss += len(pocket_cls_pred)*pocket_cls_loss.item()
pocket_coord_direct_batch_loss += len(pocket_coord_pred_direct)*pocket_coord_direct_loss.item()
keepNode_less_5_count += keepNode_less_5
batch_len = protein_out_mask_whole.sum(dim=1).detach()
# protein_len_list.append(batch_len)
pocket_coord_pred_direct_list.append(pocket_coord_pred_direct.detach())
pocket_coord_direct_list.append(data.coords_center)
for i, j in enumerate(batch_len):
count += 1
pdb_list.append(data.pdb[i])
pocket_cls_list.append(pocket_cls.detach()[i][:j])
pocket_cls_pred_list.append(pocket_cls_pred.detach()[i][:j].sigmoid().round().int())
pred_index_bool = (pocket_cls_pred.detach()[i][:j].sigmoid().round().int() == 1)
if pred_index_bool.sum() != 0:
pred_pocket_center = p_coords_batched_whole.detach()[i][:j][pred_index_bool].mean(dim=0).unsqueeze(0)
pocket_coord_pred_list.append(pred_pocket_center)
pocket_coord_list.append(data.coords_center[i].unsqueeze(0))
else: # all the prediction is False, skip
skip_count += 1
pred_index_true = pocket_cls_pred[i][:j].sigmoid().unsqueeze(-1)
pred_index_false = 1. - pred_index_true
pred_index_prob = torch.cat([pred_index_false, pred_index_true], dim=-1)
pred_index_log_prob = torch.log(pred_index_prob)
pred_index_one_hot = gumbel_softmax_no_random(pred_index_log_prob, tau=args.gs_tau, hard=False)
pred_index_one_hot_true = pred_index_one_hot[:, 1].unsqueeze(-1)
pred_pocket_center_gumbel = pred_index_one_hot_true * p_coords_batched_whole[i][:j]
pred_pocket_center_gumbel_mean = pred_pocket_center_gumbel.sum(dim=0) / pred_index_one_hot_true.sum(dim=0)
pocket_coord_pred_list.append(pred_pocket_center_gumbel_mean.detach().cpu().numpy())
pocket_coord_list.append(data.coords_center[i].unsqueeze(0))
# real_y_mask = torch.cat(real_y_mask_list)
pocket_cls = torch.cat(pocket_cls_list)
pocket_cls_pred = torch.cat(pocket_cls_pred_list)
if len(pocket_coord_pred_list) > 0:
pocket_coord_pred = torch.cat(pocket_coord_pred_list)
pocket_coord = torch.cat(pocket_coord_list)
pocket_coord_pred_direct = torch.cat(pocket_coord_pred_direct_list)
pocket_coord_direct = torch.cat(pocket_coord_direct_list)
pocket_cls_accuracy = (pocket_cls_pred == pocket_cls).sum().item() / len(pocket_cls_pred)
metrics = {"samples": count, "skip_samples": skip_count, "keepNode < 5": keepNode_less_5_count}
metrics.update({"pocket_cls_bce_loss": pocket_cls_batch_loss / len(pocket_cls_pred_list)})
metrics.update({"pocket_coord_mse_loss": pocket_coord_direct_batch_loss / len(pocket_coord_pred_direct)})
metrics.update({"pocket_cls_accuracy": pocket_cls_accuracy})
if len(pocket_coord_pred_list) > 0:
metrics.update(pocket_metrics(pocket_coord_pred, pocket_coord))
return metrics
[docs]
def gumbel_softmax_no_random(logits: torch.Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1) -> torch.Tensor:
gumbels = logits / tau # ~Gumbel(logits,tau)
y_soft = gumbels.softmax(dim)
if hard:
# Straight through.
index = y_soft.max(dim, keepdim=True)[1]
y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
ret = y_hard - y_soft.detach() + y_soft
else:
# Reparametrization trick.
ret = y_soft
return ret