# conda activate py39
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import hashlib
from tqdm import tqdm
import rdkit.Chem as Chem
from rdkit.Chem import Draw
from rdkit.Chem import AllChem
from tqdm import tqdm
import glob
import torch
from torch import nn
from datetime import datetime
import logging
from io import StringIO
import sys
from scipy.spatial.transform import Rotation
[docs]
def compute_RMSD(a, b):
return torch.sqrt((((a-b)**2).sum(axis=-1)).mean())
from rdkit.Geometry import Point3D
[docs]
def write_with_new_coords(mol, new_coords, toFile):
# put this new coordinates into the sdf file.
w = Chem.SDWriter(toFile)
conf = mol.GetConformer()
for i in range(mol.GetNumAtoms()):
x,y,z = new_coords[i]
conf.SetAtomPosition(i,Point3D(x,y,z))
# w.SetKekulize(False)
w.write(mol)
w.close()
# new_coords = movable_coords.detach().numpy().astype(np.double)
# write_with_new_coords(mol, new_coords, toFile)
[docs]
def distance_loss_function(epoch, y_pred, x, protein_nodes_xyz, compound_pair_dis_constraint, LAS_distance_constraint_mask=None, mode=0):
# compute the current distance matrix
dis = torch.cdist(protein_nodes_xyz, x)
# do not consider the distance upper than 10 A0
dis_clamp = torch.clamp(dis, max=10)
# may be the author find that L1 loss is the best among these 3 modes
if mode == 0:
interaction_loss = ((dis_clamp - y_pred).abs()).sum()
elif mode == 1:
interaction_loss = ((dis_clamp - y_pred)**2).sum()
elif mode == 2:
# probably not a good choice. x^0.5 has infinite gradient at x=0. added 1e-5 for numerical stability.
interaction_loss = (((dis_clamp - y_pred).abs() + 1e-5)**0.5).sum()
# current inter-compound distance matrix
config_dis = torch.cdist(x, x)
if LAS_distance_constraint_mask is not None:
configuration_loss = 1 * (((config_dis-compound_pair_dis_constraint).abs())[LAS_distance_constraint_mask]).sum()
# basic exlcuded-volume. the distance between compound atoms should be at least 1.22Å
configuration_loss += 2 * ((1.22 - config_dis).relu()).sum()
else:
configuration_loss = 1 * ((config_dis-compound_pair_dis_constraint).abs()).sum()
# In first 500 epochs, only consider the interaction loss.
# In the following epochs, consider both interaction loss and configuration loss(increased weight)
if epoch < 500:
loss = interaction_loss
else:
loss = 1 * (interaction_loss + 5e-3 * (epoch - 500) * configuration_loss)
return loss, (interaction_loss.item(), configuration_loss.item())
[docs]
def distance_optimize_compound_coords(coords, y_pred, protein_nodes_xyz,
compound_pair_dis_constraint, total_epoch=5000, loss_function=distance_loss_function, LAS_distance_constraint_mask=None, mode=0, show_progress=False):
# pocket center
c_pred = protein_nodes_xyz.mean(axis=0)
# random initialization. center at the protein center.
# the center of block has a random shift of -5~5Å, drawn from the uniform distribution, in all three axes.
x = (5 * (2 * torch.rand(coords.shape) - 1) + c_pred.reshape(1, 3).detach())
x.requires_grad = True
optimizer = torch.optim.Adam([x], lr=0.1)
# optimizer = torch.optim.LBFGS([x], lr=0.01)
loss_list = []
rmsd_list = []
if show_progress:
it = tqdm(range(total_epoch))
else:
it = range(total_epoch)
for epoch in it:
optimizer.zero_grad()
loss, (interaction_loss, configuration_loss) = loss_function(epoch, y_pred, x, protein_nodes_xyz, compound_pair_dis_constraint, LAS_distance_constraint_mask=LAS_distance_constraint_mask, mode=mode)
loss.backward()
optimizer.step()
loss_list.append(loss.item())
rmsd = compute_RMSD(coords, x.detach())
rmsd_list.append(rmsd.item())
# break
return x, loss_list, rmsd_list
[docs]
def get_info_pred_distance(coords, y_pred, protein_nodes_xyz, compound_pair_dis_constraint, n_repeat=1, LAS_distance_constraint_mask=None, mode=0, show_progress=False):
info = []
if show_progress:
it = tqdm(range(n_repeat))
else:
it = range(n_repeat)
for repeat in it:
# random initialization.
# x = torch.rand(coords.shape, requires_grad=True)
# mode = 0
x, loss_list, rmsd_list = distance_optimize_compound_coords(coords, y_pred, protein_nodes_xyz,
compound_pair_dis_constraint, LAS_distance_constraint_mask=LAS_distance_constraint_mask, mode=mode, show_progress=False)
# rmsd = compute_rmsd(coords.detach().cpu().numpy(), movable_coords.detach().cpu().numpy())
# print(coords, movable_coords)
# rmsd = compute_rmsd(coords, x.detach())
rmsd = rmsd_list[-1]
try:
info.append([repeat, rmsd, float(loss_list[-1]), x.detach().cpu().numpy()])
except:
info.append([repeat, rmsd, 0, x.detach().cpu().numpy()])
info = pd.DataFrame(info, columns=['repeat', 'rmsd', 'loss', 'coords'])
return info
[docs]
def read_mol(sdf_fileName, mol2_fileName, verbose=False):
Chem.WrapLogs()
stderr = sys.stderr
sio = sys.stderr = StringIO()
mol = Chem.MolFromMolFile(sdf_fileName, sanitize=False)
problem = False
try:
Chem.SanitizeMol(mol)
mol = Chem.RemoveHs(mol)
sm = Chem.MolToSmiles(mol)
except Exception as e:
sm = str(e)
problem = True
if problem:
mol = Chem.MolFromMol2File(mol2_fileName, sanitize=False)
problem = False
try:
Chem.SanitizeMol(mol)
mol = Chem.RemoveHs(mol)
sm = Chem.MolToSmiles(mol)
problem = False
except Exception as e:
sm = str(e)
problem = True
if verbose:
print(sio.getvalue())
sys.stderr = stderr
return mol, problem
[docs]
def binarize(x):
return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))
#adj - > n_hops connections adj
[docs]
def n_hops_adj(adj, n_hops):
adj_mats = [torch.eye(adj.size(0), dtype=torch.long, device=adj.device), binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device))]
for i in range(2, n_hops+1):
adj_mats.append(binarize(adj_mats[i-1] @ adj_mats[1]))
extend_mat = torch.zeros_like(adj)
for i in range(1, n_hops+1):
extend_mat += (adj_mats[i] - adj_mats[i-1]) * i
return extend_mat
[docs]
def get_LAS_distance_constraint_mask(mol):
# Get the adj
adj = Chem.GetAdjacencyMatrix(mol)
adj = torch.from_numpy(adj)
extend_adj = n_hops_adj(adj,2)
# add ring
ssr = Chem.GetSymmSSSR(mol)
for ring in ssr:
# print(ring)
for i in ring:
for j in ring:
if i==j:
continue
else:
extend_adj[i][j]+=1
# turn to mask
mol_mask = binarize(extend_adj)
return mol_mask