import torch.utils.checkpoint as checkpoint
from miniworld.models_MiniWorld_v1_5_use_interaction.Attention_module import *
from miniworld.utils.kinematics import normQ, Qs2Rs
# Components for three-track blocks
# 1. MSA -> MSA update (biased attention. bias from pair & structure)
# 2. Pair -> Pair update (biased attention. bias from structure)
# 3. MSA -> Pair update (extract coevolution signal)
# 4. Str -> Str update (node from MSA, edge from Pair)
# Module contains classes and functions to generate initial embeddings
[docs]
class SeqSep(nn.Module):
# Add relative positional encoding to pair features
def __init__(self, d_model, minpos=-32, maxpos=32):
super(SeqSep, self).__init__()
self.minpos = minpos
self.maxpos = maxpos
self.nbin = abs(minpos)+maxpos+1
self.emb = nn.Embedding(self.nbin, d_model)
[docs]
def forward(self, idx):
# idx: (B, L)
B, L = idx.shape[:2]
bins = torch.arange(self.minpos, self.maxpos, device=idx.device)
seqsep = torch.full((B,L,L), 100, dtype=idx.dtype, device=idx.device)
seqsep = idx[:,None,:] - idx[:,:,None] # (B, L, L)
#
ib = torch.bucketize(seqsep, bins).long() # (B, L, L)
emb = self.emb(ib) #(B, L, L, d_model)
return emb
# Update MSA with biased self-attention. bias from Pair & Str
[docs]
class MSAPairStr2MSA(nn.Module):
def __init__(self, d_msa=256, d_pair=128, n_head=8, d_state=16, d_rbf=64,
d_hidden=32, p_drop=0.15, use_global_attn=False):
super(MSAPairStr2MSA, self).__init__()
self.norm_pair = nn.LayerNorm(d_pair)
self.norm_state = nn.LayerNorm(d_state)
self.proj_state = nn.Linear(d_state, d_msa)
self.drop_row = Dropout(broadcast_dim=1, p_drop=p_drop)
self.row_attn = MSARowAttentionWithBias(d_msa=d_msa, d_pair=d_pair,
n_head=n_head, d_hidden=d_hidden)
self.col_attn = MSAColAttention(d_msa=d_msa, n_head=n_head, d_hidden=d_hidden)
self.ff = FeedForwardLayer(d_msa, 4, p_drop=p_drop)
# Do proper initialization
self.reset_parameter()
[docs]
def reset_parameter(self):
# initialize weights to normal distrib
self.proj_state = init_lecun_normal(self.proj_state)
# initialize bias to zeros
nn.init.zeros_(self.proj_state.bias)
[docs]
def forward(self, msa, pair, state, symmids, symmsub):
"""Update MSA with biased self-attention, using bias from Pair & Str.
:param msa: MSA feature tensor.
:type msa: torch.Tensor of shape (B, N, L, d_msa)
:param pair: Pair feature tensor.
:type pair: torch.Tensor of shape (B, L, L, d_pair)
:param state: Updated node features from the SE(3) layer.
:type state: torch.Tensor of shape (B, L, d_state)
:param symmids: Symmetry IDs.
:type symmids: torch.Tensor
:param symmsub: Symmetry sub-units.
:type symmsub: torch.Tensor
:return: Updated MSA feature tensor.
:rtype: torch.Tensor of shape (B, N, L, d_msa)
"""
B, N, L, _ = msa.shape
# prepare input bias feature by combining pair & coordinate info
pair = self.norm_pair(pair)
#
# update query sequence feature (first sequence in the MSA) with feedbacks (state) from SE3
state = self.norm_state(state)
state = self.proj_state(state).reshape(B, 1, L, -1)
msa = msa.type_as(state)
msa = msa.index_add(1, torch.tensor([0,], device=state.device), state) # (B, N, L, d_msa)
#
# Apply row/column attention to msa & transform
msa = msa + self.drop_row(self.row_attn(msa, pair)) # (B, N, L+1, d_msa)
msa = msa + self.col_attn(msa) # (B, N, L+1, d_msa)
msa = msa + self.ff(msa)
return msa
[docs]
class MSA2Pair(nn.Module):
def __init__(self, d_msa=256, d_pair=128, d_hidden=32, p_drop=0.15):
super(MSA2Pair, self).__init__()
self.norm = nn.LayerNorm(d_msa)
self.proj_left = nn.Linear(d_msa, d_hidden)
self.proj_right = nn.Linear(d_msa, d_hidden)
self.proj_out = nn.Linear(d_hidden*d_hidden, d_pair)
self.d_hidden = d_hidden
self.reset_parameter()
[docs]
def reset_parameter(self):
# normal initialization
self.proj_left = init_lecun_normal(self.proj_left)
self.proj_right = init_lecun_normal(self.proj_right)
nn.init.zeros_(self.proj_left.bias)
nn.init.zeros_(self.proj_right.bias)
# zero initialize output
nn.init.zeros_(self.proj_out.weight)
nn.init.zeros_(self.proj_out.bias)
[docs]
def forward(self, msa, pair, symmids, symmsub):
B, N, L = msa.shape[:3]
msa = self.norm(msa)
left = self.proj_left(msa)
right = self.proj_right(msa)
right = right/float(N)
out = einsum('bsli,bsmj->blmij', left, right).reshape(B, L, L, -1)
out = self.proj_out(out)
pair = pair + out
return pair
[docs]
class PairStr2Pair(nn.Module):
def __init__(self, d_pair=128, n_head=4, d_hidden=32, d_hidden_state=16, d_rbf=64, d_state=32, p_drop=0.15):
super(PairStr2Pair, self).__init__()
self.drop_row = Dropout(broadcast_dim=1, p_drop=p_drop)
self.drop_col = Dropout(broadcast_dim=2, p_drop=p_drop)
self.tri_mul_out = TriangleMultiplication(d_pair, d_hidden=d_hidden)
self.tri_mul_in = TriangleMultiplication(d_pair, d_hidden, outgoing=False)
self.ff = FeedForwardLayer(d_pair, 2)
# perform a striped p2p op
[docs]
def subblock(self, OP, pair, rbf_feat, crop):
N,L = pair.shape[:2]
nbox = (L-1)//(crop//2)+1
idx = torch.triu_indices(nbox,nbox,1, device=pair.device)
ncrops = idx.shape[1]
pairnew = torch.zeros((N,L*L,pair.shape[-1]), device=pair.device, dtype=pair.dtype)
countnew = torch.zeros((N,L*L), device=pair.device, dtype=torch.int)
for i in range(ncrops):
# reindex sub-blocks
offsetC = torch.clamp( (1+idx[1,i:(i+1)])*(crop//2)-L, min=0 ) # account for going past L
offsetN = torch.zeros_like(offsetC)
mask = (offsetC>0)*((idx[0,i]+1)==idx[1,i])
offsetN[mask] = offsetC[mask]
pairIdx = torch.zeros((1,crop), dtype=torch.long, device=pair.device)
pairIdx[:,:(crop//2)] = torch.arange(crop//2, dtype=torch.long, device=pair.device)+idx[0,i:(i+1),None]*(crop//2) - offsetN[:,None]
pairIdx[:,(crop//2):] = torch.arange(crop//2, dtype=torch.long, device=pair.device)+idx[1,i:(i+1),None]*(crop//2) - offsetC[:,None]
# do reindexing
iL,iU = pairIdx[:,:,None], pairIdx[:,None,:]
paircrop = pair[:,iL,iU,:].reshape(-1,crop,crop,pair.shape[-1])
rbfcrop = rbf_feat[:,iL,iU,:].reshape(-1,crop,crop,rbf_feat.shape[-1])
# row attn
paircrop = OP(paircrop, rbfcrop).to(pair.dtype)
# unindex
iUL = (iL*L+iU).flatten()
pairnew.index_add_(1,iUL, paircrop.reshape(N,iUL.shape[0],pair.shape[-1]))
countnew.index_add_(1,iUL, torch.ones((N,iUL.shape[0]), device=pair.device, dtype=torch.int))
return pair + (pairnew/countnew[...,None]).reshape(N,L,L,-1)
[docs]
def forward(self, pair, crop=-1, use_species = True, symmids=None, symmsub=None):
B,L = pair.shape[:2]
crop = 2*(crop//2) # make sure even
if (crop>0 and crop<=L):
pair = self.subblock(
lambda x,y:self.drop_row(self.tri_mul_out(x)), #, symmids)),
pair, crop
)
pair = self.subblock(
lambda x,y:self.drop_row(self.tri_mul_in(x)), #, symmids)),
pair, crop
)
pair = self.subblock(
lambda x,y:self.drop_row(self.row_attn(x,y)), #, symmids)),
pair, crop
)
pair = self.subblock(
lambda x,y:self.drop_col(self.col_attn(x,y)), #, symmids)),
pair, crop
)
# feed forward layer
RESSTRIDE = 16384//L
for i in range((L-1)//RESSTRIDE+1):
r_i,r_j = i*RESSTRIDE, min((i+1)*RESSTRIDE,L)
pair[:,r_i:r_j] = pair[:,r_i:r_j] + self.ff(pair[:,r_i:r_j])
else:
#_nc = lambda x:torch.sum(torch.isnan(x))
pair = pair + self.drop_row(self.tri_mul_out(pair))
pair = pair + self.drop_row(self.tri_mul_in(pair))
pair = pair + self.ff(pair)
return pair
[docs]
class SCPred(nn.Module):
def __init__(self, d_msa=256, d_state=32, d_hidden=128, p_drop=0.15):
super(SCPred, self).__init__()
self.norm_s0 = nn.LayerNorm(d_msa)
self.norm_si = nn.LayerNorm(d_state)
self.linear_s0 = nn.Linear(d_msa, d_hidden)
self.linear_si = nn.Linear(d_state, d_hidden)
# ResNet layers
self.linear_1 = nn.Linear(d_hidden, d_hidden)
self.linear_2 = nn.Linear(d_hidden, d_hidden)
self.linear_3 = nn.Linear(d_hidden, d_hidden)
self.linear_4 = nn.Linear(d_hidden, d_hidden)
# Final outputs
self.linear_out = nn.Linear(d_hidden, 20)
self.reset_parameter()
[docs]
def reset_parameter(self):
# normal initialization
self.linear_s0 = init_lecun_normal(self.linear_s0)
self.linear_si = init_lecun_normal(self.linear_si)
self.linear_out = init_lecun_normal(self.linear_out)
nn.init.zeros_(self.linear_s0.bias)
nn.init.zeros_(self.linear_si.bias)
nn.init.zeros_(self.linear_out.bias)
# right before relu activation: He initializer (kaiming normal)
nn.init.kaiming_normal_(self.linear_1.weight, nonlinearity='relu')
nn.init.zeros_(self.linear_1.bias)
nn.init.kaiming_normal_(self.linear_3.weight, nonlinearity='relu')
nn.init.zeros_(self.linear_3.bias)
# right before residual connection: zero initialize
nn.init.zeros_(self.linear_2.weight)
nn.init.zeros_(self.linear_2.bias)
nn.init.zeros_(self.linear_4.weight)
nn.init.zeros_(self.linear_4.bias)
[docs]
def forward(self, seq, state):
"""Predict side-chain torsion angles along with backbone torsions.
:param seq: Hidden embeddings corresponding to the query sequence.
:type seq: torch.Tensor of shape (B, L, d_msa)
:param state: State features (output l0 feature) from the previous SE(3) layer.
:type state: torch.Tensor of shape (B, L, d_state)
:return: Predicted torsion angles (phi, psi, omega, chi1-4 with cos/sin, Cb bend, Cb twist, CG).
:rtype: torch.Tensor of shape (B, L, 10, 2)
"""
B, L = seq.shape[:2]
seq = self.norm_s0(seq)
state = self.norm_si(state)
si = self.linear_s0(seq) + self.linear_si(state[:,1:])
si = si + self.linear_2(F.relu_(self.linear_1(F.relu_(si))))
si = si + self.linear_4(F.relu_(self.linear_3(F.relu_(si))))
si = self.linear_out(F.relu_(si))
return si.view(B, L, 10, 2)
[docs]
def update_symm_Rs(Rs, Ts, Lasu, symmsub_in, symmsub, symmRs):
def dist_error(R0,T0,Rs,Ts):
B = Ts.shape[0]
Tcom = Ts[:,:Lasu].mean(dim=1,keepdim=True)
Tcorr = torch.einsum('ij,brj->bri', R0, Ts[:,:Lasu]-Tcom) + Tcom + 10.0*T0[None,None,:]
Xsymm = torch.einsum('sij,brj->bsri', symmRs[symmsub], Tcorr).reshape(B,-1,3)
Xtrue = Ts
dsymm = torch.linalg.norm(Xsymm[:,:,None]-Xsymm[:,None,:], dim=-1)
dtrue = torch.linalg.norm(Xtrue[:,:,None]-Xtrue[:,None,:], dim=-1)
return torch.clamp( torch.abs(dsymm-dtrue), max=10.0).mean()
B = Ts.shape[0]
# symmetry correction 1: don't let COM (of entire complex) move
Tmean = Ts[:,:Lasu].reshape(-1,3).mean(dim=0)
Tmean = torch.einsum('sij,j->si', symmRs, Tmean).mean(dim=0)
Ts = Ts - Tmean
# symmetry correction 2: use minimization to minimize drms
#with torch.enable_grad():
# T0 = torch.zeros(3,device=Ts.device).requires_grad_(True)
# R0 = torch.eye(3,device=Ts.device).requires_grad_(True)
# opt = torch.optim.SGD([T0,R0], lr=0.001)
#
# if (dist_error(R0,T0,symmRs[symmsub],Ts)>0.5):
# for e in range(101):
# loss = dist_error(R0,T0,symmRs[symmsub],Ts)
# if (e%50 == 0):
# print (e,loss)
# opt.zero_grad()
# loss.backward()
# opt.step()
#
#Tcom = Ts[:,:Lasu].mean(dim=1,keepdim=True)
#Ts = torch.einsum('ij,brj->bri', R0, Ts[:,:Lasu]-Tcom) +Tcom + 10.0*T0[None,None,:]
Rs = torch.einsum('sij,brjk,slk->bsril', symmRs[symmsub], Rs[:,:Lasu], symmRs[symmsub_in])
Ts = torch.einsum('sij,brj->bsri', symmRs[symmsub], Ts[:,:Lasu])
Rs = Rs.reshape(B,-1,3,3) # (B,S,L,3,3)
Ts = Ts.reshape(B,-1,3) # (B,S,L,3,3)
return Rs, Ts
[docs]
def update_symm_subs(Rs, Ts, pair, symmids, symmsub_in, symmsub, symmRs, metasymm):
B,Ls = Ts.shape[0:2]
Osub = symmsub.shape[0]
L = Ls//Osub
com = Ts[:,:L].sum(dim=-2)
rcoms = torch.einsum('sij,bj->si', symmRs, com)
subsymms, nneighs = metasymm
symmsub_new = []
for i in range(len(subsymms)):
drcoms = torch.linalg.norm(rcoms[0,:] - rcoms[subsymms[i],:], dim=-1)
_,subs_i = torch.topk(drcoms,nneighs[i],largest=False)
subs_i,_ = torch.sort( subsymms[i][subs_i] )
symmsub_new.append(subs_i)
symmsub_new = torch.cat(symmsub_new)
s_old = symmids[symmsub[:,None],symmsub[None,:]]
s_new = symmids[symmsub_new[:,None],symmsub_new[None,:]]
# remap old->new
# a) find highest-magnitude patches
pairsub = dict()
pairmag = dict()
for i in range(Osub):
for j in range(Osub):
idx_old = s_old[i,j].item()
sub_ij = pair[:,i*L:(i+1)*L,j*L:(j+1)*L,:].clone()
mag_ij = torch.max(sub_ij.flatten()) #torch.norm(sub_ij.flatten())
if idx_old not in pairsub or mag_ij > pairmag[idx_old]:
pairmag[idx_old] = mag_ij
pairsub[idx_old] = (i,j) #sub_ij
# b) reindex
idx = torch.zeros((Osub*L,Osub*L),dtype=torch.long,device=pair.device)
idx = (
torch.arange(Osub*L,device=pair.device)[:,None]*Osub*L
+ torch.arange(Osub*L,device=pair.device)[None,:]
)
for i in range(Osub):
for j in range(Osub):
idx_new = s_new[i,j].item()
if idx_new in pairsub:
inew,jnew = pairsub[idx_new]
idx[i*L:(i+1)*L,j*L:(j+1)*L] = (
Osub*L*torch.arange(inew*L,(inew+1)*L)[:,None]
+ torch.arange(jnew*L,(jnew+1)*L)[None,:]
)
pair = pair.view(1,-1,pair.shape[-1])[:,idx.flatten(),:].view(1,Osub*L,Osub*L,pair.shape[-1])
#for i in range(Osub):
# for j in range(Osub):
# idx_new = s_new[i,j].item()
# if idx_new in pairsub:
# pass
# #pair[:,i*L:(i+1)*L,j*L:(j+1)*L,:] = pairsub[idx_new] #/pairmag[idx_new]
#if (torch.any(symmsub!=symmsub_new)):
# print (symmsub,'->',symmsub_new)
if symmsub_in is not None and symmsub_in.shape[0]>1:
Rs, Ts = update_symm_Rs(Rs, Ts, L, symmsub_in, symmsub_new, symmRs)
return Rs, Ts, pair, symmsub_new
[docs]
class IPA(nn.Module):
def __init__(self, d_seq=256, d_pair=128, d_state=16,
n_head=12, d_hidden = 16,
n_query_points = 4, n_point_values = 8, top_k =64,
p_drop=0.1):
super(IPA, self).__init__()
self.n_head = n_head
self.dim = d_hidden
self.n_query_points = n_query_points
self.n_point_values = n_point_values
self.top_k = top_k
self.w_C = math.sqrt(2 / (9 * n_query_points))
# self.w_L = math.sqrt(1 / 3)
self.w_L = 1.0
self.to_value_for_seq = nn.Linear(d_seq, n_head*d_hidden, bias=False)
# self.to_query_for_state = nn.Linear(d_state, n_head * n_query_points * 3, bias=False)
# self.to_key_for_state = nn.Linear(d_state, n_head * n_query_points * 3, bias=False)
self.to_value_for_state = nn.Linear(d_state, n_head * n_point_values * 3, bias=False)
# self.to_q_for_pair = nn.Linear(d_pair, n_head*d_hidden, bias=False)
# self.to_k_for_pair = nn.Linear(d_pair, n_head*d_hidden, bias=False)
self.to_bias_for_pair = nn.Linear(d_pair, n_head, bias=False)
# self.gamma = nn.Parameter(torch.ones(1))
out_dim = n_head*d_pair + n_head*d_hidden + n_head*n_point_values*3 + n_head*n_point_values
# self.linear_seq_out = nn.Linear(out_dim, d_seq)
self.linear_state_out = nn.Linear(out_dim, d_state)
self.reset_parameter()
[docs]
def reset_parameter(self):
# normal initialization
# nn.init.xavier_uniform_(self.to_q_for_pair.weight)
# nn.init.xavier_uniform_(self.to_k_for_pair.weight)
self.to_value_for_seq = init_lecun_normal(self.to_value_for_seq)
# self.to_query_for_state = init_lecun_normal(self.to_query_for_state)
# self.to_key_for_state = init_lecun_normal(self.to_key_for_state)
self.to_value_for_state = init_lecun_normal(self.to_value_for_state)
self.to_bias_for_pair = init_lecun_normal(self.to_bias_for_pair)
# self.linear_seq_out = init_lecun_normal(self.linear_seq_out)
self.linear_state_out = init_lecun_normal(self.linear_state_out)
# with torch.no_grad():
# softplus_inverse_1 = 0.541324854612918
# self.gamma.fill_(softplus_inverse_1)
[docs]
def apply_RT(self, Rs, Ts, vec):
# Rs : (B, L, 3, 3)
# Ts : (B, L, 3)
# vec : (B, L, h, p, 3)
B, L, h, p, _ = vec.shape
Ts = Ts.unsqueeze(-2).unsqueeze(-2) # (B, L, 1, 1, 3)
Ts = Ts.expand(-1,-1,h,p,-1) # (B, L, h, p, 3)
return einsum('blij,blhpj->blhpi', Rs, vec) + Ts
[docs]
def inverse_RT(self, Rs, Ts, vec):
# Rs : (B, L, 3, 3)
# Ts : (B, L, 3)
# vec : (B, L, h, p, 3)
B, L, h, p, _ = vec.shape
Rs = Rs.transpose(-1,-2) # (B, L, 3, 3)
Ts = - einsum('blij,blj->bli', Rs, Ts) # (B, L, 3)
return self.apply_RT(Rs, Ts, vec)
[docs]
def gather_edges_3d(self, edges, neighbor_idx):
# Features (B, L, L, h, *) ==> (B, L, K, K, h, *)
# neighbor_idx (B, L, K, h)
B, L, K, h = neighbor_idx.shape
indices_3d = neighbor_idx.unsqueeze(2).repeat(1,1,K,1,1) # (B, L, K-new, K, h)
flat_indices = indices_3d.view(B, L, -1, h) # (B, L, K*K, h)
flat_indices = flat_indices.unsqueeze(-1).expand(-1,-1,-1,-1,edges.size(-1)) # (B, L, K*K, h, *)
gathered_edges = torch.gather(edges, 2, flat_indices) # (B, L, K*K, h, *)
return gathered_edges.reshape(B, L, K, K, h, -1)
[docs]
def gather_edges(self, edges, neighbor_idx):
# Features [B,L,L,h,C] at Neighbor indices [B,L,K,h] => Neighbor features [B,L,K,h,C]
neighbors = neighbor_idx.unsqueeze(-1).expand(-1, -1, -1, -1, edges.size(-1))
edge_features = torch.gather(edges, 2, neighbors)
return edge_features
[docs]
def forward(self, seq, pair, state, Rs, Ts, chain_mask):
B, L = state.shape[:2]
# getvalue for seq
value_seq = self.to_value_for_seq(seq).reshape(B, L, self.n_head, self.dim)
# get value for state
value_state = self.to_value_for_state(state).reshape(B, L, self.n_head, self.n_point_values, 3)
# get bias for pair
bias_pair = self.to_bias_for_pair(pair).reshape(B, L, L, self.n_head) # (B, L, L, n_head)
attn = bias_pair # (B, L, L, n_head)
attn = F.softmax(attn, dim=2) # (B, L, L, n_head)
# pair_qk_test = pair_qk_test[:,:,:,:4,:].view(B, L, L, -1) # (B, L, L, 4*h)
# saving_dir = "saved_models/MiniWorld_v1_0_with_template/attn_wo_strAttn/"
# draw_attn(pair_qk, saving_dir + "neighbor_qk.png")
# draw_attn_multirow(pair_qk_test, 4, saving_dir + "neighbor_qk_each.png")
# draw_attn(bias_pair, saving_dir + "bias_pair.png")
# draw_attn(-structural_attn, saving_dir + "strructural_attn.png")
# draw_attn(attn, saving_dir + "final_attn.png") d
# breakpoint()
o_tilde = einsum('bijh,bijd->bihd', attn, pair) # (B, L, n_head, d_pair)
o = einsum('bijh,bjhd->bihd', attn, value_seq) # (B, L, n_head, d_hidden)
value_state = self.apply_RT(Rs, Ts, value_state) # (B, L, n_head, n_point_values, 3)
value_state = einsum('bijh,bjhpd->bihpd', attn, value_state) # (B, L, n_head, n_point_values, 3)
o_vec = self.inverse_RT(Rs, Ts, value_state) # (B, L, n_head, n_point_values, 3)
o_vec_norm = torch.norm(o_vec, dim=-1) # (B, L, n_head, n_point_values)
# flatten to concat
o_tilde = o_tilde.reshape(B, L, -1) # (B, L, n_head*d_pair)
# flatten to concat
o_tilde = o_tilde.reshape(B, L, -1) # (B, L, n_head*d_pair)
o = o.reshape(B, L, -1) # (B, L, n_head*d_hidden)
o_vec = o_vec.reshape(B, L, -1) # (B, L, n_head*n_point_values*3)
o_vec_norm = o_vec_norm.reshape(B, L, -1) # (B, L, n_head*n_point_values)
o_concat = torch.cat((o_tilde, o, o_vec, o_vec_norm), dim=-1) # (B, L, n_head*d_pair + n_head*d_hidden + n_head*n_point_values*3 + n_head*n_point_values)
# seq_out = self.linear_seq_out(o_concat) # (B, L, d_seq)
state_out = self.linear_state_out(o_concat) # (B, L, d_state)
return state_out
[docs]
class Str2Chain_IPA(nn.Module):
def __init__(self, d_msa=256, d_pair=128, d_state=16,
n_head=8, d_hidden = 32,
n_query_points = 4, n_point_values = 8,
p_drop=0.1):
super(Str2Chain_IPA, self).__init__()
# LayerNorm
self.norm_msa = nn.LayerNorm(d_msa)
self.norm_pair = nn.LayerNorm(d_pair)
self.norm_state = nn.LayerNorm(d_state)
self.IPA = IPA(d_seq=d_msa, d_pair=d_pair, d_state=d_state,
n_head=n_head, d_hidden=d_hidden,
n_query_points=n_query_points, n_point_values=n_point_values,
p_drop=p_drop)
state_linear_list = []
for _ in range(3):
state_linear_list.append(nn.Linear(d_state, d_state))
state_linear_list.append(nn.ReLU())
self.state_linear = nn.Sequential(*state_linear_list)
self.state_to_RT = nn.Linear(d_state, 3 + 3)
# self.state_gate = nn.Linear(d_state, 1)
self.reset_parameter()
[docs]
def reset_parameter(self):
# initialize weights to normal distribution
for layer in self.state_linear:
if isinstance(layer, nn.Linear):
layer = init_lecun_normal(layer)
# zero initialize
nn.init.zeros_(self.state_to_RT.weight)
nn.init.zeros_(self.state_to_RT.bias)
# gating: zero weights, one biases (mostly open gate at the begining)
# nn.init.zeros_(self.state_gate.weight)
# nn.init.ones_(self.state_gate.bias)
[docs]
def forward(self, msa, pair_in, R_in, T_in, state, chain_mask):
# msa : (B, N, L, d_msa)
# pair : (B, L, L, d_pair)
# xyz : (B, L, 3, 3)
B, N, L = msa.shape[:3]
msa = self.norm_msa(msa)
pair_in = self.norm_pair(pair_in)
state = self.norm_state(state)
seq = msa[:,0,1:] # (B, L, d_msa)
state_out = self.IPA(seq, pair_in[:,1:,1:], state[:,1:], R_in, T_in, chain_mask)
state_out = F.pad(state_out, (0,0,1,0), 'constant', 0) # (B, L+1, d_state)
# seq = seq + seq_out # (B, L, d_msa)
state = state + state_out # (B, L, d_state)
state = state + self.state_linear(state)
state = self.norm_state(state)
RT_vec = self.state_to_RT(state[:,1:]) # (B, L, 3 + 3)
Q_chain = RT_vec[:,:,:3] # (B, L, 3)
T_chain = RT_vec[:,:,3:] # (B, L, 3)
# gate = torch.sigmoid(self.state_gate(state)) # (B, L, 1)
# chain_mask : (B,L,L) 1 for same chain
# gate = chain_mask # (B, L, L)
chain_mask = chain_mask / torch.sum(chain_mask, dim=-1, keepdim=True) # normalize to sum to 1
T_chain = torch.einsum('bij,bjc->bic', chain_mask, T_chain) # (B,L,3)
Q_chain = torch.einsum('bij,bjc->bic', chain_mask, Q_chain) # (B,L,3)
Q_chain = torch.cat((torch.ones((B,L-1,1),device=Q_chain.device),Q_chain),dim=-1) # (B,L,4)
# breakpoint()
Q_chain = normQ(Q_chain)
R_chain = Qs2Rs(Q_chain)
R_out = einsum('bnij,bnjk->bnik', R_in, R_chain)
T_out = einsum('bnij,bnj->bni', R_in, T_chain) + T_in
return R_out, T_out, state
[docs]
class Str2Residue_IPA(nn.Module):
def __init__(self,
d_msa=256, d_pair=128, d_state=16, d_rbf = 64,
n_head=8, d_hidden = 32,
n_query_points = 4, n_point_values = 8,
p_drop=0.1):
super(Str2Residue_IPA, self).__init__()
self.norm_msa = nn.LayerNorm(d_msa)
self.norm_pair = nn.LayerNorm(d_pair)
self.norm_state = nn.LayerNorm(d_state)
self.IPA = IPA(d_seq=d_msa, d_pair=d_pair, d_state=d_state,
n_head=n_head, d_hidden=d_hidden,
n_query_points=n_query_points, n_point_values=n_point_values,
p_drop=p_drop)
state_linear_list = []
for _ in range(3):
state_linear_list.append(nn.Linear(d_state, d_state))
state_linear_list.append(nn.ReLU())
self.state_linear = nn.Sequential(*state_linear_list)
self.state_to_RT = nn.Linear(d_state, 3 + 3)
self.sc_predictor = SCPred(d_msa=d_msa, d_state=d_state,
p_drop=p_drop)
self.reset_parameter()
[docs]
def reset_parameter(self):
# initialize weights to normal distribution
for layer in self.state_linear:
if isinstance(layer, nn.Linear):
layer = init_lecun_normal(layer)
# zero initialize
nn.init.zeros_(self.state_to_RT.weight)
nn.init.zeros_(self.state_to_RT.bias)
[docs]
def forward(self, msa, pair_in, R_in, T_in, state, chain_mask):
# msa : (B, N, L+1, d_msa)
# pair : (B, L+1, L+1, d_pair)
# xyz : (B, L, 3, 3)
B, N, L = msa.shape[:3]
msa = self.norm_msa(msa)
pair_in = self.norm_pair(pair_in)
state = self.norm_state(state)
seq = msa[:,0,1:] # (B, L, d_msa)
state_out = self.IPA(seq, pair_in[:,1:,1:], state[:,1:], R_in, T_in, chain_mask)
state_out = F.pad(state_out, (0,0,1,0), 'constant', 0) # (B, L+1, d_state)
# seq = seq + seq_out # (B, L, d_msa)
state = state + state_out # (B, L, d_state)
state = state + self.state_linear(state)
state = self.norm_state(state)
RT_vec = self.state_to_RT(state[:, 1:]) # (B, L, 3 + 3)
Qs = RT_vec[:,:,:3] # (B, L, 3)
T_residue = RT_vec[:,:,3:] # (B, L, 3)
Qs = torch.cat((torch.ones((B,L-1,1),device=Qs.device),Qs),dim=-1)
Qs = normQ(Qs)
R_residue = Qs2Rs(Qs)
alpha = self.sc_predictor(seq, state)
R_out = einsum('bnij,bnjk->bnik', R_in, R_residue)
T_out = einsum('bnij,bnj->bni', R_in, T_residue) + T_in
# breakpoint()
return R_out, T_out, state, alpha, pair_in
[docs]
class IterBlock(nn.Module):
def __init__(self,
value_net,
d_msa=256, d_pair=128, d_rbf=64,
n_head_msa=8, n_head_pair=4,
use_global_attn=False,
d_hidden=32, d_hidden_msa=None, p_drop=0.15,
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}):
super(IterBlock, self).__init__()
if d_hidden_msa == None:
d_hidden_msa = d_hidden
self.msa2msa = MSAPairStr2MSA(d_msa=d_msa, d_pair=d_pair,
n_head=n_head_msa,
d_state=SE3_param['l0_out_features'],
use_global_attn=use_global_attn,
d_hidden=d_hidden_msa, p_drop=p_drop)
self.msa2pair = MSA2Pair(d_msa=d_msa, d_pair=d_pair,
d_hidden=d_hidden//2, p_drop=p_drop)
self.pair2pair = PairStr2Pair(d_pair=d_pair, n_head=n_head_pair, d_state=SE3_param['l0_out_features'],
d_hidden=d_hidden, p_drop=p_drop)
self.str2residue = Str2Residue_IPA(
d_msa=d_msa, d_pair=d_pair, d_rbf = d_rbf,
d_state=SE3_param['l0_out_features'],
p_drop=p_drop)
self.value_net = value_net
self.update_pair_by_value = UpdatePairByValue(d_pair=d_pair, d_rbf=d_rbf)
[docs]
def forward(self, seq, msa, pair, R_in, T_in, xyz_in, state, idx, chain_mask, symmids, symmsub_in, symmsub, symmRs, symmmeta, topk=0, crop=-1):
# msa : (B, N, L, d_msa)
# pair : (B, L, L, d_pair)
# xyz : (B, L, 3, 3)
B, L = pair.shape[:2]
# xyzfull = xyz.view(1,B*L,3,3)
# rbf_feat = rbf(
# torch.cdist(xyzfull[:,:,1,:], xyzfull[:,:L,1,:])
# ).reshape(B,L,L,-1) + self.pos(idx)
msa = self.msa2msa(msa, pair, state, symmids, symmsub)
pair = self.msa2pair(msa, pair, symmids, symmsub)
pair = self.pair2pair(pair, crop, True, symmids, symmsub)
R_out, T_out, state, alpha, pair = self.str2residue(msa.float(), pair.float(),
R_in.detach().float(), T_in.float(),
state.float(), chain_mask.float())
xyz_out = torch.einsum('blij,blaj->blai', R_in, xyz_in) + T_in.unsqueeze(-2) # (B,L,N,3)
h_E, pae_neighbor, E_idx = self.value_net(seq.long(), idx, chain_mask, xyz_out)
pair = self.update_pair_by_value(xyz_out, pair, h_E, pae_neighbor, E_idx)
return msa, pair, R_out, T_out, state, alpha, symmsub
[docs]
class UpdatePairByValue(nn.Module):
def __init__(self, d_pair=128,d_rbf=64):
super(UpdatePairByValue, self).__init__()
self.value_gate = nn.Linear(d_pair, d_pair)
self.beta = 1.0 # for sigmoid. I used big beta to make bigger gradient.
self.proj_dist = nn.Linear(d_rbf, d_pair)
self.reset_parameter()
[docs]
def reset_parameter(self):
self.proj_dist = init_lecun_normal(self.proj_dist)
# mostly closed gate at the begining
nn.init.zeros_(self.value_gate.weight)
self.value_gate.bias.data.fill_(-3.0) # sigmoid(-3) ~ 0.05
[docs]
def forward(self, xyz, pair, h_E, pae_neighbor, E_idx):
# get xyz from R_in, T_in
B, L = xyz.shape[:2]
n_bin_pae = pae_neighbor.shape[-1]
logit_pae = torch.zeros(B, L, L, pae_neighbor.size(-1), device=pae_neighbor.device)
logit_pae = torch.scatter_add(logit_pae, 2, E_idx.unsqueeze(-1).expand(-1,-1,-1,n_bin_pae), pae_neighbor) # (B,L,L,n_bin_pae)
logit_pae = logit_pae.permute(0,3,1,2) # (B,n_bin_pae,L,L)
# scatter edge features to (B,L,L,d_edge)
h_E_full = torch.zeros(B, L, L, h_E.size(-1), device=h_E.device)
h_E_full = torch.scatter(h_E_full, 2, E_idx.unsqueeze(-1).expand(-1,-1,-1,h_E.size(-1)), h_E)
# get topk_mask (B,L,L) from E_idx
top_k_mask = torch.zeros(B, L, L, device=E_idx.device)
top_k_mask = top_k_mask.scatter(2, E_idx, 1)
value_gate = torch.sigmoid(self.beta * self.value_gate(h_E_full.clone().detach()))
value_gate = value_gate * top_k_mask.unsqueeze(-1)
Cb = get_Cb(xyz[:,:,:3])
dist_CB = rbf(
torch.cdist(Cb, Cb)
).reshape(B,L,L,-1)
pair = pair.clone()
pair[:, 1:, 1:] = pair[:, 1:, 1:] + value_gate * self.proj_dist(dist_CB)
return pair
[docs]
class IterativeSimulator(nn.Module):
def __init__(self,
value_net,
n_extra_block=4, n_main_block=12, n_ref_block=4,
d_query = 256, d_msa=256, d_pair=128, d_hidden=32, d_rbf=64,
n_head_msa=8, n_head_pair=4,
SE3_param_full={'l0_in_features':32, 'l0_out_features':16, 'l1_out_features' : 2, 'num_edge_features':32},
SE3_param_topk={'l0_in_features':32, 'l0_out_features':16, 'l1_out_features' : 2, 'num_edge_features':32},
p_drop=0.15):
super(IterativeSimulator, self).__init__()
self.value_net = value_net
self.n_extra_block = n_extra_block
self.n_main_block = n_main_block
self.n_ref_block = n_ref_block
self.proj_state = nn.Linear(SE3_param_topk['l0_out_features'], SE3_param_full['l0_out_features'])
# Update with seed sequences
if n_main_block > 0:
self.main_block = nn.ModuleList([IterBlock(value_net,
d_msa=d_msa, d_pair=d_pair,
n_head_msa=n_head_msa,
n_head_pair=n_head_pair,
d_hidden=d_hidden,
p_drop=p_drop,
use_global_attn=False,
SE3_param=SE3_param_full)
for i in range(n_main_block)])
self.proj_state2 = nn.Linear(SE3_param_full['l0_out_features'], SE3_param_topk['l0_out_features'])
# Final SE(3) refinement
if n_ref_block > 0:
self.str_refiner = Str2Residue_IPA(d_msa=d_msa, d_pair=d_pair, d_rbf = d_rbf,
d_state=SE3_param_topk['l0_out_features'],
p_drop=p_drop)
self.reset_parameter()
[docs]
def reset_parameter(self):
self.proj_state = init_lecun_normal(self.proj_state)
nn.init.zeros_(self.proj_state.bias)
self.proj_state2 = init_lecun_normal(self.proj_state2)
nn.init.zeros_(self.proj_state2.bias)
[docs]
def forward(self, seq, msa, pair, xyz_in, state, idx, chain_mask, symmids, symmsub, symmRs, symmmeta, use_checkpoint=False, p2p_crop=-1, topk_crop=-1):
# input:
# seq: query sequence (B, L)
# msa: seed MSA embeddings (B, N, L, d_msa)
# pair: initial residue pair embeddings (B, L, L, d_pair)
# xyz_in: initial BB coordinates (B, L, N/CA/C, 3) -> RF2
# xyz_in: initial BB coordinates (B, L, 14, 3) -> PSK, MiniWorld
# state: initial state features containing mixture of query seq, sidechain, accuracy info (B, L, d_state)
# idx: residue index
B,_,L = msa.shape[:3]
if symmsub is not None:
Lasu = L//symmsub.shape[0]
symmsub_in = symmsub.clone()
else:
Lasu = L
symmsub_in = None
R_in, T_in = rigid_from_3_points(xyz_in[:,:,0], xyz_in[:,:,1], xyz_in[:,:,2]) # (B, L, 3, 3), (B, L, 3)
# # R_chain, T_chain, R_residue, T_residue, R_sidechain, T_sidechain, state, alpha
# # TODO
# R_in_chain = torch.eye(3, device=xyz_in.device).reshape(1,1,3,3).expand(B, L, -1, -1)
# R_in_residue = R_in
# chain_mask = nn.functional.normalize(chain_mask, p=1, dim=-1) # normalize to sum to 1
# # xyz = xyz_in - T_in.unsqueeze(-2) # center on CA
# T_in_chain = torch.einsum('bij,bjc->bic', chain_mask, T_in.detach()) # (B,L,3)
# T_in_residue = T_in.detach() - einsum('blij,blj->bli', R_in_residue, T_in_chain.detach()) # (B,L,3)
state = self.proj_state(state)
R_list = list()
T_list = list()
alpha_list = list()
for i_m in range(self.n_main_block):
# R_in_chain = R_in_chain.detach() # (B,L,3,3)
# R_in_residue = R_in_residue.detach() # (B,L,3,3)
# T_in_chain = T_in_chain.detach() # (B,L,3)
# T_in_residue = T_in_residue.detach() # (B,L,3)
R_in = R_in.detach() # (B,L,3,3)
# xyz_in = xyz_in.detach() # (B,L,3,3)
# print(f"For test xyz_in[0,0] : {xyz_in[0,0]}")
if use_checkpoint:
msa, pair, R_in, T_in, state, alpha, symmsub = checkpoint.checkpoint(create_custom_forward(self.main_block[i_m]), seq, msa, pair, R_in, T_in, xyz_in, state, idx, chain_mask,
symmids, symmsub_in, symmsub, symmRs, symmmeta,
topk_crop, p2p_crop, use_reentrant=False)
else :
msa, pair, R_in, T_in, state, alpha, symmsub = self.main_block[i_m](seq, msa, pair,
R_in, T_in, xyz_in, state, idx, chain_mask, symmids, symmsub_in, symmsub, symmRs, symmmeta,
crop=p2p_crop, topk=topk_crop)
R_list.append(R_in)
T_list.append(T_in)
alpha_list.append(alpha)
state = self.proj_state2(state)
R_list = torch.stack(R_list, dim=0) # (n_blocks, B, L, 3, 3)
T_list = torch.stack(T_list, dim=0) # (n_blocks, B, L, 3)
alpha_list = torch.stack(alpha_list, dim=0) # (n_blocks, B, L, d_state)
return msa, pair, R_list, T_list, pair, alpha_list, state, symmsub