import torch.nn as nn
from miniworld.utils.util import *
from miniworld.utils.chemical import aa2num
[docs]
def init_lecun_normal(module, scale=1.0):
def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2):
normal = torch.distributions.normal.Normal(0, 1)
alpha = (a - mu) / sigma
beta = (b - mu) / sigma
alpha_normal_cdf = normal.cdf(torch.tensor(alpha))
p = alpha_normal_cdf + (normal.cdf(torch.tensor(beta)) - alpha_normal_cdf) * uniform
v = torch.clamp(2 * p - 1, -1 + 1e-8, 1 - 1e-8)
x = mu + sigma * np.sqrt(2) * torch.erfinv(v)
x = torch.clamp(x, a, b)
return x
def sample_truncated_normal(shape, scale=1.0):
stddev = np.sqrt(scale/shape[-1])/.87962566103423978 # shape[-1] = fan_in
return stddev * truncated_normal(torch.rand(shape))
module.weight = torch.nn.Parameter( (sample_truncated_normal(module.weight.shape)) )
return module
[docs]
def init_lecun_normal_param(weight, scale=1.0):
def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2):
normal = torch.distributions.normal.Normal(0, 1)
alpha = (a - mu) / sigma
beta = (b - mu) / sigma
alpha_normal_cdf = normal.cdf(torch.tensor(alpha))
p = alpha_normal_cdf + (normal.cdf(torch.tensor(beta)) - alpha_normal_cdf) * uniform
v = torch.clamp(2 * p - 1, -1 + 1e-8, 1 - 1e-8)
x = mu + sigma * np.sqrt(2) * torch.erfinv(v)
x = torch.clamp(x, a, b)
return x
def sample_truncated_normal(shape, scale=1.0):
stddev = np.sqrt(scale/shape[-1])/.87962566103423978 # shape[-1] = fan_in
return stddev * truncated_normal(torch.rand(shape))
weight = torch.nn.Parameter( (sample_truncated_normal(weight.shape)) )
return weight
# for gradient checkpointing
[docs]
def create_custom_forward(module, **kwargs):
def custom_forward(*inputs):
return module(*inputs, **kwargs)
return custom_forward
[docs]
def get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
[docs]
class Dropout(nn.Module):
# Dropout entire row or column
def __init__(self, broadcast_dim=None, p_drop=0.15):
super(Dropout, self).__init__()
# give ones with probability of 1-p_drop / zeros with p_drop
self.sampler = torch.distributions.bernoulli.Bernoulli(torch.tensor([1-p_drop]))
self.broadcast_dim=broadcast_dim
self.p_drop=p_drop
[docs]
def forward(self, x):
if not self.training: # no drophead during evaluation mode
return x
shape = list(x.shape)
if not self.broadcast_dim == None:
shape[self.broadcast_dim] = 1
mask = self.sampler.sample(shape).to(x.device).view(shape)
x = mask * x / (1.0 - self.p_drop)
return x
[docs]
def rbf(D, D_min=0.0, D_count=64, D_sigma=0.5):
# Distance radial basis function
D_max = D_min + (D_count-1) * D_sigma
D_mu = torch.linspace(D_min, D_max, D_count).to(D.device)
D_mu = D_mu[None,:]
D_expand = torch.unsqueeze(D, -1)
RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2)
return RBF
[docs]
def get_seqsep(idx):
'''
Input:
- idx: residue indices of given sequence (B,L)
Output:
- seqsep: sequence separation feature with sign (B, L, L, 1)
Sergey found that having sign in seqsep features helps a little
'''
seqsep = idx[:,None,:] - idx[:,:,None]
sign = torch.sign(seqsep)
neigh = torch.abs(seqsep)
neigh[neigh > 1] = 0.0 # if bonded -- 1.0 / else 0.0
neigh = sign * neigh
return neigh.unsqueeze(-1)
[docs]
def get_topk(D, sep, top_k=64, kmin=32):
B, L = D.shape[:2]
# get top_k neighbors
D_neigh, E_idx = torch.topk(D, min(top_k, L-1), largest=False) # shape of E_idx: (B, L, top_k)
topk_matrix = torch.zeros((B, L, L), device=D.device)
topk_matrix.scatter_(2, E_idx, 1.0)
# put an edge if any of the 3 conditions are met:
# 1) |i-j| <= kmin (connect sequentially adjacent residues)
# 2) top_k neighbors
cond = torch.logical_or(topk_matrix > 0.0, sep < kmin)
b,i,j = torch.where(cond)
return b, i, j
[docs]
def make_rotX(angs, eps=1e-6):
B,L = angs.shape[:2]
NORM = torch.linalg.norm(angs, dim=-1) + eps
RTs = torch.eye(4, device=angs.device).repeat(B,L,1,1)
RTs[:,:,1,1] = angs[:,:,0]/NORM
RTs[:,:,1,2] = -angs[:,:,1]/NORM
RTs[:,:,2,1] = angs[:,:,1]/NORM
RTs[:,:,2,2] = angs[:,:,0]/NORM
return RTs
# rotate about the z axis
[docs]
def make_rotZ(angs, eps=1e-6):
B,L = angs.shape[:2]
NORM = torch.linalg.norm(angs, dim=-1) + eps
RTs = torch.eye(4, device=angs.device).repeat(B,L,1,1)
RTs[:,:,0,0] = angs[:,:,0]/NORM
RTs[:,:,0,1] = -angs[:,:,1]/NORM
RTs[:,:,1,0] = angs[:,:,1]/NORM
RTs[:,:,1,1] = angs[:,:,0]/NORM
return RTs
# rotate about an arbitrary axis
[docs]
def make_rot_axis(angs, u, eps=1e-6):
B,L = angs.shape[:2]
NORM = torch.linalg.norm(angs, dim=-1) + eps
RTs = torch.eye(4, device=angs.device).repeat(B,L,1,1)
ct = angs[:,:,0]/NORM
st = angs[:,:,1]/NORM
u0 = u[:,:,0]
u1 = u[:,:,1]
u2 = u[:,:,2]
RTs[:,:,0,0] = ct+u0*u0*(1-ct)
RTs[:,:,0,1] = u0*u1*(1-ct)-u2*st
RTs[:,:,0,2] = u0*u2*(1-ct)+u1*st
RTs[:,:,1,0] = u0*u1*(1-ct)+u2*st
RTs[:,:,1,1] = ct+u1*u1*(1-ct)
RTs[:,:,1,2] = u1*u2*(1-ct)-u0*st
RTs[:,:,2,0] = u0*u2*(1-ct)-u1*st
RTs[:,:,2,1] = u1*u2*(1-ct)+u0*st
RTs[:,:,2,2] = ct+u2*u2*(1-ct)
return RTs
[docs]
class ComputeAllAtomCoords(nn.Module):
def __init__(self):
super(ComputeAllAtomCoords, self).__init__()
#self.base_indices = nn.Parameter(base_indices, requires_grad=False)
#self.RTs_in_base_frame = nn.Parameter(RTs_by_torsion, requires_grad=False)
#self.xyzs_in_base_frame = nn.Parameter(xyzs_in_base_frame, requires_grad=False)
[docs]
def forward(self, seq, xyz, alphas, non_ideal=False, use_H=True):
B,L = xyz.shape[:2]
Rs, Ts = rigid_from_3_points(xyz[...,0,:],xyz[...,1,:],xyz[...,2,:], non_ideal=non_ideal)
RTF0 = torch.eye(4).repeat(B,L,1,1).to(device=Rs.device)
# bb
RTF0[:,:,:3,:3] = Rs
RTF0[:,:,:3,3] = Ts
# omega
RTF1 = torch.einsum(
'brij,brjk,brkl->bril',
RTF0, self.RTs_in_base_frame[seq,0,:], make_rotX(alphas[:,:,0,:]))
# phi
RTF2 = torch.einsum(
'brij,brjk,brkl->bril',
RTF0, self.RTs_in_base_frame[seq,1,:], make_rotX(alphas[:,:,1,:]))
# psi
RTF3 = torch.einsum(
'brij,brjk,brkl->bril',
RTF0, self.RTs_in_base_frame[seq,2,:], make_rotX(alphas[:,:,2,:]))
# CB bend
basexyzs = self.xyzs_in_base_frame[seq]
NCr = 0.5*(basexyzs[:,:,2,:3]+basexyzs[:,:,0,:3])
CAr = (basexyzs[:,:,1,:3])
CBr = (basexyzs[:,:,4,:3])
CBrotaxis1 = (CBr-CAr).cross(NCr-CAr)
CBrotaxis1 /= torch.linalg.norm(CBrotaxis1, dim=-1, keepdim=True)+1e-8
# CB twist
NCp = basexyzs[:,:,2,:3] - basexyzs[:,:,0,:3]
NCpp = NCp - torch.sum(NCp*NCr, dim=-1, keepdim=True)/ torch.sum(NCr*NCr, dim=-1, keepdim=True) * NCr
CBrotaxis2 = (CBr-CAr).cross(NCpp)
CBrotaxis2 /= torch.linalg.norm(CBrotaxis2, dim=-1, keepdim=True)+1e-8
CBrot1 = make_rot_axis(alphas[:,:,7,:], CBrotaxis1 )
CBrot2 = make_rot_axis(alphas[:,:,8,:], CBrotaxis2 )
RTF8 = torch.einsum(
'brij,brjk,brkl->bril',
RTF0, CBrot1,CBrot2)
# chi1 + CG bend
RTF4 = torch.einsum(
'brij,brjk,brkl,brlm->brim',
RTF8,
self.RTs_in_base_frame[seq,3,:],
make_rotX(alphas[:,:,3,:]),
make_rotZ(alphas[:,:,9,:]))
# chi2
RTF5 = torch.einsum(
'brij,brjk,brkl->bril',
RTF4, self.RTs_in_base_frame[seq,4,:],make_rotX(alphas[:,:,4,:]))
# chi3
RTF6 = torch.einsum(
'brij,brjk,brkl->bril',
RTF5,self.RTs_in_base_frame[seq,5,:],make_rotX(alphas[:,:,5,:]))
# chi4
RTF7 = torch.einsum(
'brij,brjk,brkl->bril',
RTF6,self.RTs_in_base_frame[seq,6,:],make_rotX(alphas[:,:,6,:]))
RTframes = torch.stack((
RTF0,RTF1,RTF2,RTF3,RTF4,RTF5,RTF6,RTF7,RTF8
),dim=2)
xyzs = torch.einsum(
'brtij,brtj->brti',
RTframes.gather(2,self.base_indices[seq][...,None,None].repeat(1,1,1,4,4)), basexyzs
)
if use_H:
return RTframes, xyzs[...,:3]
else:
return RTframes, xyzs[...,:14,:3]
[docs]
class XYZConverter(nn.Module):
def __init__(self):
super(XYZConverter, self).__init__()
self.register_buffer("torsion_indices", torsion_indices)
self.register_buffer("torsion_can_flip", torsion_can_flip)
self.register_buffer("ref_angles", reference_angles)
self.register_buffer("tip_indices", tip_indices)
self.register_buffer("base_indices", base_indices)
self.register_buffer("RTs_in_base_frame", RTs_by_torsion)
self.register_buffer("xyzs_in_base_frame", xyzs_in_base_frame)
[docs]
def compute_all_atom(self, seq, xyz, alphas, non_ideal=True, use_H=True):
B,L = xyz.shape[:2]
seq = seq.to(torch.long)
Rs, Ts = rigid_from_3_points(xyz[...,0,:],xyz[...,1,:],xyz[...,2,:], non_ideal=non_ideal)
RTF0 = torch.eye(4).repeat(B,L,1,1).to(device=Rs.device)
# bb
RTF0[:,:,:3,:3] = Rs
RTF0[:,:,:3,3] = Ts
# omega
RTF1 = torch.einsum(
'brij,brjk,brkl->bril',
RTF0, self.RTs_in_base_frame[seq,0,:], make_rotX(alphas[:,:,0,:]))
# phi
RTF2 = torch.einsum(
'brij,brjk,brkl->bril',
RTF0, self.RTs_in_base_frame[seq,1,:], make_rotX(alphas[:,:,1,:]))
# psi
RTF3 = torch.einsum(
'brij,brjk,brkl->bril',
RTF0, self.RTs_in_base_frame[seq,2,:], make_rotX(alphas[:,:,2,:]))
# CB bend
basexyzs = self.xyzs_in_base_frame[seq]
NCr = 0.5*(basexyzs[:,:,2,:3]+basexyzs[:,:,0,:3])
CAr = (basexyzs[:,:,1,:3])
CBr = (basexyzs[:,:,4,:3])
CBrotaxis1 = (CBr-CAr).cross(NCr-CAr)
CBrotaxis1 /= torch.linalg.norm(CBrotaxis1, dim=-1, keepdim=True)+1e-8
# CB twist
NCp = basexyzs[:,:,2,:3] - basexyzs[:,:,0,:3]
NCpp = NCp - torch.sum(NCp*NCr, dim=-1, keepdim=True)/ torch.sum(NCr*NCr, dim=-1, keepdim=True) * NCr
CBrotaxis2 = (CBr-CAr).cross(NCpp)
CBrotaxis2 /= torch.linalg.norm(CBrotaxis2, dim=-1, keepdim=True)+1e-8
CBrot1 = make_rot_axis(alphas[:,:,7,:], CBrotaxis1 )
CBrot2 = make_rot_axis(alphas[:,:,8,:], CBrotaxis2 )
RTF8 = torch.einsum(
'brij,brjk,brkl->bril',
RTF0, CBrot1,CBrot2)
# chi1 + CG bend
RTF4 = torch.einsum(
'brij,brjk,brkl,brlm->brim',
RTF8,
self.RTs_in_base_frame[seq,3,:],
make_rotX(alphas[:,:,3,:]),
make_rotZ(alphas[:,:,9,:]))
# chi2
RTF5 = torch.einsum(
'brij,brjk,brkl->bril',
RTF4, self.RTs_in_base_frame[seq,4,:],make_rotX(alphas[:,:,4,:]))
# chi3
RTF6 = torch.einsum(
'brij,brjk,brkl->bril',
RTF5,self.RTs_in_base_frame[seq,5,:],make_rotX(alphas[:,:,5,:]))
# chi4
RTF7 = torch.einsum(
'brij,brjk,brkl->bril',
RTF6,self.RTs_in_base_frame[seq,6,:],make_rotX(alphas[:,:,6,:]))
RTframes = torch.stack((
RTF0,RTF1,RTF2,RTF3,RTF4,RTF5,RTF6,RTF7,RTF8
),dim=2)
xyzs = torch.einsum(
'brtij,brtj->brti',
RTframes.gather(2,self.base_indices[seq][...,None,None].repeat(1,1,1,4,4)), basexyzs
)
if use_H:
return RTframes, xyzs[...,:3]
else:
return RTframes, xyzs[...,:14,:3]
[docs]
def get_tor_mask(self, seq, mask_in=None):
B,L = seq.shape[:2]
tors_mask = torch.ones((B,L,10), dtype=torch.bool, device=seq.device)
tors_mask[...,3:7] = self.torsion_indices[seq,:,-1] > 0
tors_mask[:,0,1] = False
tors_mask[:,-1,0] = False
# mask for additional angles
tors_mask[:,:,7] = seq!=aa2num['GLY']
tors_mask[:,:,8] = seq!=aa2num['GLY']
tors_mask[:,:,9] = torch.logical_and( seq!=aa2num['GLY'], seq!=aa2num['ALA'] )
tors_mask[:,:,9] = torch.logical_and( tors_mask[:,:,9], seq!=aa2num['UNK'] )
tors_mask[:,:,9] = torch.logical_and( tors_mask[:,:,9], seq!=aa2num['MAS'] )
if mask_in != None:
# mask for missing atoms
# chis
ti0 = torch.gather(mask_in,2,self.torsion_indices[seq,:,0])
ti1 = torch.gather(mask_in,2,self.torsion_indices[seq,:,1])
ti2 = torch.gather(mask_in,2,self.torsion_indices[seq,:,2])
ti3 = torch.gather(mask_in,2,self.torsion_indices[seq,:,3])
#is_valid = torch.stack((ti0, ti1, ti2, ti3), dim=-2).all(dim=-1) # bug.... (2023 Feb 24 fixed)
is_valid = torch.stack((ti0, ti1, ti2, ti3), dim=-1).all(dim=-1)
tors_mask[...,3:7] = torch.logical_and(tors_mask[...,3:7], is_valid)
tors_mask[:,:,7] = torch.logical_and(tors_mask[:,:,7], mask_in[:,:,4]) # CB exist?
tors_mask[:,:,8] = torch.logical_and(tors_mask[:,:,8], mask_in[:,:,4]) # CB exist?
tors_mask[:,:,9] = torch.logical_and(tors_mask[:,:,9], mask_in[:,:,5]) # XG exist?
return tors_mask
[docs]
def get_torsions(self, xyz_in, seq, mask_in=None):
B,L = xyz_in.shape[:2]
tors_mask = self.get_tor_mask(seq, mask_in)
# torsions to restrain to 0 or 180degree
tors_planar = torch.zeros((B, L, 10), dtype=torch.bool, device=xyz_in.device)
tors_planar[:,:,5] = (seq == aa2num['TYR']) # TYR chi 3 should be planar
# idealize given xyz coordinates before computing torsion angles
xyz = xyz_in.clone()
Rs, Ts = rigid_from_3_points(xyz[...,0,:],xyz[...,1,:],xyz[...,2,:])
Nideal = torch.tensor([-0.5272, 1.3593, 0.000], device=xyz_in.device)
Cideal = torch.tensor([1.5233, 0.000, 0.000], device=xyz_in.device)
xyz[...,0,:] = torch.einsum('brij,j->bri', Rs, Nideal) + Ts
xyz[...,2,:] = torch.einsum('brij,j->bri', Rs, Cideal) + Ts
torsions = torch.zeros( (B,L,10,2), device=xyz.device )
# avoid undefined angles for H generation
torsions[:,0,1,0] = 1.0
torsions[:,-1,0,0] = 1.0
# omega
torsions[:,:-1,0,:] = th_dih(xyz[:,:-1,1,:],xyz[:,:-1,2,:],xyz[:,1:,0,:],xyz[:,1:,1,:])
# phi
torsions[:,1:,1,:] = th_dih(xyz[:,:-1,2,:],xyz[:,1:,0,:],xyz[:,1:,1,:],xyz[:,1:,2,:])
# psi
torsions[:,:,2,:] = -1 * th_dih(xyz[:,:,0,:],xyz[:,:,1,:],xyz[:,:,2,:],xyz[:,:,3,:])
# chis
ti0 = torch.gather(xyz,2,self.torsion_indices[seq,:,0,None].repeat(1,1,1,3))
ti1 = torch.gather(xyz,2,self.torsion_indices[seq,:,1,None].repeat(1,1,1,3))
ti2 = torch.gather(xyz,2,self.torsion_indices[seq,:,2,None].repeat(1,1,1,3))
ti3 = torch.gather(xyz,2,self.torsion_indices[seq,:,3,None].repeat(1,1,1,3))
torsions[:,:,3:7,:] = th_dih(ti0,ti1,ti2,ti3)
# CB bend
NC = 0.5*( xyz[:,:,0,:3] + xyz[:,:,2,:3] )
CA = xyz[:,:,1,:3]
CB = xyz[:,:,4,:3]
t = th_ang_v(CB-CA,NC-CA)
t0 = self.ref_angles[seq][...,0,:]
torsions[:,:,7,:] = torch.stack(
(torch.sum(t*t0,dim=-1),t[...,0]*t0[...,1]-t[...,1]*t0[...,0]),
dim=-1 )
# CB twist
NCCA = NC-CA
NCp = xyz[:,:,2,:3] - xyz[:,:,0,:3]
NCpp = NCp - torch.sum(NCp*NCCA, dim=-1, keepdim=True)/ torch.sum(NCCA*NCCA, dim=-1, keepdim=True) * NCCA
t = th_ang_v(CB-CA,NCpp)
t0 = self.ref_angles[seq][...,1,:]
torsions[:,:,8,:] = torch.stack(
(torch.sum(t*t0,dim=-1),t[...,0]*t0[...,1]-t[...,1]*t0[...,0]),
dim=-1 )
# CG bend
CG = xyz[:,:,5,:3]
t = th_ang_v(CG-CB,CA-CB)
t0 = self.ref_angles[seq][...,2,:]
torsions[:,:,9,:] = torch.stack(
(torch.sum(t*t0,dim=-1),t[...,0]*t0[...,1]-t[...,1]*t0[...,0]),
dim=-1 )
tors_mask *= (~torch.isnan(torsions[...,0]))
tors_mask *= (~torch.isnan(torsions[...,1]))
torsions = torch.nan_to_num(torsions)
# alt chis
torsions_alt = torsions.clone()
torsions_alt[self.torsion_can_flip[seq,:]] *= -1
return torsions, torsions_alt, tors_mask, tors_planar