Source code for miniworld.utils.util

import numpy as np
import torch
import copy

import scipy.sparse
from scipy.spatial.transform import Rotation

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

import math

from miniworld.utils.chemical import *

NUM_CLASSES = 23


[docs] def get_train_valid_data_dir_dictionary(data_dir_dictionary): """Describes the structure of the data directory dictionary:: data_dir_dictionary = { 'PDB_hhr_dir' : '/home/psk6950/practice/MiniWorld/data/hhr_rerefactoring/', 'pickle_data' : { 'PDB_msa_dir' : '/home/psk6950/practice/MiniWorld/data/pickle_data/a3m/', 'PDB_mmcif_dir' : '/home/psk6950/practice/MiniWorld/data/pickle_data/mmcif/' }, 'STRING_msa_dir' : '/home/psk6950/practice/MiniWorld/data/STRING/MSA', 'STRING_template_dir' : '/home/psk6950/practice/MiniWorld/data/STRING/Template', 'STRING_ID_Seq_path' : '/home/psk6950/practice/MiniWorld/data/STRING_ID_dict.txt', 'ID_to_source_dir' : '/home/psk6950/practice/MiniWorld/data/new/filter_final_v00/new_ID_to_source_dict.txt', 'PDB_Monomer_ID_to_info_dir' : '/home/psk6950/practice/MiniWorld/data/valid_PDB_Monomer_ID_list.txt', 'PDB_Complex_ID_to_info_dir' : '/home/psk6950/practice/MiniWorld/data/valid_PDB_Complex_ID_list.txt', 'PDB_ID_to_info_dir' : '/home/psk6950/practice/MiniWorld/data/new/filter_final_v00/list_PSK_v04.csv', 'PDB_ID_full_sequence' : '/home/psk6950/practice/MiniWorld/data/new/filter_final_v00/new_PDB_ID_full_sequence.txt', 'train_ID_list' : '/home/psk6950/practice/MiniWorld/data/train_ID_list.txt', 'valid_PDB_ID_list' : '/home/psk6950/practice/MiniWorld/data/valid_PDB_ID_list.txt', 'valid_STRING_ID_list' : '/home/psk6950/practice/MiniWorld/data/valid_STRING_ID_list.txt', } """ train_data_dir_dictionary = copy.deepcopy(data_dir_dictionary) valid_PDB_monomer_data_dir_dictionary = copy.deepcopy(data_dir_dictionary) valid_PDB_complex_data_dir_dictionary = copy.deepcopy(data_dir_dictionary) valid_STRING_data_dir_dictionary = copy.deepcopy(data_dir_dictionary) train_data_dir_dictionary['ID_list'] = train_data_dir_dictionary['train_ID_list'] valid_PDB_monomer_data_dir_dictionary['ID_list'] = valid_PDB_monomer_data_dir_dictionary['valid_PDB_ID_list'] valid_PDB_complex_data_dir_dictionary['ID_list'] = valid_PDB_complex_data_dir_dictionary['valid_PDB_ID_list'] valid_PDB_monomer_data_dir_dictionary['PDB_ID_to_info_dir'] = valid_PDB_monomer_data_dir_dictionary['PDB_Monomer_ID_to_info_dir'] valid_PDB_complex_data_dir_dictionary['PDB_ID_to_info_dir'] = valid_PDB_complex_data_dir_dictionary['PDB_Complex_ID_to_info_dir'] valid_STRING_data_dir_dictionary['ID_list'] = valid_STRING_data_dir_dictionary['valid_STRING_ID_list'] return train_data_dir_dictionary, valid_PDB_monomer_data_dir_dictionary, valid_PDB_complex_data_dir_dictionary, valid_STRING_data_dir_dictionary
[docs] def extract_tip_atom_xyz(query_sequence, xyz, extract_atom_number = 1): # query_sequence : (B, ) # xyz : (B, L, 14, 3) or (I, B, L, 14, 3) # extract : "atom" or "atoms" # xyz_tip_atom : (B, L, 1, 3) or (B, L, 3, 3) or (I, B, L, 1, 3) or (I, B, L, 3, 3) if len(xyz.shape) == 4 : B, L, _, _ = xyz.shape gather_dim = 2 elif len(xyz.shape) == 5 : I, B, L, _, _ = xyz.shape gather_dim = 3 if extract_atom_number == 1 : tip_atom_idx = torch.zeros((B, L), device=xyz.device, dtype=torch.long) for bb in range(B): tip_atom_idx[bb] = torch.tensor([aa2tipidx[aa][1] for aa in query_sequence[bb]], device=xyz.device) tip_atom_idx = tip_atom_idx.unsqueeze(-1) elif extract_atom_number == 3 : tip_atom_idx = torch.zeros((B, L, 3), device=xyz.device, dtype=torch.long) for bb in range(B): tip_atom_idx[bb] = torch.tensor([aa2tipidx[aa] for aa in query_sequence[bb]], device=xyz.device) tip_atom_idx = tip_atom_idx[:,:,:,None].expand(-1,-1,-1,3) # (B, L, 1, 3) or (B, L, 3, 3) if gather_dim == 3 : tip_atom_idx = tip_atom_idx.unsqueeze(0).expand(I,-1,-1,-1,-1) xyz_tip_atom = xyz.gather(gather_dim, tip_atom_idx) return xyz_tip_atom
# TODO Is it really uniform in the space?
[docs] def random_rot_trans(xyz, random_noise=20.0): # xyz: (N, L, 27, 3) N, L = xyz.shape[:2] # pick random rotation axis R_mat = torch.tensor(Rotation.random(N).as_matrix(), dtype=xyz.dtype).to(xyz.device) xyz = torch.einsum('nij,nlaj->nlai', R_mat, xyz) + torch.rand(N,1,1,3, device=xyz.device)*random_noise return xyz
[docs] def center_and_realign_missing(xyz, mask_t): # xyz: (L, 27, 3) # mask_t: (L, 27) L = xyz.shape[0] mask = mask_t[:,:3].all(dim=-1) # True for valid atom (L) # center c.o.m at the origin center_CA = (mask[...,None]*xyz[:,1]).sum(dim=0) / (mask[...,None].sum(dim=0) + 1e-5) # (3) xyz = torch.where(mask.view(L,1,1), xyz - center_CA.view(1, 1, 3), xyz) # move missing residues to the closest valid residues exist_in_xyz = torch.where(mask)[0] # L_sub seqmap = (torch.arange(L, device=xyz.device)[:,None] - exist_in_xyz[None,:]).abs() # (L, Lsub) seqmap = torch.argmin(seqmap, dim=-1) # L idx = torch.gather(exist_in_xyz, 0, seqmap) offset_CA = torch.gather(xyz[:,1], 0, idx.reshape(L,1).expand(-1,3)) xyz = torch.where(mask.view(L,1,1), xyz, xyz + offset_CA.reshape(L,1,3)) return xyz
[docs] def th_ang_v(ab,bc,eps:float=1e-8): def th_norm(x,eps:float=1e-8): return x.square().sum(-1,keepdim=True).add(eps).sqrt() def th_N(x,alpha:float=0): return x/th_norm(x).add(alpha) ab, bc = th_N(ab),th_N(bc) cos_angle = torch.clamp( (ab*bc).sum(-1), -1, 1) sin_angle = torch.sqrt(1-cos_angle.square() + eps) dih = torch.stack((cos_angle,sin_angle),-1) return dih
[docs] def th_dih_v(ab,bc,cd): def th_cross(a,b): a,b = torch.broadcast_tensors(a,b) return torch.cross(a,b, dim=-1) def th_norm(x,eps:float=1e-8): return x.square().sum(-1,keepdim=True).add(eps).sqrt() def th_N(x,alpha:float=0): return x/th_norm(x).add(alpha) ab, bc, cd = th_N(ab),th_N(bc),th_N(cd) n1 = th_N( th_cross(ab,bc) ) n2 = th_N( th_cross(bc,cd) ) sin_angle = (th_cross(n1,bc)*n2).sum(-1) cos_angle = (n1*n2).sum(-1) dih = torch.stack((cos_angle,sin_angle),-1) return dih
[docs] def th_dih(a,b,c,d): return th_dih_v(a-b,b-c,c-d)
# More complicated version splits error in CA-N and CA-C (giving more accurate CB position) # It returns the rigid transformation from local frame to global frame
[docs] def rigid_from_3_points(N, Ca, C, non_ideal=False, eps=1e-8): #N, Ca, C - [B,L, 3] #R - [B,L, 3, 3], det(R)=1, inv(R) = R.T, R is a rotation matrix B,L = N.shape[:2] v1 = C-Ca v2 = N-Ca e1 = v1/(torch.norm(v1, dim=-1, keepdim=True)+eps) u2 = v2-(torch.einsum('bli, bli -> bl', e1, v2)[...,None]*e1) e2 = u2/(torch.norm(u2, dim=-1, keepdim=True)+eps) e3 = torch.cross(e1, e2, dim=-1) R = torch.cat([e1[...,None], e2[...,None], e3[...,None]], axis=-1) #[B,L,3,3] - rotation matrix if non_ideal: v2 = v2/(torch.norm(v2, dim=-1, keepdim=True)+eps) cosref = torch.clamp( torch.sum(e1*v2, dim=-1), min=-1.0, max=1.0) # cosine of current N-CA-C bond angle costgt = cos_ideal_NCAC.item() cos2del = torch.clamp( cosref*costgt + torch.sqrt((1-cosref*cosref)*(1-costgt*costgt)+eps), min=-1.0, max=1.0 ) cosdel = torch.sqrt(0.5*(1+cos2del)+eps) sindel = torch.sign(costgt-cosref) * torch.sqrt(1-0.5*(1+cos2del)+eps) Rp = torch.eye(3, device=N.device).repeat(B,L,1,1) Rp[:,:,0,0] = cosdel Rp[:,:,0,1] = -sindel Rp[:,:,1,0] = sindel Rp[:,:,1,1] = cosdel R = torch.einsum('blij,bljk->blik', R,Rp) return R, Ca
[docs] def rigid_from_3_points_v2(N, Ca, C, mask_crds, chain_break_list, non_ideal=False, eps=1e-8): # N, Ca, C - [B,L, 3] # mask - [B, L, 3] # chain_break - {chain_idx: (start, end)} B, L = N.shape[:2] R_chain_list = [] T_chain_list = [] R_residue_list = [] T_residue_list = [] mask_BB = ~(mask_crds[:,:,:3].sum(dim=-1) < 3.0) # ignore residues having missing BB atoms for loss calculation, [B, L] for bb, chain_break in enumerate(chain_break_list): R_chain_in_list = [] T_chain_in_list = [] R_residue_in_list = [] T_residue_in_list = [] for chain_idx, (start, end) in chain_break.items(): mask_idxs = mask_BB[bb,start:end+1] # [L_chain] chain_N = N[bb,start:end+1] # [L_chain, 3] chain_N = chain_N[mask_idxs] # [L_chain, 3] chain_Ca = Ca[bb,start:end+1] # [L_chain, 3] chain_Ca = chain_Ca[mask_idxs] # [L_chain, 3] chain_C = C[bb,start:end+1] # [L_chain, 3] chain_C = chain_C[mask_idxs] # [L_chain, 3] R_residue_in_chain, T_residue_in_chain = rigid_from_3_points(chain_N.unsqueeze(0), chain_Ca.unsqueeze(0), chain_C.unsqueeze(0), non_ideal, eps) # [1, L_chain, 3, 3], [1, L_chain, 3] R_residue_in_chain = R_residue_in_chain.squeeze(0) # [L_chain, 3, 3] T_residue_in_chain = T_residue_in_chain.squeeze(0) # [L_chain, 3] N_to_C_vector = chain_Ca[-1,:] - chain_Ca[0,:] # [3] L_chain = chain_N.shape[0] # T_chain : CoM of Ca T_chain = chain_Ca.mean(dim=0, keepdim=True) # [1, 3] T_chain = T_chain.expand(L_chain, 3) # [L_chain, 3] # Centering chain_Ca = chain_Ca - T_chain if L_chain > 2: # PCA using Ca U, S, V = torch.svd(chain_Ca.unsqueeze(0)) # U: [1, L_chain, 3], S: [1, 3], V: [1, 3, 3] # breakpoint() e1 = V[0,:,0] # [3] e2 = V[0,:,1] # [3] e3 = V[0,:,2] # [3] # Choose the direction sign1 = torch.sign(torch.sum(e1 * N_to_C_vector, dim=-1)) sign2 = torch.sign(torch.sum(e2 * N_to_C_vector, dim=-1)) sign3 = torch.sign(torch.sum(e3 * N_to_C_vector, dim=-1)) e1 = e1 * sign1 # [3] e2 = e2 * sign2 # [3] e3 = e3 * sign3 # [3] # choose x,y,z. consider right-handed coordinate system check_right_handed = torch.sign(torch.sum(e1 * torch.cross(e2, e3, dim=-1), dim=-1)) # (1 or -1) if check_right_handed > 0: ex = e1 ey = e2 ez = e3 else: ex = e1 ey = e3 ez = e2 R_chain = torch.stack([ex, ey, ez], dim=-1) # [3, 3] else : # In this case, we can't calculate PCA. So, we just use identity matrix R_chain = torch.eye(3, device=N.device).unsqueeze(0).expand(3, 3) R_chain = R_chain.unsqueeze(0).expand(L_chain, 3, 3) # [L_chain, 3, 3] R_chain_inv = R_chain.transpose(-2, -1) # [L_chain, 3, 3] R_residue = torch.einsum('lij,ljk->lik', R_chain_inv, R_residue_in_chain) # [L_chain, 3, 3] T_residue = torch.einsum('lij,lj->li', R_chain_inv, T_residue_in_chain - T_chain) # [L_chain, 3] R_chain_in_list.append(R_chain) T_chain_in_list.append(T_chain) R_residue_in_list.append(R_residue) T_residue_in_list.append(T_residue) R_chain_in = torch.cat(R_chain_in_list, dim=0) # [L, 3, 3] T_chain_in = torch.cat(T_chain_in_list, dim=0) R_residue_in = torch.cat(R_residue_in_list, dim=0) # [L, 3, 3] T_residue_in = torch.cat(T_residue_in_list, dim=0) R_chain_list.append(R_chain_in) T_chain_list.append(T_chain_in) R_residue_list.append(R_residue_in) T_residue_list.append(T_residue_in) R_chain = torch.stack(R_chain_list, dim=0) # [B, L, 3, 3] T_chain = torch.stack(T_chain_list, dim=0) # [B, L, 3] R_residue = torch.stack(R_residue_list, dim=0) # [B, L, 3, 3] T_residue = torch.stack(T_residue_list, dim=0) # [B, L, 3] return R_chain, T_chain, R_residue, T_residue
[docs] def cal_integrated_frame(R_chain, T_chain, R_residue, T_residue): # R_chain : [B, L, 3, 3] # T_chain : [B, L, 3] # R_residue : [B, L, 3, 3] # T_residue : [B, L, 3] R_out = torch.einsum('blij,bljk->blik', R_chain, R_residue) # [B, L, 3, 3] T_out = torch.einsum('blij,blj->bli', R_chain, T_residue) + T_chain # [B, L, 3] return R_out, T_out
[docs] def centering_RT(R, T, mask): # R : [B, L, 3, 3] # T : [B, L, 3] # mask : [B, L, 3] B, L = R.shape[:2] # Centering CoM = (mask[...,None]*T).sum(dim=1) / (mask[...,None].sum(dim=1) + 1e-5) # [B, 3] T = torch.where(mask.view(B,L,1), T - CoM.view(B,1,3), T) return R, T
# process ideal frames
[docs] def make_frame(X, Y): Xn = X / torch.linalg.norm(X) Y = Y - torch.dot(Y, Xn) * Xn Yn = Y / torch.linalg.norm(Y) Z = torch.cross(Xn,Yn) Zn = Z / torch.linalg.norm(Z) return torch.stack((Xn,Yn,Zn), dim=-1)
[docs] def get_Cb(xyz): N = xyz[:,:,0] Ca = xyz[:,:,1] C = xyz[:,:,2] # recreate Cb given N,Ca,C b = Ca - N c = C - Ca a = torch.cross(b, c, dim=-1) Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca return Cb
[docs] def cross_product_matrix(u): B, L = u.shape[:2] matrix = torch.zeros((B, L, 3, 3), device=u.device) matrix[:,:,0,1] = -u[...,2] matrix[:,:,0,2] = u[...,1] matrix[:,:,1,0] = u[...,2] matrix[:,:,1,2] = -u[...,0] matrix[:,:,2,0] = -u[...,1] matrix[:,:,2,1] = u[...,0] return matrix
# writepdb
[docs] def writepdb(filename, atoms, seq, Ls, idx_pdb=None, bfacts=None): f = open(filename,"w") ctr = 1 scpu = seq.cpu().squeeze(0) atomscpu = atoms.cpu().squeeze(0) L = sum(Ls) O = seq.shape[0]//L Ls = Ls * O if bfacts is None: bfacts = torch.zeros(atomscpu.shape[0]) if idx_pdb is None: idx_pdb = 1 + torch.arange(atomscpu.shape[0]) Bfacts = torch.clamp( bfacts.cpu(), 0, 100) chn_idx, res_idx = 0, 0 for i,s in enumerate(scpu): natoms = atomscpu.shape[-2] if (natoms!=14 and natoms!=27): print ('bad size!', atoms.shape) assert(False) atms = aa2long[s] # his protonation state hack if (s==8 and torch.linalg.norm( atomscpu[i,9,:]-atomscpu[i,5,:] ) < 1.7): atms = ( " N "," CA "," C "," O "," CB "," CG "," NE2"," CD2"," CE1"," ND1", None, None, None, None," H "," HA ","1HB ","2HB "," HD2"," HE1", " HD1", None, None, None, None, None, None) # his_d for j,atm_j in enumerate(atms): if (j<natoms and atm_j is not None and not torch.isnan(atomscpu[i,j,:]).any()): f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%( "ATOM", ctr, atm_j, num2aa[s], PDB_CHAIN_IDS[chn_idx], res_idx+1, atomscpu[i,j,0], atomscpu[i,j,1], atomscpu[i,j,2], 1.0, Bfacts[i] ) ) ctr += 1 res_idx += 1 if (chn_idx < len(Ls) and res_idx == Ls[chn_idx]): chn_idx += 1 res_idx = 0
# resolve tip atom indices tip_indices = torch.full((NUM_CLASSES,3), -1) for i in range(NUM_CLASSES): tip_atm = aa2tip[i] atm_long = aa2long[i] for j in range(3): if tip_atm[j] == None: continue tip_indices[i,j] = atm_long.index(tip_atm[j]) # resolve torsion indices torsion_indices = torch.full((NUM_CLASSES,4,4),0) torsion_can_flip = torch.full((NUM_CLASSES,10),False,dtype=torch.bool) for i in range(NUM_CLASSES): i_l, i_a = aa2long[i], aa2longalt[i] for j in range(4): if torsions[i][j] is None: continue for k in range(4): a = torsions[i][j][k] torsion_indices[i,j,k] = i_l.index(a) if (i_l.index(a) != i_a.index(a)): torsion_can_flip[i,3+j] = True ##bb tors never flip # HIS is a special case torsion_can_flip[8,4]=False # build the mapping from atoms in the full rep (Nx27) to the "alternate" rep allatom_mask = torch.zeros((NUM_CLASSES,27), dtype=torch.bool) long2alt = torch.zeros((NUM_CLASSES,27), dtype=torch.long) for i in range(NUM_CLASSES): i_l, i_lalt = aa2long[i], aa2longalt[i] for j,a in enumerate(i_l): if (a is None): long2alt[i,j] = j else: long2alt[i,j] = i_lalt.index(a) allatom_mask[i,j] = True # bond graph traversal num_bonds = torch.zeros((NUM_CLASSES,27,27), dtype=torch.long) for i in range(NUM_CLASSES): num_bonds_i = np.zeros((27,27)) for (bnamei,bnamej) in aabonds[i]: bi,bj = aa2long[i].index(bnamei),aa2long[i].index(bnamej) num_bonds_i[bi,bj] = 1 num_bonds_i = scipy.sparse.csgraph.shortest_path (num_bonds_i,directed=False) num_bonds_i[num_bonds_i>=4] = 4 num_bonds[i,...] = torch.tensor(num_bonds_i) # hbond scoring parameters
[docs] def donorHs(D,bonds,atoms): dHs = [] for (i,j) in bonds: if (i==D): idx_j = atoms.index(j) if (idx_j>=14): # if atom j is a hydrogen dHs.append(idx_j) if (j==D): idx_i = atoms.index(i) if (idx_i>=14): # if atom j is a hydrogen dHs.append(idx_i) assert (len(dHs)>0) return dHs
[docs] def generate_random_xyz(seq, random_noise = 20.0, use_all_model = False, use_random_model=True): # seq : (Model, L) # initialize random xyz from ideal_coords in chemical.py # and then apply random rotation and translation using random_rot_trans M, L = seq.shape if use_all_model : initial_xyz = torch.zeros((M, L, 27, 3)) atom_mask = torch.zeros((M, L, 27), dtype=torch.bool) random_xyz = torch.zeros((M, L, 27, 3)) for ii in range(M): for jj in range(L): for atom_xyz_list in ideal_coords[seq[ii,jj]]: for kk in range(len(atom_xyz_list)): x,y,z = atom_xyz_list[kk][-1] initial_xyz[ii,jj,kk,:] = torch.tensor([x,y,z]) atom_mask[ii,jj,kk] = True random_xyz[ii:ii+1,:,:,:] = random_rot_trans(initial_xyz[ii:ii+1,:,:,:], random_noise) else : model_idx = torch.randint(M, (1,)) if use_random_model else 0 initial_xyz = torch.zeros((1, L, 27, 3)) atom_mask = torch.zeros((1, L, 27), dtype=torch.bool) random_xyz = torch.zeros((1, L, 27, 3)) for jj in range(L): for kk, atom_xyz_list in enumerate(ideal_coords[seq[model_idx,jj].item()]): x,y,z = atom_xyz_list[-1] initial_xyz[0,jj,kk,:] = torch.tensor([x,y,z]) atom_mask[0,jj,kk] = True random_xyz = random_rot_trans(initial_xyz, random_noise) # (1, L, 27, 3) return random_xyz, atom_mask # (M, L, 27, 3) or (1, L, 27, 3) | (M, L, 27) or (1, L, 27)
[docs] def generate_initial_xyz(seq, chain_break, random_noise = 20.0, use_all_model = False, use_random_model=True): # seq : (Model, L) # initialize random xyz from ideal_coords in chemical.py # and then apply random rotation and translation using random_rot_trans M, L = seq.shape if use_all_model : initial_xyz = torch.zeros((M, L, 27, 3)) atom_mask = torch.zeros((M, L, 27), dtype=torch.bool) random_xyz = torch.zeros((M, L, 27, 3)) for ii in range(M): for jj in range(L): for atom_xyz_list in ideal_coords[seq[ii,jj]]: for kk in range(len(atom_xyz_list)): x,y,z = atom_xyz_list[kk][-1] initial_xyz[ii,jj,kk,:] = torch.tensor([x,y,z]) atom_mask[ii,jj,kk] = True else : model_idx = torch.randint(M, (1,)) if use_random_model else 0 initial_xyz = torch.zeros((1, L, 27, 3)) atom_mask = torch.zeros((1, L, 27), dtype=torch.bool) random_xyz = torch.zeros((1, L, 27, 3)) for jj in range(L): for kk, atom_xyz_list in enumerate(ideal_coords[seq[model_idx,jj].item()]): x,y,z = atom_xyz_list[-1] initial_xyz[0,jj,kk,:] = torch.tensor([x,y,z]) atom_mask[0,jj,kk] = True for chain_idx, (start, end) in chain_break.items(): chain_xyz = initial_xyz[:,start:(end+1),:,:] # (M, L_chain, 27, 3) random_xyz[:,start:(end+1),:,:] = random_rot_trans(chain_xyz, random_noise) return random_xyz, atom_mask # (M, L, 27, 3) or (1, L, 27, 3) | (M, L, 27) or (1, L, 27)
[docs] def generate_symmetric_operation(symmetry = "C4"): if symmetry == "C4": # rotate 90 degree around z-axis return torch.tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]], dtype=torch.float) elif symmetry == "C3": # rotate 120 degree around z-axis angle = 120 * (math.pi / 180) # converting degrees to radians return torch.tensor([ [math.cos(angle), -math.sin(angle), 0], [math.sin(angle), math.cos(angle), 0], [0, 0, 1] ], dtype=torch.float) else : raise ValueError("Not implemented yet")
[docs] @torch.no_grad() def generate_symmetric_xyz(seq, chain_break, random_noise = 20.0, use_random_model=True, symmetry = "C4"): # seq : (Model, L) # initialize random xyz from ideal_coords in chemical.py # and then apply random rotation and translation using random_rot_trans M, L = seq.shape model_idx = torch.randint(M, (1,)) if use_random_model else 0 initial_xyz = torch.zeros((1, L, 27, 3)) atom_mask = torch.zeros((1, L, 27), dtype=torch.bool) symmetric_xyz = torch.zeros((1, L, 27, 3)) for jj in range(L): for kk, atom_xyz_list in enumerate(ideal_coords[seq[model_idx,jj].item()]): x,y,z = atom_xyz_list[-1] initial_xyz[0,jj,kk,:] = torch.tensor([x,y,z]) atom_mask[0,jj,kk] = True initial_xyz = initial_xyz + torch.randn(1,1,3).repeat(1,L,1).unsqueeze(-2) rotate_matrix = torch.eye(3) # (3, 3) unit_rotation_matrix = generate_symmetric_operation(symmetry) # (3, 3) first = True for chain_idx, (start, end) in chain_break.items(): default_xyz = initial_xyz[:,start:(end+1),:,:] # (M, L_chain, 27, 3) default_xyz = torch.einsum('mlcj,ij->mlci', default_xyz, rotate_matrix) # (M, L_chain, 27, 3) symmetric_xyz[:,start:(end+1),:,:] = default_xyz rotate_matrix = torch.einsum('ij,jk->ik', rotate_matrix, unit_rotation_matrix) # (3, 3) return symmetric_xyz, atom_mask # (M, L, 27, 3) or (1, L, 27, 3) | (M, L, 27) or (1, L, 27)
# kinematic parameters base_indices = torch.full((NUM_CLASSES,27),0, dtype=torch.long) xyzs_in_base_frame = torch.ones((NUM_CLASSES,27,4)) RTs_by_torsion = torch.eye(4).repeat(NUM_CLASSES,7,1,1) reference_angles = torch.ones((NUM_CLASSES,3,2)) for i in range(NUM_CLASSES): i_l = aa2long[i] for name, base, coords in ideal_coords[i]: idx = i_l.index(name) base_indices[i,idx] = base xyzs_in_base_frame[i,idx,:3] = torch.tensor(coords) # omega frame RTs_by_torsion[i,0,:3,:3] = torch.eye(3) RTs_by_torsion[i,0,:3,3] = torch.zeros(3) # phi frame RTs_by_torsion[i,1,:3,:3] = make_frame( xyzs_in_base_frame[i,0,:3] - xyzs_in_base_frame[i,1,:3], torch.tensor([1.,0.,0.]) ) RTs_by_torsion[i,1,:3,3] = xyzs_in_base_frame[i,0,:3] # psi frame RTs_by_torsion[i,2,:3,:3] = make_frame( xyzs_in_base_frame[i,2,:3] - xyzs_in_base_frame[i,1,:3], xyzs_in_base_frame[i,1,:3] - xyzs_in_base_frame[i,0,:3] ) RTs_by_torsion[i,2,:3,3] = xyzs_in_base_frame[i,2,:3] # chi1 frame if torsions[i][0] is not None: a0,a1,a2 = torsion_indices[i,0,0:3] RTs_by_torsion[i,3,:3,:3] = make_frame( xyzs_in_base_frame[i,a2,:3]-xyzs_in_base_frame[i,a1,:3], xyzs_in_base_frame[i,a0,:3]-xyzs_in_base_frame[i,a1,:3], ) RTs_by_torsion[i,3,:3,3] = xyzs_in_base_frame[i,a2,:3] # chi2~4 frame for j in range(1,4): if torsions[i][j] is not None: a2 = torsion_indices[i,j,2] if ((i==18 and j==2) or (i==8 and j==2)): # TYR CZ-OH & HIS CE1-HE1 a special case a0,a1 = torsion_indices[i,j,0:2] RTs_by_torsion[i,3+j,:3,:3] = make_frame( xyzs_in_base_frame[i,a2,:3]-xyzs_in_base_frame[i,a1,:3], xyzs_in_base_frame[i,a0,:3]-xyzs_in_base_frame[i,a1,:3] ) else: RTs_by_torsion[i,3+j,:3,:3] = make_frame( xyzs_in_base_frame[i,a2,:3], torch.tensor([-1.,0.,0.]), ) RTs_by_torsion[i,3+j,:3,3] = xyzs_in_base_frame[i,a2,:3] # CB/CG angles NCr = 0.5*(xyzs_in_base_frame[i,0,:3]+xyzs_in_base_frame[i,2,:3]) CAr = xyzs_in_base_frame[i,1,:3] CBr = xyzs_in_base_frame[i,4,:3] CGr = xyzs_in_base_frame[i,5,:3] reference_angles[i,0,:]=th_ang_v(CBr-CAr,NCr-CAr) NCp = xyzs_in_base_frame[i,2,:3]-xyzs_in_base_frame[i,0,:3] NCpp = NCp - torch.dot(NCp,NCr)/ torch.dot(NCr,NCr) * NCr reference_angles[i,1,:]=th_ang_v(CBr-CAr,NCpp) reference_angles[i,2,:]=th_ang_v(CGr,torch.tensor([-1.,0.,0.]))
[docs] def draw_attn(attn,save_path): import matplotlib.pyplot as plt # attn : (B, L, L, n_head) attn_max = attn.max() attn_min = attn.min() attn = attn[0] n_head = attn.shape[-1] # draw all heads in one figure fig, axs = plt.subplots(1, n_head, figsize=(n_head*3, 3)) for i in range(n_head): axs[i].imshow(attn[:,:,i].detach().cpu().numpy(), cmap='hot', interpolation='nearest') # show heatmap at the right plt.colorbar(axs[0].imshow(attn[:,:,0].detach().cpu().numpy(), cmap='hot', interpolation='nearest'), ax=axs, orientation='horizontal') # if os.path.exists(save_path): # idx = 0 # while os.path.exists(save_path): # save_path = save_path.split('.')[0] # save_path = save_path + f'.{idx}.png' # idx += 1 print(save_path) plt.savefig(save_path) plt.close()
[docs] def draw_attn_multirow(attn, num_row, save_path): import matplotlib.pyplot as plt # attn: (B, L, L, n_head) attn = attn[0] n_head = attn.shape[-1] # Calculate the number of columns needed if n_head isn't divisible evenly by num_row num_col = -(-n_head // num_row) # Ceiling division # Draw all heads in one figure fig, axs = plt.subplots(num_row, num_col, figsize=(num_col*3, 3*num_row)) axs = axs.flatten() # Flatten the array for easier indexing for i in range(n_head): ax = axs[i] im = ax.imshow(attn[:, :, i].detach().cpu().numpy(), cmap='hot', interpolation='nearest') ax.axis('off') # Turn off the axis to make it cleaner since the location isn't important here # Show heatmap at the right of the last plot plt.colorbar(im, ax=axs, orientation='horizontal', fraction=0.046, pad=0.04) plt.savefig(save_path) plt.close()
[docs] def draw_pae(logit_pae, mask, saving_path): # logit_pae : (B, 64, L, L) 0~32 # mask : (B, L, L) pae = torch.nn.functional.softmax(logit_pae, dim=1) # (B, 64, L, L) nbin = pae.shape[1] bin_value = 0.5 * torch.arange(nbin, dtype=torch.float32, device=pae.device) expected_error = torch.einsum('bnij, n -> bij', pae, bin_value) # (B, L, L) expected_error = expected_error * mask if mask is not None else expected_error # Define the colormap cdict = { 'red': [(0.0, 0.0, 0.0), # Blue at 0 (0.5, 1.0, 1.0), # White at 0.5 (corresponds to 16 if scaled 0-32) (1.0, 1.0, 1.0)], # Red at 1 'green': [(0.0, 0.0, 0.0), # No green at 0 (0.5, 1.0, 1.0), # White (full green) at 0.5 (1.0, 0.0, 0.0)], # No green at 1 'blue': [(0.0, 1.0, 1.0), # Blue at 0 (0.5, 1.0, 1.0), # White (full blue) at 0.5 (1.0, 0.0, 0.0)] # No blue at 1 } # black for expected_error = 0 cdict['red'].insert(0, (0.0, 0.0, 0.0)) cdict['green'].insert(0, (0.0, 0.0, 0.0)) cdict['blue'].insert(0, (0.0, 0.0, 0.0)) custom_cmap = mcolors.LinearSegmentedColormap('custom_cmap', cdict) # Normalize the expected_error values to 0-1 scale norm = mcolors.Normalize(vmin=0, vmax=32) # Draw all batch pae in one figure B = expected_error.shape[0] if B>1: fig, axs = plt.subplots(1, B, figsize=(B*3, 3)) axs = axs.flatten() for i in range(B): im = axs[i].imshow(expected_error[i].detach().cpu().numpy(), cmap=custom_cmap, norm=norm) axs[i].axis('off') else : fig, axs = plt.subplots(1, 1, figsize=(3, 3)) im = axs.imshow(expected_error[0].detach().cpu().numpy(), cmap=custom_cmap, norm=norm) axs.axis('off') # Show heatmap at the right of the last plot plt.colorbar(im, ax=axs, orientation='horizontal', fraction=0.046, pad=0.04) plt.savefig(saving_path) plt.close()