Source code for miniworld.utils.util_module

import torch.nn as nn
from miniworld.utils.util import *
from miniworld.utils.chemical import aa2num



[docs] def init_lecun_normal(module, scale=1.0): def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2): normal = torch.distributions.normal.Normal(0, 1) alpha = (a - mu) / sigma beta = (b - mu) / sigma alpha_normal_cdf = normal.cdf(torch.tensor(alpha)) p = alpha_normal_cdf + (normal.cdf(torch.tensor(beta)) - alpha_normal_cdf) * uniform v = torch.clamp(2 * p - 1, -1 + 1e-8, 1 - 1e-8) x = mu + sigma * np.sqrt(2) * torch.erfinv(v) x = torch.clamp(x, a, b) return x def sample_truncated_normal(shape, scale=1.0): stddev = np.sqrt(scale/shape[-1])/.87962566103423978 # shape[-1] = fan_in return stddev * truncated_normal(torch.rand(shape)) module.weight = torch.nn.Parameter( (sample_truncated_normal(module.weight.shape)) ) return module
[docs] def init_lecun_normal_param(weight, scale=1.0): def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2): normal = torch.distributions.normal.Normal(0, 1) alpha = (a - mu) / sigma beta = (b - mu) / sigma alpha_normal_cdf = normal.cdf(torch.tensor(alpha)) p = alpha_normal_cdf + (normal.cdf(torch.tensor(beta)) - alpha_normal_cdf) * uniform v = torch.clamp(2 * p - 1, -1 + 1e-8, 1 - 1e-8) x = mu + sigma * np.sqrt(2) * torch.erfinv(v) x = torch.clamp(x, a, b) return x def sample_truncated_normal(shape, scale=1.0): stddev = np.sqrt(scale/shape[-1])/.87962566103423978 # shape[-1] = fan_in return stddev * truncated_normal(torch.rand(shape)) weight = torch.nn.Parameter( (sample_truncated_normal(weight.shape)) ) return weight
# for gradient checkpointing
[docs] def create_custom_forward(module, **kwargs): def custom_forward(*inputs): return module(*inputs, **kwargs) return custom_forward
[docs] def get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
[docs] class Dropout(nn.Module): # Dropout entire row or column def __init__(self, broadcast_dim=None, p_drop=0.15): super(Dropout, self).__init__() # give ones with probability of 1-p_drop / zeros with p_drop self.sampler = torch.distributions.bernoulli.Bernoulli(torch.tensor([1-p_drop])) self.broadcast_dim=broadcast_dim self.p_drop=p_drop
[docs] def forward(self, x): if not self.training: # no drophead during evaluation mode return x shape = list(x.shape) if not self.broadcast_dim == None: shape[self.broadcast_dim] = 1 mask = self.sampler.sample(shape).to(x.device).view(shape) x = mask * x / (1.0 - self.p_drop) return x
[docs] def rbf(D, D_min=0.0, D_count=64, D_sigma=0.5): # Distance radial basis function D_max = D_min + (D_count-1) * D_sigma D_mu = torch.linspace(D_min, D_max, D_count).to(D.device) D_mu = D_mu[None,:] D_expand = torch.unsqueeze(D, -1) RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2) return RBF
[docs] def get_seqsep(idx): ''' Input: - idx: residue indices of given sequence (B,L) Output: - seqsep: sequence separation feature with sign (B, L, L, 1) Sergey found that having sign in seqsep features helps a little ''' seqsep = idx[:,None,:] - idx[:,:,None] sign = torch.sign(seqsep) neigh = torch.abs(seqsep) neigh[neigh > 1] = 0.0 # if bonded -- 1.0 / else 0.0 neigh = sign * neigh return neigh.unsqueeze(-1)
[docs] def get_topk(D, sep, top_k=64, kmin=32): B, L = D.shape[:2] # get top_k neighbors D_neigh, E_idx = torch.topk(D, min(top_k, L-1), largest=False) # shape of E_idx: (B, L, top_k) topk_matrix = torch.zeros((B, L, L), device=D.device) topk_matrix.scatter_(2, E_idx, 1.0) # put an edge if any of the 3 conditions are met: # 1) |i-j| <= kmin (connect sequentially adjacent residues) # 2) top_k neighbors cond = torch.logical_or(topk_matrix > 0.0, sep < kmin) b,i,j = torch.where(cond) return b, i, j
[docs] def make_rotX(angs, eps=1e-6): B,L = angs.shape[:2] NORM = torch.linalg.norm(angs, dim=-1) + eps RTs = torch.eye(4, device=angs.device).repeat(B,L,1,1) RTs[:,:,1,1] = angs[:,:,0]/NORM RTs[:,:,1,2] = -angs[:,:,1]/NORM RTs[:,:,2,1] = angs[:,:,1]/NORM RTs[:,:,2,2] = angs[:,:,0]/NORM return RTs
# rotate about the z axis
[docs] def make_rotZ(angs, eps=1e-6): B,L = angs.shape[:2] NORM = torch.linalg.norm(angs, dim=-1) + eps RTs = torch.eye(4, device=angs.device).repeat(B,L,1,1) RTs[:,:,0,0] = angs[:,:,0]/NORM RTs[:,:,0,1] = -angs[:,:,1]/NORM RTs[:,:,1,0] = angs[:,:,1]/NORM RTs[:,:,1,1] = angs[:,:,0]/NORM return RTs
# rotate about an arbitrary axis
[docs] def make_rot_axis(angs, u, eps=1e-6): B,L = angs.shape[:2] NORM = torch.linalg.norm(angs, dim=-1) + eps RTs = torch.eye(4, device=angs.device).repeat(B,L,1,1) ct = angs[:,:,0]/NORM st = angs[:,:,1]/NORM u0 = u[:,:,0] u1 = u[:,:,1] u2 = u[:,:,2] RTs[:,:,0,0] = ct+u0*u0*(1-ct) RTs[:,:,0,1] = u0*u1*(1-ct)-u2*st RTs[:,:,0,2] = u0*u2*(1-ct)+u1*st RTs[:,:,1,0] = u0*u1*(1-ct)+u2*st RTs[:,:,1,1] = ct+u1*u1*(1-ct) RTs[:,:,1,2] = u1*u2*(1-ct)-u0*st RTs[:,:,2,0] = u0*u2*(1-ct)-u1*st RTs[:,:,2,1] = u1*u2*(1-ct)+u0*st RTs[:,:,2,2] = ct+u2*u2*(1-ct) return RTs
[docs] class ComputeAllAtomCoords(nn.Module): def __init__(self): super(ComputeAllAtomCoords, self).__init__() #self.base_indices = nn.Parameter(base_indices, requires_grad=False) #self.RTs_in_base_frame = nn.Parameter(RTs_by_torsion, requires_grad=False) #self.xyzs_in_base_frame = nn.Parameter(xyzs_in_base_frame, requires_grad=False)
[docs] def forward(self, seq, xyz, alphas, non_ideal=False, use_H=True): B,L = xyz.shape[:2] Rs, Ts = rigid_from_3_points(xyz[...,0,:],xyz[...,1,:],xyz[...,2,:], non_ideal=non_ideal) RTF0 = torch.eye(4).repeat(B,L,1,1).to(device=Rs.device) # bb RTF0[:,:,:3,:3] = Rs RTF0[:,:,:3,3] = Ts # omega RTF1 = torch.einsum( 'brij,brjk,brkl->bril', RTF0, self.RTs_in_base_frame[seq,0,:], make_rotX(alphas[:,:,0,:])) # phi RTF2 = torch.einsum( 'brij,brjk,brkl->bril', RTF0, self.RTs_in_base_frame[seq,1,:], make_rotX(alphas[:,:,1,:])) # psi RTF3 = torch.einsum( 'brij,brjk,brkl->bril', RTF0, self.RTs_in_base_frame[seq,2,:], make_rotX(alphas[:,:,2,:])) # CB bend basexyzs = self.xyzs_in_base_frame[seq] NCr = 0.5*(basexyzs[:,:,2,:3]+basexyzs[:,:,0,:3]) CAr = (basexyzs[:,:,1,:3]) CBr = (basexyzs[:,:,4,:3]) CBrotaxis1 = (CBr-CAr).cross(NCr-CAr) CBrotaxis1 /= torch.linalg.norm(CBrotaxis1, dim=-1, keepdim=True)+1e-8 # CB twist NCp = basexyzs[:,:,2,:3] - basexyzs[:,:,0,:3] NCpp = NCp - torch.sum(NCp*NCr, dim=-1, keepdim=True)/ torch.sum(NCr*NCr, dim=-1, keepdim=True) * NCr CBrotaxis2 = (CBr-CAr).cross(NCpp) CBrotaxis2 /= torch.linalg.norm(CBrotaxis2, dim=-1, keepdim=True)+1e-8 CBrot1 = make_rot_axis(alphas[:,:,7,:], CBrotaxis1 ) CBrot2 = make_rot_axis(alphas[:,:,8,:], CBrotaxis2 ) RTF8 = torch.einsum( 'brij,brjk,brkl->bril', RTF0, CBrot1,CBrot2) # chi1 + CG bend RTF4 = torch.einsum( 'brij,brjk,brkl,brlm->brim', RTF8, self.RTs_in_base_frame[seq,3,:], make_rotX(alphas[:,:,3,:]), make_rotZ(alphas[:,:,9,:])) # chi2 RTF5 = torch.einsum( 'brij,brjk,brkl->bril', RTF4, self.RTs_in_base_frame[seq,4,:],make_rotX(alphas[:,:,4,:])) # chi3 RTF6 = torch.einsum( 'brij,brjk,brkl->bril', RTF5,self.RTs_in_base_frame[seq,5,:],make_rotX(alphas[:,:,5,:])) # chi4 RTF7 = torch.einsum( 'brij,brjk,brkl->bril', RTF6,self.RTs_in_base_frame[seq,6,:],make_rotX(alphas[:,:,6,:])) RTframes = torch.stack(( RTF0,RTF1,RTF2,RTF3,RTF4,RTF5,RTF6,RTF7,RTF8 ),dim=2) xyzs = torch.einsum( 'brtij,brtj->brti', RTframes.gather(2,self.base_indices[seq][...,None,None].repeat(1,1,1,4,4)), basexyzs ) if use_H: return RTframes, xyzs[...,:3] else: return RTframes, xyzs[...,:14,:3]
[docs] class XYZConverter(nn.Module): def __init__(self): super(XYZConverter, self).__init__() self.register_buffer("torsion_indices", torsion_indices) self.register_buffer("torsion_can_flip", torsion_can_flip) self.register_buffer("ref_angles", reference_angles) self.register_buffer("tip_indices", tip_indices) self.register_buffer("base_indices", base_indices) self.register_buffer("RTs_in_base_frame", RTs_by_torsion) self.register_buffer("xyzs_in_base_frame", xyzs_in_base_frame)
[docs] def compute_all_atom(self, seq, xyz, alphas, non_ideal=True, use_H=True): B,L = xyz.shape[:2] seq = seq.to(torch.long) Rs, Ts = rigid_from_3_points(xyz[...,0,:],xyz[...,1,:],xyz[...,2,:], non_ideal=non_ideal) RTF0 = torch.eye(4).repeat(B,L,1,1).to(device=Rs.device) # bb RTF0[:,:,:3,:3] = Rs RTF0[:,:,:3,3] = Ts # omega RTF1 = torch.einsum( 'brij,brjk,brkl->bril', RTF0, self.RTs_in_base_frame[seq,0,:], make_rotX(alphas[:,:,0,:])) # phi RTF2 = torch.einsum( 'brij,brjk,brkl->bril', RTF0, self.RTs_in_base_frame[seq,1,:], make_rotX(alphas[:,:,1,:])) # psi RTF3 = torch.einsum( 'brij,brjk,brkl->bril', RTF0, self.RTs_in_base_frame[seq,2,:], make_rotX(alphas[:,:,2,:])) # CB bend basexyzs = self.xyzs_in_base_frame[seq] NCr = 0.5*(basexyzs[:,:,2,:3]+basexyzs[:,:,0,:3]) CAr = (basexyzs[:,:,1,:3]) CBr = (basexyzs[:,:,4,:3]) CBrotaxis1 = (CBr-CAr).cross(NCr-CAr) CBrotaxis1 /= torch.linalg.norm(CBrotaxis1, dim=-1, keepdim=True)+1e-8 # CB twist NCp = basexyzs[:,:,2,:3] - basexyzs[:,:,0,:3] NCpp = NCp - torch.sum(NCp*NCr, dim=-1, keepdim=True)/ torch.sum(NCr*NCr, dim=-1, keepdim=True) * NCr CBrotaxis2 = (CBr-CAr).cross(NCpp) CBrotaxis2 /= torch.linalg.norm(CBrotaxis2, dim=-1, keepdim=True)+1e-8 CBrot1 = make_rot_axis(alphas[:,:,7,:], CBrotaxis1 ) CBrot2 = make_rot_axis(alphas[:,:,8,:], CBrotaxis2 ) RTF8 = torch.einsum( 'brij,brjk,brkl->bril', RTF0, CBrot1,CBrot2) # chi1 + CG bend RTF4 = torch.einsum( 'brij,brjk,brkl,brlm->brim', RTF8, self.RTs_in_base_frame[seq,3,:], make_rotX(alphas[:,:,3,:]), make_rotZ(alphas[:,:,9,:])) # chi2 RTF5 = torch.einsum( 'brij,brjk,brkl->bril', RTF4, self.RTs_in_base_frame[seq,4,:],make_rotX(alphas[:,:,4,:])) # chi3 RTF6 = torch.einsum( 'brij,brjk,brkl->bril', RTF5,self.RTs_in_base_frame[seq,5,:],make_rotX(alphas[:,:,5,:])) # chi4 RTF7 = torch.einsum( 'brij,brjk,brkl->bril', RTF6,self.RTs_in_base_frame[seq,6,:],make_rotX(alphas[:,:,6,:])) RTframes = torch.stack(( RTF0,RTF1,RTF2,RTF3,RTF4,RTF5,RTF6,RTF7,RTF8 ),dim=2) xyzs = torch.einsum( 'brtij,brtj->brti', RTframes.gather(2,self.base_indices[seq][...,None,None].repeat(1,1,1,4,4)), basexyzs ) if use_H: return RTframes, xyzs[...,:3] else: return RTframes, xyzs[...,:14,:3]
[docs] def get_tor_mask(self, seq, mask_in=None): B,L = seq.shape[:2] tors_mask = torch.ones((B,L,10), dtype=torch.bool, device=seq.device) tors_mask[...,3:7] = self.torsion_indices[seq,:,-1] > 0 tors_mask[:,0,1] = False tors_mask[:,-1,0] = False # mask for additional angles tors_mask[:,:,7] = seq!=aa2num['GLY'] tors_mask[:,:,8] = seq!=aa2num['GLY'] tors_mask[:,:,9] = torch.logical_and( seq!=aa2num['GLY'], seq!=aa2num['ALA'] ) tors_mask[:,:,9] = torch.logical_and( tors_mask[:,:,9], seq!=aa2num['UNK'] ) tors_mask[:,:,9] = torch.logical_and( tors_mask[:,:,9], seq!=aa2num['MAS'] ) if mask_in != None: # mask for missing atoms # chis ti0 = torch.gather(mask_in,2,self.torsion_indices[seq,:,0]) ti1 = torch.gather(mask_in,2,self.torsion_indices[seq,:,1]) ti2 = torch.gather(mask_in,2,self.torsion_indices[seq,:,2]) ti3 = torch.gather(mask_in,2,self.torsion_indices[seq,:,3]) #is_valid = torch.stack((ti0, ti1, ti2, ti3), dim=-2).all(dim=-1) # bug.... (2023 Feb 24 fixed) is_valid = torch.stack((ti0, ti1, ti2, ti3), dim=-1).all(dim=-1) tors_mask[...,3:7] = torch.logical_and(tors_mask[...,3:7], is_valid) tors_mask[:,:,7] = torch.logical_and(tors_mask[:,:,7], mask_in[:,:,4]) # CB exist? tors_mask[:,:,8] = torch.logical_and(tors_mask[:,:,8], mask_in[:,:,4]) # CB exist? tors_mask[:,:,9] = torch.logical_and(tors_mask[:,:,9], mask_in[:,:,5]) # XG exist? return tors_mask
[docs] def get_torsions(self, xyz_in, seq, mask_in=None): B,L = xyz_in.shape[:2] tors_mask = self.get_tor_mask(seq, mask_in) # torsions to restrain to 0 or 180degree tors_planar = torch.zeros((B, L, 10), dtype=torch.bool, device=xyz_in.device) tors_planar[:,:,5] = (seq == aa2num['TYR']) # TYR chi 3 should be planar # idealize given xyz coordinates before computing torsion angles xyz = xyz_in.clone() Rs, Ts = rigid_from_3_points(xyz[...,0,:],xyz[...,1,:],xyz[...,2,:]) Nideal = torch.tensor([-0.5272, 1.3593, 0.000], device=xyz_in.device) Cideal = torch.tensor([1.5233, 0.000, 0.000], device=xyz_in.device) xyz[...,0,:] = torch.einsum('brij,j->bri', Rs, Nideal) + Ts xyz[...,2,:] = torch.einsum('brij,j->bri', Rs, Cideal) + Ts torsions = torch.zeros( (B,L,10,2), device=xyz.device ) # avoid undefined angles for H generation torsions[:,0,1,0] = 1.0 torsions[:,-1,0,0] = 1.0 # omega torsions[:,:-1,0,:] = th_dih(xyz[:,:-1,1,:],xyz[:,:-1,2,:],xyz[:,1:,0,:],xyz[:,1:,1,:]) # phi torsions[:,1:,1,:] = th_dih(xyz[:,:-1,2,:],xyz[:,1:,0,:],xyz[:,1:,1,:],xyz[:,1:,2,:]) # psi torsions[:,:,2,:] = -1 * th_dih(xyz[:,:,0,:],xyz[:,:,1,:],xyz[:,:,2,:],xyz[:,:,3,:]) # chis ti0 = torch.gather(xyz,2,self.torsion_indices[seq,:,0,None].repeat(1,1,1,3)) ti1 = torch.gather(xyz,2,self.torsion_indices[seq,:,1,None].repeat(1,1,1,3)) ti2 = torch.gather(xyz,2,self.torsion_indices[seq,:,2,None].repeat(1,1,1,3)) ti3 = torch.gather(xyz,2,self.torsion_indices[seq,:,3,None].repeat(1,1,1,3)) torsions[:,:,3:7,:] = th_dih(ti0,ti1,ti2,ti3) # CB bend NC = 0.5*( xyz[:,:,0,:3] + xyz[:,:,2,:3] ) CA = xyz[:,:,1,:3] CB = xyz[:,:,4,:3] t = th_ang_v(CB-CA,NC-CA) t0 = self.ref_angles[seq][...,0,:] torsions[:,:,7,:] = torch.stack( (torch.sum(t*t0,dim=-1),t[...,0]*t0[...,1]-t[...,1]*t0[...,0]), dim=-1 ) # CB twist NCCA = NC-CA NCp = xyz[:,:,2,:3] - xyz[:,:,0,:3] NCpp = NCp - torch.sum(NCp*NCCA, dim=-1, keepdim=True)/ torch.sum(NCCA*NCCA, dim=-1, keepdim=True) * NCCA t = th_ang_v(CB-CA,NCpp) t0 = self.ref_angles[seq][...,1,:] torsions[:,:,8,:] = torch.stack( (torch.sum(t*t0,dim=-1),t[...,0]*t0[...,1]-t[...,1]*t0[...,0]), dim=-1 ) # CG bend CG = xyz[:,:,5,:3] t = th_ang_v(CG-CB,CA-CB) t0 = self.ref_angles[seq][...,2,:] torsions[:,:,9,:] = torch.stack( (torch.sum(t*t0,dim=-1),t[...,0]*t0[...,1]-t[...,1]*t0[...,0]), dim=-1 ) tors_mask *= (~torch.isnan(torsions[...,0])) tors_mask *= (~torch.isnan(torsions[...,1])) torsions = torch.nan_to_num(torsions) # alt chis torsions_alt = torsions.clone() torsions_alt[self.torsion_can_flip[seq,:]] *= -1 return torsions, torsions_alt, tors_mask, tors_planar