import numpy as np
import torch
import copy
import scipy.sparse
from scipy.spatial.transform import Rotation
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import math
from miniworld.utils.chemical import *
NUM_CLASSES = 23
[docs]
def get_train_valid_data_dir_dictionary(data_dir_dictionary):
"""Describes the structure of the data directory dictionary::
data_dir_dictionary = {
'PDB_hhr_dir' : '/home/psk6950/practice/MiniWorld/data/hhr_rerefactoring/',
'pickle_data' : {
'PDB_msa_dir' : '/home/psk6950/practice/MiniWorld/data/pickle_data/a3m/',
'PDB_mmcif_dir' : '/home/psk6950/practice/MiniWorld/data/pickle_data/mmcif/'
},
'STRING_msa_dir' : '/home/psk6950/practice/MiniWorld/data/STRING/MSA',
'STRING_template_dir' : '/home/psk6950/practice/MiniWorld/data/STRING/Template',
'STRING_ID_Seq_path' : '/home/psk6950/practice/MiniWorld/data/STRING_ID_dict.txt',
'ID_to_source_dir' : '/home/psk6950/practice/MiniWorld/data/new/filter_final_v00/new_ID_to_source_dict.txt',
'PDB_Monomer_ID_to_info_dir' : '/home/psk6950/practice/MiniWorld/data/valid_PDB_Monomer_ID_list.txt',
'PDB_Complex_ID_to_info_dir' : '/home/psk6950/practice/MiniWorld/data/valid_PDB_Complex_ID_list.txt',
'PDB_ID_to_info_dir' : '/home/psk6950/practice/MiniWorld/data/new/filter_final_v00/list_PSK_v04.csv',
'PDB_ID_full_sequence' : '/home/psk6950/practice/MiniWorld/data/new/filter_final_v00/new_PDB_ID_full_sequence.txt',
'train_ID_list' : '/home/psk6950/practice/MiniWorld/data/train_ID_list.txt',
'valid_PDB_ID_list' : '/home/psk6950/practice/MiniWorld/data/valid_PDB_ID_list.txt',
'valid_STRING_ID_list' : '/home/psk6950/practice/MiniWorld/data/valid_STRING_ID_list.txt',
}
"""
train_data_dir_dictionary = copy.deepcopy(data_dir_dictionary)
valid_PDB_monomer_data_dir_dictionary = copy.deepcopy(data_dir_dictionary)
valid_PDB_complex_data_dir_dictionary = copy.deepcopy(data_dir_dictionary)
valid_STRING_data_dir_dictionary = copy.deepcopy(data_dir_dictionary)
train_data_dir_dictionary['ID_list'] = train_data_dir_dictionary['train_ID_list']
valid_PDB_monomer_data_dir_dictionary['ID_list'] = valid_PDB_monomer_data_dir_dictionary['valid_PDB_ID_list']
valid_PDB_complex_data_dir_dictionary['ID_list'] = valid_PDB_complex_data_dir_dictionary['valid_PDB_ID_list']
valid_PDB_monomer_data_dir_dictionary['PDB_ID_to_info_dir'] = valid_PDB_monomer_data_dir_dictionary['PDB_Monomer_ID_to_info_dir']
valid_PDB_complex_data_dir_dictionary['PDB_ID_to_info_dir'] = valid_PDB_complex_data_dir_dictionary['PDB_Complex_ID_to_info_dir']
valid_STRING_data_dir_dictionary['ID_list'] = valid_STRING_data_dir_dictionary['valid_STRING_ID_list']
return train_data_dir_dictionary, valid_PDB_monomer_data_dir_dictionary, valid_PDB_complex_data_dir_dictionary, valid_STRING_data_dir_dictionary
# TODO Is it really uniform in the space?
[docs]
def random_rot_trans(xyz, random_noise=20.0):
# xyz: (N, L, 27, 3)
N, L = xyz.shape[:2]
# pick random rotation axis
R_mat = torch.tensor(Rotation.random(N).as_matrix(), dtype=xyz.dtype).to(xyz.device)
xyz = torch.einsum('nij,nlaj->nlai', R_mat, xyz) + torch.rand(N,1,1,3, device=xyz.device)*random_noise
return xyz
[docs]
def center_and_realign_missing(xyz, mask_t):
# xyz: (L, 27, 3)
# mask_t: (L, 27)
L = xyz.shape[0]
mask = mask_t[:,:3].all(dim=-1) # True for valid atom (L)
# center c.o.m at the origin
center_CA = (mask[...,None]*xyz[:,1]).sum(dim=0) / (mask[...,None].sum(dim=0) + 1e-5) # (3)
xyz = torch.where(mask.view(L,1,1), xyz - center_CA.view(1, 1, 3), xyz)
# move missing residues to the closest valid residues
exist_in_xyz = torch.where(mask)[0] # L_sub
seqmap = (torch.arange(L, device=xyz.device)[:,None] - exist_in_xyz[None,:]).abs() # (L, Lsub)
seqmap = torch.argmin(seqmap, dim=-1) # L
idx = torch.gather(exist_in_xyz, 0, seqmap)
offset_CA = torch.gather(xyz[:,1], 0, idx.reshape(L,1).expand(-1,3))
xyz = torch.where(mask.view(L,1,1), xyz, xyz + offset_CA.reshape(L,1,3))
return xyz
[docs]
def th_ang_v(ab,bc,eps:float=1e-8):
def th_norm(x,eps:float=1e-8):
return x.square().sum(-1,keepdim=True).add(eps).sqrt()
def th_N(x,alpha:float=0):
return x/th_norm(x).add(alpha)
ab, bc = th_N(ab),th_N(bc)
cos_angle = torch.clamp( (ab*bc).sum(-1), -1, 1)
sin_angle = torch.sqrt(1-cos_angle.square() + eps)
dih = torch.stack((cos_angle,sin_angle),-1)
return dih
[docs]
def th_dih_v(ab,bc,cd):
def th_cross(a,b):
a,b = torch.broadcast_tensors(a,b)
return torch.cross(a,b, dim=-1)
def th_norm(x,eps:float=1e-8):
return x.square().sum(-1,keepdim=True).add(eps).sqrt()
def th_N(x,alpha:float=0):
return x/th_norm(x).add(alpha)
ab, bc, cd = th_N(ab),th_N(bc),th_N(cd)
n1 = th_N( th_cross(ab,bc) )
n2 = th_N( th_cross(bc,cd) )
sin_angle = (th_cross(n1,bc)*n2).sum(-1)
cos_angle = (n1*n2).sum(-1)
dih = torch.stack((cos_angle,sin_angle),-1)
return dih
[docs]
def th_dih(a,b,c,d):
return th_dih_v(a-b,b-c,c-d)
# More complicated version splits error in CA-N and CA-C (giving more accurate CB position)
# It returns the rigid transformation from local frame to global frame
[docs]
def rigid_from_3_points(N, Ca, C, non_ideal=False, eps=1e-8):
#N, Ca, C - [B,L, 3]
#R - [B,L, 3, 3], det(R)=1, inv(R) = R.T, R is a rotation matrix
B,L = N.shape[:2]
v1 = C-Ca
v2 = N-Ca
e1 = v1/(torch.norm(v1, dim=-1, keepdim=True)+eps)
u2 = v2-(torch.einsum('bli, bli -> bl', e1, v2)[...,None]*e1)
e2 = u2/(torch.norm(u2, dim=-1, keepdim=True)+eps)
e3 = torch.cross(e1, e2, dim=-1)
R = torch.cat([e1[...,None], e2[...,None], e3[...,None]], axis=-1) #[B,L,3,3] - rotation matrix
if non_ideal:
v2 = v2/(torch.norm(v2, dim=-1, keepdim=True)+eps)
cosref = torch.clamp( torch.sum(e1*v2, dim=-1), min=-1.0, max=1.0) # cosine of current N-CA-C bond angle
costgt = cos_ideal_NCAC.item()
cos2del = torch.clamp( cosref*costgt + torch.sqrt((1-cosref*cosref)*(1-costgt*costgt)+eps), min=-1.0, max=1.0 )
cosdel = torch.sqrt(0.5*(1+cos2del)+eps)
sindel = torch.sign(costgt-cosref) * torch.sqrt(1-0.5*(1+cos2del)+eps)
Rp = torch.eye(3, device=N.device).repeat(B,L,1,1)
Rp[:,:,0,0] = cosdel
Rp[:,:,0,1] = -sindel
Rp[:,:,1,0] = sindel
Rp[:,:,1,1] = cosdel
R = torch.einsum('blij,bljk->blik', R,Rp)
return R, Ca
[docs]
def rigid_from_3_points_v2(N, Ca, C, mask_crds, chain_break_list, non_ideal=False, eps=1e-8):
# N, Ca, C - [B,L, 3]
# mask - [B, L, 3]
# chain_break - {chain_idx: (start, end)}
B, L = N.shape[:2]
R_chain_list = []
T_chain_list = []
R_residue_list = []
T_residue_list = []
mask_BB = ~(mask_crds[:,:,:3].sum(dim=-1) < 3.0) # ignore residues having missing BB atoms for loss calculation, [B, L]
for bb, chain_break in enumerate(chain_break_list):
R_chain_in_list = []
T_chain_in_list = []
R_residue_in_list = []
T_residue_in_list = []
for chain_idx, (start, end) in chain_break.items():
mask_idxs = mask_BB[bb,start:end+1] # [L_chain]
chain_N = N[bb,start:end+1] # [L_chain, 3]
chain_N = chain_N[mask_idxs] # [L_chain, 3]
chain_Ca = Ca[bb,start:end+1] # [L_chain, 3]
chain_Ca = chain_Ca[mask_idxs] # [L_chain, 3]
chain_C = C[bb,start:end+1] # [L_chain, 3]
chain_C = chain_C[mask_idxs] # [L_chain, 3]
R_residue_in_chain, T_residue_in_chain = rigid_from_3_points(chain_N.unsqueeze(0), chain_Ca.unsqueeze(0), chain_C.unsqueeze(0), non_ideal, eps) # [1, L_chain, 3, 3], [1, L_chain, 3]
R_residue_in_chain = R_residue_in_chain.squeeze(0) # [L_chain, 3, 3]
T_residue_in_chain = T_residue_in_chain.squeeze(0) # [L_chain, 3]
N_to_C_vector = chain_Ca[-1,:] - chain_Ca[0,:] # [3]
L_chain = chain_N.shape[0]
# T_chain : CoM of Ca
T_chain = chain_Ca.mean(dim=0, keepdim=True) # [1, 3]
T_chain = T_chain.expand(L_chain, 3) # [L_chain, 3]
# Centering
chain_Ca = chain_Ca - T_chain
if L_chain > 2:
# PCA using Ca
U, S, V = torch.svd(chain_Ca.unsqueeze(0)) # U: [1, L_chain, 3], S: [1, 3], V: [1, 3, 3]
# breakpoint()
e1 = V[0,:,0] # [3]
e2 = V[0,:,1] # [3]
e3 = V[0,:,2] # [3]
# Choose the direction
sign1 = torch.sign(torch.sum(e1 * N_to_C_vector, dim=-1))
sign2 = torch.sign(torch.sum(e2 * N_to_C_vector, dim=-1))
sign3 = torch.sign(torch.sum(e3 * N_to_C_vector, dim=-1))
e1 = e1 * sign1 # [3]
e2 = e2 * sign2 # [3]
e3 = e3 * sign3 # [3]
# choose x,y,z. consider right-handed coordinate system
check_right_handed = torch.sign(torch.sum(e1 * torch.cross(e2, e3, dim=-1), dim=-1)) # (1 or -1)
if check_right_handed > 0:
ex = e1
ey = e2
ez = e3
else:
ex = e1
ey = e3
ez = e2
R_chain = torch.stack([ex, ey, ez], dim=-1) # [3, 3]
else :
# In this case, we can't calculate PCA. So, we just use identity matrix
R_chain = torch.eye(3, device=N.device).unsqueeze(0).expand(3, 3)
R_chain = R_chain.unsqueeze(0).expand(L_chain, 3, 3) # [L_chain, 3, 3]
R_chain_inv = R_chain.transpose(-2, -1) # [L_chain, 3, 3]
R_residue = torch.einsum('lij,ljk->lik', R_chain_inv, R_residue_in_chain) # [L_chain, 3, 3]
T_residue = torch.einsum('lij,lj->li', R_chain_inv, T_residue_in_chain - T_chain) # [L_chain, 3]
R_chain_in_list.append(R_chain)
T_chain_in_list.append(T_chain)
R_residue_in_list.append(R_residue)
T_residue_in_list.append(T_residue)
R_chain_in = torch.cat(R_chain_in_list, dim=0) # [L, 3, 3]
T_chain_in = torch.cat(T_chain_in_list, dim=0)
R_residue_in = torch.cat(R_residue_in_list, dim=0) # [L, 3, 3]
T_residue_in = torch.cat(T_residue_in_list, dim=0)
R_chain_list.append(R_chain_in)
T_chain_list.append(T_chain_in)
R_residue_list.append(R_residue_in)
T_residue_list.append(T_residue_in)
R_chain = torch.stack(R_chain_list, dim=0) # [B, L, 3, 3]
T_chain = torch.stack(T_chain_list, dim=0) # [B, L, 3]
R_residue = torch.stack(R_residue_list, dim=0) # [B, L, 3, 3]
T_residue = torch.stack(T_residue_list, dim=0) # [B, L, 3]
return R_chain, T_chain, R_residue, T_residue
[docs]
def cal_integrated_frame(R_chain, T_chain, R_residue, T_residue):
# R_chain : [B, L, 3, 3]
# T_chain : [B, L, 3]
# R_residue : [B, L, 3, 3]
# T_residue : [B, L, 3]
R_out = torch.einsum('blij,bljk->blik', R_chain, R_residue) # [B, L, 3, 3]
T_out = torch.einsum('blij,blj->bli', R_chain, T_residue) + T_chain # [B, L, 3]
return R_out, T_out
[docs]
def centering_RT(R, T, mask):
# R : [B, L, 3, 3]
# T : [B, L, 3]
# mask : [B, L, 3]
B, L = R.shape[:2]
# Centering
CoM = (mask[...,None]*T).sum(dim=1) / (mask[...,None].sum(dim=1) + 1e-5) # [B, 3]
T = torch.where(mask.view(B,L,1), T - CoM.view(B,1,3), T)
return R, T
# process ideal frames
[docs]
def make_frame(X, Y):
Xn = X / torch.linalg.norm(X)
Y = Y - torch.dot(Y, Xn) * Xn
Yn = Y / torch.linalg.norm(Y)
Z = torch.cross(Xn,Yn)
Zn = Z / torch.linalg.norm(Z)
return torch.stack((Xn,Yn,Zn), dim=-1)
[docs]
def get_Cb(xyz):
N = xyz[:,:,0]
Ca = xyz[:,:,1]
C = xyz[:,:,2]
# recreate Cb given N,Ca,C
b = Ca - N
c = C - Ca
a = torch.cross(b, c, dim=-1)
Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca
return Cb
[docs]
def cross_product_matrix(u):
B, L = u.shape[:2]
matrix = torch.zeros((B, L, 3, 3), device=u.device)
matrix[:,:,0,1] = -u[...,2]
matrix[:,:,0,2] = u[...,1]
matrix[:,:,1,0] = u[...,2]
matrix[:,:,1,2] = -u[...,0]
matrix[:,:,2,0] = -u[...,1]
matrix[:,:,2,1] = u[...,0]
return matrix
# writepdb
[docs]
def writepdb(filename, atoms, seq, Ls, idx_pdb=None, bfacts=None):
f = open(filename,"w")
ctr = 1
scpu = seq.cpu().squeeze(0)
atomscpu = atoms.cpu().squeeze(0)
L = sum(Ls)
O = seq.shape[0]//L
Ls = Ls * O
if bfacts is None:
bfacts = torch.zeros(atomscpu.shape[0])
if idx_pdb is None:
idx_pdb = 1 + torch.arange(atomscpu.shape[0])
Bfacts = torch.clamp( bfacts.cpu(), 0, 100)
chn_idx, res_idx = 0, 0
for i,s in enumerate(scpu):
natoms = atomscpu.shape[-2]
if (natoms!=14 and natoms!=27):
print ('bad size!', atoms.shape)
assert(False)
atms = aa2long[s]
# his protonation state hack
if (s==8 and torch.linalg.norm( atomscpu[i,9,:]-atomscpu[i,5,:] ) < 1.7):
atms = (
" N "," CA "," C "," O "," CB "," CG "," NE2"," CD2"," CE1"," ND1",
None, None, None, None," H "," HA ","1HB ","2HB "," HD2"," HE1",
" HD1", None, None, None, None, None, None) # his_d
for j,atm_j in enumerate(atms):
if (j<natoms and atm_j is not None and not torch.isnan(atomscpu[i,j,:]).any()):
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%(
"ATOM", ctr, atm_j, num2aa[s],
PDB_CHAIN_IDS[chn_idx], res_idx+1, atomscpu[i,j,0], atomscpu[i,j,1], atomscpu[i,j,2],
1.0, Bfacts[i] ) )
ctr += 1
res_idx += 1
if (chn_idx < len(Ls) and res_idx == Ls[chn_idx]):
chn_idx += 1
res_idx = 0
# resolve tip atom indices
tip_indices = torch.full((NUM_CLASSES,3), -1)
for i in range(NUM_CLASSES):
tip_atm = aa2tip[i]
atm_long = aa2long[i]
for j in range(3):
if tip_atm[j] == None:
continue
tip_indices[i,j] = atm_long.index(tip_atm[j])
# resolve torsion indices
torsion_indices = torch.full((NUM_CLASSES,4,4),0)
torsion_can_flip = torch.full((NUM_CLASSES,10),False,dtype=torch.bool)
for i in range(NUM_CLASSES):
i_l, i_a = aa2long[i], aa2longalt[i]
for j in range(4):
if torsions[i][j] is None:
continue
for k in range(4):
a = torsions[i][j][k]
torsion_indices[i,j,k] = i_l.index(a)
if (i_l.index(a) != i_a.index(a)):
torsion_can_flip[i,3+j] = True ##bb tors never flip
# HIS is a special case
torsion_can_flip[8,4]=False
# build the mapping from atoms in the full rep (Nx27) to the "alternate" rep
allatom_mask = torch.zeros((NUM_CLASSES,27), dtype=torch.bool)
long2alt = torch.zeros((NUM_CLASSES,27), dtype=torch.long)
for i in range(NUM_CLASSES):
i_l, i_lalt = aa2long[i], aa2longalt[i]
for j,a in enumerate(i_l):
if (a is None):
long2alt[i,j] = j
else:
long2alt[i,j] = i_lalt.index(a)
allatom_mask[i,j] = True
# bond graph traversal
num_bonds = torch.zeros((NUM_CLASSES,27,27), dtype=torch.long)
for i in range(NUM_CLASSES):
num_bonds_i = np.zeros((27,27))
for (bnamei,bnamej) in aabonds[i]:
bi,bj = aa2long[i].index(bnamei),aa2long[i].index(bnamej)
num_bonds_i[bi,bj] = 1
num_bonds_i = scipy.sparse.csgraph.shortest_path (num_bonds_i,directed=False)
num_bonds_i[num_bonds_i>=4] = 4
num_bonds[i,...] = torch.tensor(num_bonds_i)
# hbond scoring parameters
[docs]
def donorHs(D,bonds,atoms):
dHs = []
for (i,j) in bonds:
if (i==D):
idx_j = atoms.index(j)
if (idx_j>=14): # if atom j is a hydrogen
dHs.append(idx_j)
if (j==D):
idx_i = atoms.index(i)
if (idx_i>=14): # if atom j is a hydrogen
dHs.append(idx_i)
assert (len(dHs)>0)
return dHs
[docs]
def generate_random_xyz(seq, random_noise = 20.0, use_all_model = False, use_random_model=True):
# seq : (Model, L)
# initialize random xyz from ideal_coords in chemical.py
# and then apply random rotation and translation using random_rot_trans
M, L = seq.shape
if use_all_model :
initial_xyz = torch.zeros((M, L, 27, 3))
atom_mask = torch.zeros((M, L, 27), dtype=torch.bool)
random_xyz = torch.zeros((M, L, 27, 3))
for ii in range(M):
for jj in range(L):
for atom_xyz_list in ideal_coords[seq[ii,jj]]:
for kk in range(len(atom_xyz_list)):
x,y,z = atom_xyz_list[kk][-1]
initial_xyz[ii,jj,kk,:] = torch.tensor([x,y,z])
atom_mask[ii,jj,kk] = True
random_xyz[ii:ii+1,:,:,:] = random_rot_trans(initial_xyz[ii:ii+1,:,:,:], random_noise)
else :
model_idx = torch.randint(M, (1,)) if use_random_model else 0
initial_xyz = torch.zeros((1, L, 27, 3))
atom_mask = torch.zeros((1, L, 27), dtype=torch.bool)
random_xyz = torch.zeros((1, L, 27, 3))
for jj in range(L):
for kk, atom_xyz_list in enumerate(ideal_coords[seq[model_idx,jj].item()]):
x,y,z = atom_xyz_list[-1]
initial_xyz[0,jj,kk,:] = torch.tensor([x,y,z])
atom_mask[0,jj,kk] = True
random_xyz = random_rot_trans(initial_xyz, random_noise) # (1, L, 27, 3)
return random_xyz, atom_mask # (M, L, 27, 3) or (1, L, 27, 3) | (M, L, 27) or (1, L, 27)
[docs]
def generate_initial_xyz(seq, chain_break, random_noise = 20.0, use_all_model = False, use_random_model=True):
# seq : (Model, L)
# initialize random xyz from ideal_coords in chemical.py
# and then apply random rotation and translation using random_rot_trans
M, L = seq.shape
if use_all_model :
initial_xyz = torch.zeros((M, L, 27, 3))
atom_mask = torch.zeros((M, L, 27), dtype=torch.bool)
random_xyz = torch.zeros((M, L, 27, 3))
for ii in range(M):
for jj in range(L):
for atom_xyz_list in ideal_coords[seq[ii,jj]]:
for kk in range(len(atom_xyz_list)):
x,y,z = atom_xyz_list[kk][-1]
initial_xyz[ii,jj,kk,:] = torch.tensor([x,y,z])
atom_mask[ii,jj,kk] = True
else :
model_idx = torch.randint(M, (1,)) if use_random_model else 0
initial_xyz = torch.zeros((1, L, 27, 3))
atom_mask = torch.zeros((1, L, 27), dtype=torch.bool)
random_xyz = torch.zeros((1, L, 27, 3))
for jj in range(L):
for kk, atom_xyz_list in enumerate(ideal_coords[seq[model_idx,jj].item()]):
x,y,z = atom_xyz_list[-1]
initial_xyz[0,jj,kk,:] = torch.tensor([x,y,z])
atom_mask[0,jj,kk] = True
for chain_idx, (start, end) in chain_break.items():
chain_xyz = initial_xyz[:,start:(end+1),:,:] # (M, L_chain, 27, 3)
random_xyz[:,start:(end+1),:,:] = random_rot_trans(chain_xyz, random_noise)
return random_xyz, atom_mask # (M, L, 27, 3) or (1, L, 27, 3) | (M, L, 27) or (1, L, 27)
[docs]
def generate_symmetric_operation(symmetry = "C4"):
if symmetry == "C4":
# rotate 90 degree around z-axis
return torch.tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]], dtype=torch.float)
elif symmetry == "C3":
# rotate 120 degree around z-axis
angle = 120 * (math.pi / 180) # converting degrees to radians
return torch.tensor([
[math.cos(angle), -math.sin(angle), 0],
[math.sin(angle), math.cos(angle), 0],
[0, 0, 1]
], dtype=torch.float)
else :
raise ValueError("Not implemented yet")
[docs]
@torch.no_grad()
def generate_symmetric_xyz(seq, chain_break, random_noise = 20.0, use_random_model=True, symmetry = "C4"):
# seq : (Model, L)
# initialize random xyz from ideal_coords in chemical.py
# and then apply random rotation and translation using random_rot_trans
M, L = seq.shape
model_idx = torch.randint(M, (1,)) if use_random_model else 0
initial_xyz = torch.zeros((1, L, 27, 3))
atom_mask = torch.zeros((1, L, 27), dtype=torch.bool)
symmetric_xyz = torch.zeros((1, L, 27, 3))
for jj in range(L):
for kk, atom_xyz_list in enumerate(ideal_coords[seq[model_idx,jj].item()]):
x,y,z = atom_xyz_list[-1]
initial_xyz[0,jj,kk,:] = torch.tensor([x,y,z])
atom_mask[0,jj,kk] = True
initial_xyz = initial_xyz + torch.randn(1,1,3).repeat(1,L,1).unsqueeze(-2)
rotate_matrix = torch.eye(3) # (3, 3)
unit_rotation_matrix = generate_symmetric_operation(symmetry) # (3, 3)
first = True
for chain_idx, (start, end) in chain_break.items():
default_xyz = initial_xyz[:,start:(end+1),:,:] # (M, L_chain, 27, 3)
default_xyz = torch.einsum('mlcj,ij->mlci', default_xyz, rotate_matrix) # (M, L_chain, 27, 3)
symmetric_xyz[:,start:(end+1),:,:] = default_xyz
rotate_matrix = torch.einsum('ij,jk->ik', rotate_matrix, unit_rotation_matrix) # (3, 3)
return symmetric_xyz, atom_mask # (M, L, 27, 3) or (1, L, 27, 3) | (M, L, 27) or (1, L, 27)
# kinematic parameters
base_indices = torch.full((NUM_CLASSES,27),0, dtype=torch.long)
xyzs_in_base_frame = torch.ones((NUM_CLASSES,27,4))
RTs_by_torsion = torch.eye(4).repeat(NUM_CLASSES,7,1,1)
reference_angles = torch.ones((NUM_CLASSES,3,2))
for i in range(NUM_CLASSES):
i_l = aa2long[i]
for name, base, coords in ideal_coords[i]:
idx = i_l.index(name)
base_indices[i,idx] = base
xyzs_in_base_frame[i,idx,:3] = torch.tensor(coords)
# omega frame
RTs_by_torsion[i,0,:3,:3] = torch.eye(3)
RTs_by_torsion[i,0,:3,3] = torch.zeros(3)
# phi frame
RTs_by_torsion[i,1,:3,:3] = make_frame(
xyzs_in_base_frame[i,0,:3] - xyzs_in_base_frame[i,1,:3],
torch.tensor([1.,0.,0.])
)
RTs_by_torsion[i,1,:3,3] = xyzs_in_base_frame[i,0,:3]
# psi frame
RTs_by_torsion[i,2,:3,:3] = make_frame(
xyzs_in_base_frame[i,2,:3] - xyzs_in_base_frame[i,1,:3],
xyzs_in_base_frame[i,1,:3] - xyzs_in_base_frame[i,0,:3]
)
RTs_by_torsion[i,2,:3,3] = xyzs_in_base_frame[i,2,:3]
# chi1 frame
if torsions[i][0] is not None:
a0,a1,a2 = torsion_indices[i,0,0:3]
RTs_by_torsion[i,3,:3,:3] = make_frame(
xyzs_in_base_frame[i,a2,:3]-xyzs_in_base_frame[i,a1,:3],
xyzs_in_base_frame[i,a0,:3]-xyzs_in_base_frame[i,a1,:3],
)
RTs_by_torsion[i,3,:3,3] = xyzs_in_base_frame[i,a2,:3]
# chi2~4 frame
for j in range(1,4):
if torsions[i][j] is not None:
a2 = torsion_indices[i,j,2]
if ((i==18 and j==2) or (i==8 and j==2)): # TYR CZ-OH & HIS CE1-HE1 a special case
a0,a1 = torsion_indices[i,j,0:2]
RTs_by_torsion[i,3+j,:3,:3] = make_frame(
xyzs_in_base_frame[i,a2,:3]-xyzs_in_base_frame[i,a1,:3],
xyzs_in_base_frame[i,a0,:3]-xyzs_in_base_frame[i,a1,:3] )
else:
RTs_by_torsion[i,3+j,:3,:3] = make_frame(
xyzs_in_base_frame[i,a2,:3],
torch.tensor([-1.,0.,0.]), )
RTs_by_torsion[i,3+j,:3,3] = xyzs_in_base_frame[i,a2,:3]
# CB/CG angles
NCr = 0.5*(xyzs_in_base_frame[i,0,:3]+xyzs_in_base_frame[i,2,:3])
CAr = xyzs_in_base_frame[i,1,:3]
CBr = xyzs_in_base_frame[i,4,:3]
CGr = xyzs_in_base_frame[i,5,:3]
reference_angles[i,0,:]=th_ang_v(CBr-CAr,NCr-CAr)
NCp = xyzs_in_base_frame[i,2,:3]-xyzs_in_base_frame[i,0,:3]
NCpp = NCp - torch.dot(NCp,NCr)/ torch.dot(NCr,NCr) * NCr
reference_angles[i,1,:]=th_ang_v(CBr-CAr,NCpp)
reference_angles[i,2,:]=th_ang_v(CGr,torch.tensor([-1.,0.,0.]))
[docs]
def draw_attn(attn,save_path):
import matplotlib.pyplot as plt
# attn : (B, L, L, n_head)
attn_max = attn.max()
attn_min = attn.min()
attn = attn[0]
n_head = attn.shape[-1]
# draw all heads in one figure
fig, axs = plt.subplots(1, n_head, figsize=(n_head*3, 3))
for i in range(n_head):
axs[i].imshow(attn[:,:,i].detach().cpu().numpy(), cmap='hot', interpolation='nearest')
# show heatmap at the right
plt.colorbar(axs[0].imshow(attn[:,:,0].detach().cpu().numpy(), cmap='hot', interpolation='nearest'), ax=axs, orientation='horizontal')
# if os.path.exists(save_path):
# idx = 0
# while os.path.exists(save_path):
# save_path = save_path.split('.')[0]
# save_path = save_path + f'.{idx}.png'
# idx += 1
print(save_path)
plt.savefig(save_path)
plt.close()
[docs]
def draw_attn_multirow(attn, num_row, save_path):
import matplotlib.pyplot as plt
# attn: (B, L, L, n_head)
attn = attn[0]
n_head = attn.shape[-1]
# Calculate the number of columns needed if n_head isn't divisible evenly by num_row
num_col = -(-n_head // num_row) # Ceiling division
# Draw all heads in one figure
fig, axs = plt.subplots(num_row, num_col, figsize=(num_col*3, 3*num_row))
axs = axs.flatten() # Flatten the array for easier indexing
for i in range(n_head):
ax = axs[i]
im = ax.imshow(attn[:, :, i].detach().cpu().numpy(), cmap='hot', interpolation='nearest')
ax.axis('off') # Turn off the axis to make it cleaner since the location isn't important here
# Show heatmap at the right of the last plot
plt.colorbar(im, ax=axs, orientation='horizontal', fraction=0.046, pad=0.04)
plt.savefig(save_path)
plt.close()
[docs]
def draw_pae(logit_pae, mask, saving_path):
# logit_pae : (B, 64, L, L) 0~32
# mask : (B, L, L)
pae = torch.nn.functional.softmax(logit_pae, dim=1) # (B, 64, L, L)
nbin = pae.shape[1]
bin_value = 0.5 * torch.arange(nbin, dtype=torch.float32, device=pae.device)
expected_error = torch.einsum('bnij, n -> bij', pae, bin_value) # (B, L, L)
expected_error = expected_error * mask if mask is not None else expected_error
# Define the colormap
cdict = {
'red': [(0.0, 0.0, 0.0), # Blue at 0
(0.5, 1.0, 1.0), # White at 0.5 (corresponds to 16 if scaled 0-32)
(1.0, 1.0, 1.0)], # Red at 1
'green': [(0.0, 0.0, 0.0), # No green at 0
(0.5, 1.0, 1.0), # White (full green) at 0.5
(1.0, 0.0, 0.0)], # No green at 1
'blue': [(0.0, 1.0, 1.0), # Blue at 0
(0.5, 1.0, 1.0), # White (full blue) at 0.5
(1.0, 0.0, 0.0)] # No blue at 1
}
# black for expected_error = 0
cdict['red'].insert(0, (0.0, 0.0, 0.0))
cdict['green'].insert(0, (0.0, 0.0, 0.0))
cdict['blue'].insert(0, (0.0, 0.0, 0.0))
custom_cmap = mcolors.LinearSegmentedColormap('custom_cmap', cdict)
# Normalize the expected_error values to 0-1 scale
norm = mcolors.Normalize(vmin=0, vmax=32)
# Draw all batch pae in one figure
B = expected_error.shape[0]
if B>1:
fig, axs = plt.subplots(1, B, figsize=(B*3, 3))
axs = axs.flatten()
for i in range(B):
im = axs[i].imshow(expected_error[i].detach().cpu().numpy(), cmap=custom_cmap, norm=norm)
axs[i].axis('off')
else :
fig, axs = plt.subplots(1, 1, figsize=(3, 3))
im = axs.imshow(expected_error[0].detach().cpu().numpy(), cmap=custom_cmap, norm=norm)
axs.axis('off')
# Show heatmap at the right of the last plot
plt.colorbar(im, ax=axs, orientation='horizontal', fraction=0.046, pad=0.04)
plt.savefig(saving_path)
plt.close()