Source code for miniworld.models_MiniWorld_v1_5_use_interaction.Track_module

import torch.utils.checkpoint as checkpoint
from miniworld.models_MiniWorld_v1_5_use_interaction.Attention_module import *
from miniworld.utils.kinematics import normQ, Qs2Rs


# Components for three-track blocks
# 1. MSA -> MSA update (biased attention. bias from pair & structure)
# 2. Pair -> Pair update (biased attention. bias from structure)
# 3. MSA -> Pair update (extract coevolution signal)
# 4. Str -> Str update (node from MSA, edge from Pair)

# Module contains classes and functions to generate initial embeddings
[docs] class SeqSep(nn.Module): # Add relative positional encoding to pair features def __init__(self, d_model, minpos=-32, maxpos=32): super(SeqSep, self).__init__() self.minpos = minpos self.maxpos = maxpos self.nbin = abs(minpos)+maxpos+1 self.emb = nn.Embedding(self.nbin, d_model)
[docs] def forward(self, idx): # idx: (B, L) B, L = idx.shape[:2] bins = torch.arange(self.minpos, self.maxpos, device=idx.device) seqsep = torch.full((B,L,L), 100, dtype=idx.dtype, device=idx.device) seqsep = idx[:,None,:] - idx[:,:,None] # (B, L, L) # ib = torch.bucketize(seqsep, bins).long() # (B, L, L) emb = self.emb(ib) #(B, L, L, d_model) return emb
# Update MSA with biased self-attention. bias from Pair & Str
[docs] class MSAPairStr2MSA(nn.Module): def __init__(self, d_msa=256, d_pair=128, n_head=8, d_state=16, d_rbf=64, d_hidden=32, p_drop=0.15, use_global_attn=False): super(MSAPairStr2MSA, self).__init__() self.norm_pair = nn.LayerNorm(d_pair) self.norm_state = nn.LayerNorm(d_state) self.proj_state = nn.Linear(d_state, d_msa) self.drop_row = Dropout(broadcast_dim=1, p_drop=p_drop) self.row_attn = MSARowAttentionWithBias(d_msa=d_msa, d_pair=d_pair, n_head=n_head, d_hidden=d_hidden) self.col_attn = MSAColAttention(d_msa=d_msa, n_head=n_head, d_hidden=d_hidden) self.ff = FeedForwardLayer(d_msa, 4, p_drop=p_drop) # Do proper initialization self.reset_parameter()
[docs] def reset_parameter(self): # initialize weights to normal distrib self.proj_state = init_lecun_normal(self.proj_state) # initialize bias to zeros nn.init.zeros_(self.proj_state.bias)
[docs] def forward(self, msa, pair, state, symmids, symmsub): """Update MSA with biased self-attention, using bias from Pair & Str. :param msa: MSA feature tensor. :type msa: torch.Tensor of shape (B, N, L, d_msa) :param pair: Pair feature tensor. :type pair: torch.Tensor of shape (B, L, L, d_pair) :param state: Updated node features from the SE(3) layer. :type state: torch.Tensor of shape (B, L, d_state) :param symmids: Symmetry IDs. :type symmids: torch.Tensor :param symmsub: Symmetry sub-units. :type symmsub: torch.Tensor :return: Updated MSA feature tensor. :rtype: torch.Tensor of shape (B, N, L, d_msa) """ B, N, L, _ = msa.shape # prepare input bias feature by combining pair & coordinate info pair = self.norm_pair(pair) # # update query sequence feature (first sequence in the MSA) with feedbacks (state) from SE3 state = self.norm_state(state) state = self.proj_state(state).reshape(B, 1, L, -1) msa = msa.type_as(state) msa = msa.index_add(1, torch.tensor([0,], device=state.device), state) # (B, N, L, d_msa) # # Apply row/column attention to msa & transform msa = msa + self.drop_row(self.row_attn(msa, pair)) # (B, N, L+1, d_msa) msa = msa + self.col_attn(msa) # (B, N, L+1, d_msa) msa = msa + self.ff(msa) return msa
[docs] class MSA2Pair(nn.Module): def __init__(self, d_msa=256, d_pair=128, d_hidden=32, p_drop=0.15): super(MSA2Pair, self).__init__() self.norm = nn.LayerNorm(d_msa) self.proj_left = nn.Linear(d_msa, d_hidden) self.proj_right = nn.Linear(d_msa, d_hidden) self.proj_out = nn.Linear(d_hidden*d_hidden, d_pair) self.d_hidden = d_hidden self.reset_parameter()
[docs] def reset_parameter(self): # normal initialization self.proj_left = init_lecun_normal(self.proj_left) self.proj_right = init_lecun_normal(self.proj_right) nn.init.zeros_(self.proj_left.bias) nn.init.zeros_(self.proj_right.bias) # zero initialize output nn.init.zeros_(self.proj_out.weight) nn.init.zeros_(self.proj_out.bias)
[docs] def forward(self, msa, pair, symmids, symmsub): B, N, L = msa.shape[:3] msa = self.norm(msa) left = self.proj_left(msa) right = self.proj_right(msa) right = right/float(N) out = einsum('bsli,bsmj->blmij', left, right).reshape(B, L, L, -1) out = self.proj_out(out) pair = pair + out return pair
[docs] class PairStr2Pair(nn.Module): def __init__(self, d_pair=128, n_head=4, d_hidden=32, d_hidden_state=16, d_rbf=64, d_state=32, p_drop=0.15): super(PairStr2Pair, self).__init__() self.drop_row = Dropout(broadcast_dim=1, p_drop=p_drop) self.drop_col = Dropout(broadcast_dim=2, p_drop=p_drop) self.tri_mul_out = TriangleMultiplication(d_pair, d_hidden=d_hidden) self.tri_mul_in = TriangleMultiplication(d_pair, d_hidden, outgoing=False) self.ff = FeedForwardLayer(d_pair, 2) # perform a striped p2p op
[docs] def subblock(self, OP, pair, rbf_feat, crop): N,L = pair.shape[:2] nbox = (L-1)//(crop//2)+1 idx = torch.triu_indices(nbox,nbox,1, device=pair.device) ncrops = idx.shape[1] pairnew = torch.zeros((N,L*L,pair.shape[-1]), device=pair.device, dtype=pair.dtype) countnew = torch.zeros((N,L*L), device=pair.device, dtype=torch.int) for i in range(ncrops): # reindex sub-blocks offsetC = torch.clamp( (1+idx[1,i:(i+1)])*(crop//2)-L, min=0 ) # account for going past L offsetN = torch.zeros_like(offsetC) mask = (offsetC>0)*((idx[0,i]+1)==idx[1,i]) offsetN[mask] = offsetC[mask] pairIdx = torch.zeros((1,crop), dtype=torch.long, device=pair.device) pairIdx[:,:(crop//2)] = torch.arange(crop//2, dtype=torch.long, device=pair.device)+idx[0,i:(i+1),None]*(crop//2) - offsetN[:,None] pairIdx[:,(crop//2):] = torch.arange(crop//2, dtype=torch.long, device=pair.device)+idx[1,i:(i+1),None]*(crop//2) - offsetC[:,None] # do reindexing iL,iU = pairIdx[:,:,None], pairIdx[:,None,:] paircrop = pair[:,iL,iU,:].reshape(-1,crop,crop,pair.shape[-1]) rbfcrop = rbf_feat[:,iL,iU,:].reshape(-1,crop,crop,rbf_feat.shape[-1]) # row attn paircrop = OP(paircrop, rbfcrop).to(pair.dtype) # unindex iUL = (iL*L+iU).flatten() pairnew.index_add_(1,iUL, paircrop.reshape(N,iUL.shape[0],pair.shape[-1])) countnew.index_add_(1,iUL, torch.ones((N,iUL.shape[0]), device=pair.device, dtype=torch.int)) return pair + (pairnew/countnew[...,None]).reshape(N,L,L,-1)
[docs] def forward(self, pair, crop=-1, use_species = True, symmids=None, symmsub=None): B,L = pair.shape[:2] crop = 2*(crop//2) # make sure even if (crop>0 and crop<=L): pair = self.subblock( lambda x,y:self.drop_row(self.tri_mul_out(x)), #, symmids)), pair, crop ) pair = self.subblock( lambda x,y:self.drop_row(self.tri_mul_in(x)), #, symmids)), pair, crop ) pair = self.subblock( lambda x,y:self.drop_row(self.row_attn(x,y)), #, symmids)), pair, crop ) pair = self.subblock( lambda x,y:self.drop_col(self.col_attn(x,y)), #, symmids)), pair, crop ) # feed forward layer RESSTRIDE = 16384//L for i in range((L-1)//RESSTRIDE+1): r_i,r_j = i*RESSTRIDE, min((i+1)*RESSTRIDE,L) pair[:,r_i:r_j] = pair[:,r_i:r_j] + self.ff(pair[:,r_i:r_j]) else: #_nc = lambda x:torch.sum(torch.isnan(x)) pair = pair + self.drop_row(self.tri_mul_out(pair)) pair = pair + self.drop_row(self.tri_mul_in(pair)) pair = pair + self.ff(pair) return pair
[docs] class SCPred(nn.Module): def __init__(self, d_msa=256, d_state=32, d_hidden=128, p_drop=0.15): super(SCPred, self).__init__() self.norm_s0 = nn.LayerNorm(d_msa) self.norm_si = nn.LayerNorm(d_state) self.linear_s0 = nn.Linear(d_msa, d_hidden) self.linear_si = nn.Linear(d_state, d_hidden) # ResNet layers self.linear_1 = nn.Linear(d_hidden, d_hidden) self.linear_2 = nn.Linear(d_hidden, d_hidden) self.linear_3 = nn.Linear(d_hidden, d_hidden) self.linear_4 = nn.Linear(d_hidden, d_hidden) # Final outputs self.linear_out = nn.Linear(d_hidden, 20) self.reset_parameter()
[docs] def reset_parameter(self): # normal initialization self.linear_s0 = init_lecun_normal(self.linear_s0) self.linear_si = init_lecun_normal(self.linear_si) self.linear_out = init_lecun_normal(self.linear_out) nn.init.zeros_(self.linear_s0.bias) nn.init.zeros_(self.linear_si.bias) nn.init.zeros_(self.linear_out.bias) # right before relu activation: He initializer (kaiming normal) nn.init.kaiming_normal_(self.linear_1.weight, nonlinearity='relu') nn.init.zeros_(self.linear_1.bias) nn.init.kaiming_normal_(self.linear_3.weight, nonlinearity='relu') nn.init.zeros_(self.linear_3.bias) # right before residual connection: zero initialize nn.init.zeros_(self.linear_2.weight) nn.init.zeros_(self.linear_2.bias) nn.init.zeros_(self.linear_4.weight) nn.init.zeros_(self.linear_4.bias)
[docs] def forward(self, seq, state): """Predict side-chain torsion angles along with backbone torsions. :param seq: Hidden embeddings corresponding to the query sequence. :type seq: torch.Tensor of shape (B, L, d_msa) :param state: State features (output l0 feature) from the previous SE(3) layer. :type state: torch.Tensor of shape (B, L, d_state) :return: Predicted torsion angles (phi, psi, omega, chi1-4 with cos/sin, Cb bend, Cb twist, CG). :rtype: torch.Tensor of shape (B, L, 10, 2) """ B, L = seq.shape[:2] seq = self.norm_s0(seq) state = self.norm_si(state) si = self.linear_s0(seq) + self.linear_si(state[:,1:]) si = si + self.linear_2(F.relu_(self.linear_1(F.relu_(si)))) si = si + self.linear_4(F.relu_(self.linear_3(F.relu_(si)))) si = self.linear_out(F.relu_(si)) return si.view(B, L, 10, 2)
[docs] def update_symm_Rs(Rs, Ts, Lasu, symmsub_in, symmsub, symmRs): def dist_error(R0,T0,Rs,Ts): B = Ts.shape[0] Tcom = Ts[:,:Lasu].mean(dim=1,keepdim=True) Tcorr = torch.einsum('ij,brj->bri', R0, Ts[:,:Lasu]-Tcom) + Tcom + 10.0*T0[None,None,:] Xsymm = torch.einsum('sij,brj->bsri', symmRs[symmsub], Tcorr).reshape(B,-1,3) Xtrue = Ts dsymm = torch.linalg.norm(Xsymm[:,:,None]-Xsymm[:,None,:], dim=-1) dtrue = torch.linalg.norm(Xtrue[:,:,None]-Xtrue[:,None,:], dim=-1) return torch.clamp( torch.abs(dsymm-dtrue), max=10.0).mean() B = Ts.shape[0] # symmetry correction 1: don't let COM (of entire complex) move Tmean = Ts[:,:Lasu].reshape(-1,3).mean(dim=0) Tmean = torch.einsum('sij,j->si', symmRs, Tmean).mean(dim=0) Ts = Ts - Tmean # symmetry correction 2: use minimization to minimize drms #with torch.enable_grad(): # T0 = torch.zeros(3,device=Ts.device).requires_grad_(True) # R0 = torch.eye(3,device=Ts.device).requires_grad_(True) # opt = torch.optim.SGD([T0,R0], lr=0.001) # # if (dist_error(R0,T0,symmRs[symmsub],Ts)>0.5): # for e in range(101): # loss = dist_error(R0,T0,symmRs[symmsub],Ts) # if (e%50 == 0): # print (e,loss) # opt.zero_grad() # loss.backward() # opt.step() # #Tcom = Ts[:,:Lasu].mean(dim=1,keepdim=True) #Ts = torch.einsum('ij,brj->bri', R0, Ts[:,:Lasu]-Tcom) +Tcom + 10.0*T0[None,None,:] Rs = torch.einsum('sij,brjk,slk->bsril', symmRs[symmsub], Rs[:,:Lasu], symmRs[symmsub_in]) Ts = torch.einsum('sij,brj->bsri', symmRs[symmsub], Ts[:,:Lasu]) Rs = Rs.reshape(B,-1,3,3) # (B,S,L,3,3) Ts = Ts.reshape(B,-1,3) # (B,S,L,3,3) return Rs, Ts
[docs] def update_symm_subs(Rs, Ts, pair, symmids, symmsub_in, symmsub, symmRs, metasymm): B,Ls = Ts.shape[0:2] Osub = symmsub.shape[0] L = Ls//Osub com = Ts[:,:L].sum(dim=-2) rcoms = torch.einsum('sij,bj->si', symmRs, com) subsymms, nneighs = metasymm symmsub_new = [] for i in range(len(subsymms)): drcoms = torch.linalg.norm(rcoms[0,:] - rcoms[subsymms[i],:], dim=-1) _,subs_i = torch.topk(drcoms,nneighs[i],largest=False) subs_i,_ = torch.sort( subsymms[i][subs_i] ) symmsub_new.append(subs_i) symmsub_new = torch.cat(symmsub_new) s_old = symmids[symmsub[:,None],symmsub[None,:]] s_new = symmids[symmsub_new[:,None],symmsub_new[None,:]] # remap old->new # a) find highest-magnitude patches pairsub = dict() pairmag = dict() for i in range(Osub): for j in range(Osub): idx_old = s_old[i,j].item() sub_ij = pair[:,i*L:(i+1)*L,j*L:(j+1)*L,:].clone() mag_ij = torch.max(sub_ij.flatten()) #torch.norm(sub_ij.flatten()) if idx_old not in pairsub or mag_ij > pairmag[idx_old]: pairmag[idx_old] = mag_ij pairsub[idx_old] = (i,j) #sub_ij # b) reindex idx = torch.zeros((Osub*L,Osub*L),dtype=torch.long,device=pair.device) idx = ( torch.arange(Osub*L,device=pair.device)[:,None]*Osub*L + torch.arange(Osub*L,device=pair.device)[None,:] ) for i in range(Osub): for j in range(Osub): idx_new = s_new[i,j].item() if idx_new in pairsub: inew,jnew = pairsub[idx_new] idx[i*L:(i+1)*L,j*L:(j+1)*L] = ( Osub*L*torch.arange(inew*L,(inew+1)*L)[:,None] + torch.arange(jnew*L,(jnew+1)*L)[None,:] ) pair = pair.view(1,-1,pair.shape[-1])[:,idx.flatten(),:].view(1,Osub*L,Osub*L,pair.shape[-1]) #for i in range(Osub): # for j in range(Osub): # idx_new = s_new[i,j].item() # if idx_new in pairsub: # pass # #pair[:,i*L:(i+1)*L,j*L:(j+1)*L,:] = pairsub[idx_new] #/pairmag[idx_new] #if (torch.any(symmsub!=symmsub_new)): # print (symmsub,'->',symmsub_new) if symmsub_in is not None and symmsub_in.shape[0]>1: Rs, Ts = update_symm_Rs(Rs, Ts, L, symmsub_in, symmsub_new, symmRs) return Rs, Ts, pair, symmsub_new
[docs] class IPA(nn.Module): def __init__(self, d_seq=256, d_pair=128, d_state=16, n_head=12, d_hidden = 16, n_query_points = 4, n_point_values = 8, top_k =64, p_drop=0.1): super(IPA, self).__init__() self.n_head = n_head self.dim = d_hidden self.n_query_points = n_query_points self.n_point_values = n_point_values self.top_k = top_k self.w_C = math.sqrt(2 / (9 * n_query_points)) # self.w_L = math.sqrt(1 / 3) self.w_L = 1.0 self.to_value_for_seq = nn.Linear(d_seq, n_head*d_hidden, bias=False) # self.to_query_for_state = nn.Linear(d_state, n_head * n_query_points * 3, bias=False) # self.to_key_for_state = nn.Linear(d_state, n_head * n_query_points * 3, bias=False) self.to_value_for_state = nn.Linear(d_state, n_head * n_point_values * 3, bias=False) # self.to_q_for_pair = nn.Linear(d_pair, n_head*d_hidden, bias=False) # self.to_k_for_pair = nn.Linear(d_pair, n_head*d_hidden, bias=False) self.to_bias_for_pair = nn.Linear(d_pair, n_head, bias=False) # self.gamma = nn.Parameter(torch.ones(1)) out_dim = n_head*d_pair + n_head*d_hidden + n_head*n_point_values*3 + n_head*n_point_values # self.linear_seq_out = nn.Linear(out_dim, d_seq) self.linear_state_out = nn.Linear(out_dim, d_state) self.reset_parameter()
[docs] def reset_parameter(self): # normal initialization # nn.init.xavier_uniform_(self.to_q_for_pair.weight) # nn.init.xavier_uniform_(self.to_k_for_pair.weight) self.to_value_for_seq = init_lecun_normal(self.to_value_for_seq) # self.to_query_for_state = init_lecun_normal(self.to_query_for_state) # self.to_key_for_state = init_lecun_normal(self.to_key_for_state) self.to_value_for_state = init_lecun_normal(self.to_value_for_state) self.to_bias_for_pair = init_lecun_normal(self.to_bias_for_pair) # self.linear_seq_out = init_lecun_normal(self.linear_seq_out) self.linear_state_out = init_lecun_normal(self.linear_state_out)
# with torch.no_grad(): # softplus_inverse_1 = 0.541324854612918 # self.gamma.fill_(softplus_inverse_1)
[docs] def apply_RT(self, Rs, Ts, vec): # Rs : (B, L, 3, 3) # Ts : (B, L, 3) # vec : (B, L, h, p, 3) B, L, h, p, _ = vec.shape Ts = Ts.unsqueeze(-2).unsqueeze(-2) # (B, L, 1, 1, 3) Ts = Ts.expand(-1,-1,h,p,-1) # (B, L, h, p, 3) return einsum('blij,blhpj->blhpi', Rs, vec) + Ts
[docs] def inverse_RT(self, Rs, Ts, vec): # Rs : (B, L, 3, 3) # Ts : (B, L, 3) # vec : (B, L, h, p, 3) B, L, h, p, _ = vec.shape Rs = Rs.transpose(-1,-2) # (B, L, 3, 3) Ts = - einsum('blij,blj->bli', Rs, Ts) # (B, L, 3) return self.apply_RT(Rs, Ts, vec)
[docs] def gather_edges_3d(self, edges, neighbor_idx): # Features (B, L, L, h, *) ==> (B, L, K, K, h, *) # neighbor_idx (B, L, K, h) B, L, K, h = neighbor_idx.shape indices_3d = neighbor_idx.unsqueeze(2).repeat(1,1,K,1,1) # (B, L, K-new, K, h) flat_indices = indices_3d.view(B, L, -1, h) # (B, L, K*K, h) flat_indices = flat_indices.unsqueeze(-1).expand(-1,-1,-1,-1,edges.size(-1)) # (B, L, K*K, h, *) gathered_edges = torch.gather(edges, 2, flat_indices) # (B, L, K*K, h, *) return gathered_edges.reshape(B, L, K, K, h, -1)
[docs] def gather_edges(self, edges, neighbor_idx): # Features [B,L,L,h,C] at Neighbor indices [B,L,K,h] => Neighbor features [B,L,K,h,C] neighbors = neighbor_idx.unsqueeze(-1).expand(-1, -1, -1, -1, edges.size(-1)) edge_features = torch.gather(edges, 2, neighbors) return edge_features
[docs] def forward(self, seq, pair, state, Rs, Ts, chain_mask): B, L = state.shape[:2] # getvalue for seq value_seq = self.to_value_for_seq(seq).reshape(B, L, self.n_head, self.dim) # get value for state value_state = self.to_value_for_state(state).reshape(B, L, self.n_head, self.n_point_values, 3) # get bias for pair bias_pair = self.to_bias_for_pair(pair).reshape(B, L, L, self.n_head) # (B, L, L, n_head) attn = bias_pair # (B, L, L, n_head) attn = F.softmax(attn, dim=2) # (B, L, L, n_head) # pair_qk_test = pair_qk_test[:,:,:,:4,:].view(B, L, L, -1) # (B, L, L, 4*h) # saving_dir = "saved_models/MiniWorld_v1_0_with_template/attn_wo_strAttn/" # draw_attn(pair_qk, saving_dir + "neighbor_qk.png") # draw_attn_multirow(pair_qk_test, 4, saving_dir + "neighbor_qk_each.png") # draw_attn(bias_pair, saving_dir + "bias_pair.png") # draw_attn(-structural_attn, saving_dir + "strructural_attn.png") # draw_attn(attn, saving_dir + "final_attn.png") d # breakpoint() o_tilde = einsum('bijh,bijd->bihd', attn, pair) # (B, L, n_head, d_pair) o = einsum('bijh,bjhd->bihd', attn, value_seq) # (B, L, n_head, d_hidden) value_state = self.apply_RT(Rs, Ts, value_state) # (B, L, n_head, n_point_values, 3) value_state = einsum('bijh,bjhpd->bihpd', attn, value_state) # (B, L, n_head, n_point_values, 3) o_vec = self.inverse_RT(Rs, Ts, value_state) # (B, L, n_head, n_point_values, 3) o_vec_norm = torch.norm(o_vec, dim=-1) # (B, L, n_head, n_point_values) # flatten to concat o_tilde = o_tilde.reshape(B, L, -1) # (B, L, n_head*d_pair) # flatten to concat o_tilde = o_tilde.reshape(B, L, -1) # (B, L, n_head*d_pair) o = o.reshape(B, L, -1) # (B, L, n_head*d_hidden) o_vec = o_vec.reshape(B, L, -1) # (B, L, n_head*n_point_values*3) o_vec_norm = o_vec_norm.reshape(B, L, -1) # (B, L, n_head*n_point_values) o_concat = torch.cat((o_tilde, o, o_vec, o_vec_norm), dim=-1) # (B, L, n_head*d_pair + n_head*d_hidden + n_head*n_point_values*3 + n_head*n_point_values) # seq_out = self.linear_seq_out(o_concat) # (B, L, d_seq) state_out = self.linear_state_out(o_concat) # (B, L, d_state) return state_out
[docs] class Str2Chain_IPA(nn.Module): def __init__(self, d_msa=256, d_pair=128, d_state=16, n_head=8, d_hidden = 32, n_query_points = 4, n_point_values = 8, p_drop=0.1): super(Str2Chain_IPA, self).__init__() # LayerNorm self.norm_msa = nn.LayerNorm(d_msa) self.norm_pair = nn.LayerNorm(d_pair) self.norm_state = nn.LayerNorm(d_state) self.IPA = IPA(d_seq=d_msa, d_pair=d_pair, d_state=d_state, n_head=n_head, d_hidden=d_hidden, n_query_points=n_query_points, n_point_values=n_point_values, p_drop=p_drop) state_linear_list = [] for _ in range(3): state_linear_list.append(nn.Linear(d_state, d_state)) state_linear_list.append(nn.ReLU()) self.state_linear = nn.Sequential(*state_linear_list) self.state_to_RT = nn.Linear(d_state, 3 + 3) # self.state_gate = nn.Linear(d_state, 1) self.reset_parameter()
[docs] def reset_parameter(self): # initialize weights to normal distribution for layer in self.state_linear: if isinstance(layer, nn.Linear): layer = init_lecun_normal(layer) # zero initialize nn.init.zeros_(self.state_to_RT.weight) nn.init.zeros_(self.state_to_RT.bias)
# gating: zero weights, one biases (mostly open gate at the begining) # nn.init.zeros_(self.state_gate.weight) # nn.init.ones_(self.state_gate.bias)
[docs] def forward(self, msa, pair_in, R_in, T_in, state, chain_mask): # msa : (B, N, L, d_msa) # pair : (B, L, L, d_pair) # xyz : (B, L, 3, 3) B, N, L = msa.shape[:3] msa = self.norm_msa(msa) pair_in = self.norm_pair(pair_in) state = self.norm_state(state) seq = msa[:,0,1:] # (B, L, d_msa) state_out = self.IPA(seq, pair_in[:,1:,1:], state[:,1:], R_in, T_in, chain_mask) state_out = F.pad(state_out, (0,0,1,0), 'constant', 0) # (B, L+1, d_state) # seq = seq + seq_out # (B, L, d_msa) state = state + state_out # (B, L, d_state) state = state + self.state_linear(state) state = self.norm_state(state) RT_vec = self.state_to_RT(state[:,1:]) # (B, L, 3 + 3) Q_chain = RT_vec[:,:,:3] # (B, L, 3) T_chain = RT_vec[:,:,3:] # (B, L, 3) # gate = torch.sigmoid(self.state_gate(state)) # (B, L, 1) # chain_mask : (B,L,L) 1 for same chain # gate = chain_mask # (B, L, L) chain_mask = chain_mask / torch.sum(chain_mask, dim=-1, keepdim=True) # normalize to sum to 1 T_chain = torch.einsum('bij,bjc->bic', chain_mask, T_chain) # (B,L,3) Q_chain = torch.einsum('bij,bjc->bic', chain_mask, Q_chain) # (B,L,3) Q_chain = torch.cat((torch.ones((B,L-1,1),device=Q_chain.device),Q_chain),dim=-1) # (B,L,4) # breakpoint() Q_chain = normQ(Q_chain) R_chain = Qs2Rs(Q_chain) R_out = einsum('bnij,bnjk->bnik', R_in, R_chain) T_out = einsum('bnij,bnj->bni', R_in, T_chain) + T_in return R_out, T_out, state
[docs] class Str2Residue_IPA(nn.Module): def __init__(self, d_msa=256, d_pair=128, d_state=16, d_rbf = 64, n_head=8, d_hidden = 32, n_query_points = 4, n_point_values = 8, p_drop=0.1): super(Str2Residue_IPA, self).__init__() self.norm_msa = nn.LayerNorm(d_msa) self.norm_pair = nn.LayerNorm(d_pair) self.norm_state = nn.LayerNorm(d_state) self.IPA = IPA(d_seq=d_msa, d_pair=d_pair, d_state=d_state, n_head=n_head, d_hidden=d_hidden, n_query_points=n_query_points, n_point_values=n_point_values, p_drop=p_drop) state_linear_list = [] for _ in range(3): state_linear_list.append(nn.Linear(d_state, d_state)) state_linear_list.append(nn.ReLU()) self.state_linear = nn.Sequential(*state_linear_list) self.state_to_RT = nn.Linear(d_state, 3 + 3) self.sc_predictor = SCPred(d_msa=d_msa, d_state=d_state, p_drop=p_drop) self.reset_parameter()
[docs] def reset_parameter(self): # initialize weights to normal distribution for layer in self.state_linear: if isinstance(layer, nn.Linear): layer = init_lecun_normal(layer) # zero initialize nn.init.zeros_(self.state_to_RT.weight) nn.init.zeros_(self.state_to_RT.bias)
[docs] def forward(self, msa, pair_in, R_in, T_in, state, chain_mask): # msa : (B, N, L+1, d_msa) # pair : (B, L+1, L+1, d_pair) # xyz : (B, L, 3, 3) B, N, L = msa.shape[:3] msa = self.norm_msa(msa) pair_in = self.norm_pair(pair_in) state = self.norm_state(state) seq = msa[:,0,1:] # (B, L, d_msa) state_out = self.IPA(seq, pair_in[:,1:,1:], state[:,1:], R_in, T_in, chain_mask) state_out = F.pad(state_out, (0,0,1,0), 'constant', 0) # (B, L+1, d_state) # seq = seq + seq_out # (B, L, d_msa) state = state + state_out # (B, L, d_state) state = state + self.state_linear(state) state = self.norm_state(state) RT_vec = self.state_to_RT(state[:, 1:]) # (B, L, 3 + 3) Qs = RT_vec[:,:,:3] # (B, L, 3) T_residue = RT_vec[:,:,3:] # (B, L, 3) Qs = torch.cat((torch.ones((B,L-1,1),device=Qs.device),Qs),dim=-1) Qs = normQ(Qs) R_residue = Qs2Rs(Qs) alpha = self.sc_predictor(seq, state) R_out = einsum('bnij,bnjk->bnik', R_in, R_residue) T_out = einsum('bnij,bnj->bni', R_in, T_residue) + T_in # breakpoint() return R_out, T_out, state, alpha, pair_in
[docs] class IterBlock(nn.Module): def __init__(self, value_net, d_msa=256, d_pair=128, d_rbf=64, n_head_msa=8, n_head_pair=4, use_global_attn=False, d_hidden=32, d_hidden_msa=None, p_drop=0.15, SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}): super(IterBlock, self).__init__() if d_hidden_msa == None: d_hidden_msa = d_hidden self.msa2msa = MSAPairStr2MSA(d_msa=d_msa, d_pair=d_pair, n_head=n_head_msa, d_state=SE3_param['l0_out_features'], use_global_attn=use_global_attn, d_hidden=d_hidden_msa, p_drop=p_drop) self.msa2pair = MSA2Pair(d_msa=d_msa, d_pair=d_pair, d_hidden=d_hidden//2, p_drop=p_drop) self.pair2pair = PairStr2Pair(d_pair=d_pair, n_head=n_head_pair, d_state=SE3_param['l0_out_features'], d_hidden=d_hidden, p_drop=p_drop) self.str2residue = Str2Residue_IPA( d_msa=d_msa, d_pair=d_pair, d_rbf = d_rbf, d_state=SE3_param['l0_out_features'], p_drop=p_drop) self.value_net = value_net self.update_pair_by_value = UpdatePairByValue(d_pair=d_pair, d_rbf=d_rbf)
[docs] def forward(self, seq, msa, pair, R_in, T_in, xyz_in, state, idx, chain_mask, symmids, symmsub_in, symmsub, symmRs, symmmeta, topk=0, crop=-1): # msa : (B, N, L, d_msa) # pair : (B, L, L, d_pair) # xyz : (B, L, 3, 3) B, L = pair.shape[:2] # xyzfull = xyz.view(1,B*L,3,3) # rbf_feat = rbf( # torch.cdist(xyzfull[:,:,1,:], xyzfull[:,:L,1,:]) # ).reshape(B,L,L,-1) + self.pos(idx) msa = self.msa2msa(msa, pair, state, symmids, symmsub) pair = self.msa2pair(msa, pair, symmids, symmsub) pair = self.pair2pair(pair, crop, True, symmids, symmsub) R_out, T_out, state, alpha, pair = self.str2residue(msa.float(), pair.float(), R_in.detach().float(), T_in.float(), state.float(), chain_mask.float()) xyz_out = torch.einsum('blij,blaj->blai', R_in, xyz_in) + T_in.unsqueeze(-2) # (B,L,N,3) h_E, pae_neighbor, E_idx = self.value_net(seq.long(), idx, chain_mask, xyz_out) pair = self.update_pair_by_value(xyz_out, pair, h_E, pae_neighbor, E_idx) return msa, pair, R_out, T_out, state, alpha, symmsub
[docs] class UpdatePairByValue(nn.Module): def __init__(self, d_pair=128,d_rbf=64): super(UpdatePairByValue, self).__init__() self.value_gate = nn.Linear(d_pair, d_pair) self.beta = 1.0 # for sigmoid. I used big beta to make bigger gradient. self.proj_dist = nn.Linear(d_rbf, d_pair) self.reset_parameter()
[docs] def reset_parameter(self): self.proj_dist = init_lecun_normal(self.proj_dist) # mostly closed gate at the begining nn.init.zeros_(self.value_gate.weight) self.value_gate.bias.data.fill_(-3.0) # sigmoid(-3) ~ 0.05
[docs] def forward(self, xyz, pair, h_E, pae_neighbor, E_idx): # get xyz from R_in, T_in B, L = xyz.shape[:2] n_bin_pae = pae_neighbor.shape[-1] logit_pae = torch.zeros(B, L, L, pae_neighbor.size(-1), device=pae_neighbor.device) logit_pae = torch.scatter_add(logit_pae, 2, E_idx.unsqueeze(-1).expand(-1,-1,-1,n_bin_pae), pae_neighbor) # (B,L,L,n_bin_pae) logit_pae = logit_pae.permute(0,3,1,2) # (B,n_bin_pae,L,L) # scatter edge features to (B,L,L,d_edge) h_E_full = torch.zeros(B, L, L, h_E.size(-1), device=h_E.device) h_E_full = torch.scatter(h_E_full, 2, E_idx.unsqueeze(-1).expand(-1,-1,-1,h_E.size(-1)), h_E) # get topk_mask (B,L,L) from E_idx top_k_mask = torch.zeros(B, L, L, device=E_idx.device) top_k_mask = top_k_mask.scatter(2, E_idx, 1) value_gate = torch.sigmoid(self.beta * self.value_gate(h_E_full.clone().detach())) value_gate = value_gate * top_k_mask.unsqueeze(-1) Cb = get_Cb(xyz[:,:,:3]) dist_CB = rbf( torch.cdist(Cb, Cb) ).reshape(B,L,L,-1) pair = pair.clone() pair[:, 1:, 1:] = pair[:, 1:, 1:] + value_gate * self.proj_dist(dist_CB) return pair
[docs] class IterativeSimulator(nn.Module): def __init__(self, value_net, n_extra_block=4, n_main_block=12, n_ref_block=4, d_query = 256, d_msa=256, d_pair=128, d_hidden=32, d_rbf=64, n_head_msa=8, n_head_pair=4, SE3_param_full={'l0_in_features':32, 'l0_out_features':16, 'l1_out_features' : 2, 'num_edge_features':32}, SE3_param_topk={'l0_in_features':32, 'l0_out_features':16, 'l1_out_features' : 2, 'num_edge_features':32}, p_drop=0.15): super(IterativeSimulator, self).__init__() self.value_net = value_net self.n_extra_block = n_extra_block self.n_main_block = n_main_block self.n_ref_block = n_ref_block self.proj_state = nn.Linear(SE3_param_topk['l0_out_features'], SE3_param_full['l0_out_features']) # Update with seed sequences if n_main_block > 0: self.main_block = nn.ModuleList([IterBlock(value_net, d_msa=d_msa, d_pair=d_pair, n_head_msa=n_head_msa, n_head_pair=n_head_pair, d_hidden=d_hidden, p_drop=p_drop, use_global_attn=False, SE3_param=SE3_param_full) for i in range(n_main_block)]) self.proj_state2 = nn.Linear(SE3_param_full['l0_out_features'], SE3_param_topk['l0_out_features']) # Final SE(3) refinement if n_ref_block > 0: self.str_refiner = Str2Residue_IPA(d_msa=d_msa, d_pair=d_pair, d_rbf = d_rbf, d_state=SE3_param_topk['l0_out_features'], p_drop=p_drop) self.reset_parameter()
[docs] def reset_parameter(self): self.proj_state = init_lecun_normal(self.proj_state) nn.init.zeros_(self.proj_state.bias) self.proj_state2 = init_lecun_normal(self.proj_state2) nn.init.zeros_(self.proj_state2.bias)
[docs] def forward(self, seq, msa, pair, xyz_in, state, idx, chain_mask, symmids, symmsub, symmRs, symmmeta, use_checkpoint=False, p2p_crop=-1, topk_crop=-1): # input: # seq: query sequence (B, L) # msa: seed MSA embeddings (B, N, L, d_msa) # pair: initial residue pair embeddings (B, L, L, d_pair) # xyz_in: initial BB coordinates (B, L, N/CA/C, 3) -> RF2 # xyz_in: initial BB coordinates (B, L, 14, 3) -> PSK, MiniWorld # state: initial state features containing mixture of query seq, sidechain, accuracy info (B, L, d_state) # idx: residue index B,_,L = msa.shape[:3] if symmsub is not None: Lasu = L//symmsub.shape[0] symmsub_in = symmsub.clone() else: Lasu = L symmsub_in = None R_in, T_in = rigid_from_3_points(xyz_in[:,:,0], xyz_in[:,:,1], xyz_in[:,:,2]) # (B, L, 3, 3), (B, L, 3) # # R_chain, T_chain, R_residue, T_residue, R_sidechain, T_sidechain, state, alpha # # TODO # R_in_chain = torch.eye(3, device=xyz_in.device).reshape(1,1,3,3).expand(B, L, -1, -1) # R_in_residue = R_in # chain_mask = nn.functional.normalize(chain_mask, p=1, dim=-1) # normalize to sum to 1 # # xyz = xyz_in - T_in.unsqueeze(-2) # center on CA # T_in_chain = torch.einsum('bij,bjc->bic', chain_mask, T_in.detach()) # (B,L,3) # T_in_residue = T_in.detach() - einsum('blij,blj->bli', R_in_residue, T_in_chain.detach()) # (B,L,3) state = self.proj_state(state) R_list = list() T_list = list() alpha_list = list() for i_m in range(self.n_main_block): # R_in_chain = R_in_chain.detach() # (B,L,3,3) # R_in_residue = R_in_residue.detach() # (B,L,3,3) # T_in_chain = T_in_chain.detach() # (B,L,3) # T_in_residue = T_in_residue.detach() # (B,L,3) R_in = R_in.detach() # (B,L,3,3) # xyz_in = xyz_in.detach() # (B,L,3,3) # print(f"For test xyz_in[0,0] : {xyz_in[0,0]}") if use_checkpoint: msa, pair, R_in, T_in, state, alpha, symmsub = checkpoint.checkpoint(create_custom_forward(self.main_block[i_m]), seq, msa, pair, R_in, T_in, xyz_in, state, idx, chain_mask, symmids, symmsub_in, symmsub, symmRs, symmmeta, topk_crop, p2p_crop, use_reentrant=False) else : msa, pair, R_in, T_in, state, alpha, symmsub = self.main_block[i_m](seq, msa, pair, R_in, T_in, xyz_in, state, idx, chain_mask, symmids, symmsub_in, symmsub, symmRs, symmmeta, crop=p2p_crop, topk=topk_crop) R_list.append(R_in) T_list.append(T_in) alpha_list.append(alpha) state = self.proj_state2(state) R_list = torch.stack(R_list, dim=0) # (n_blocks, B, L, 3, 3) T_list = torch.stack(T_list, dim=0) # (n_blocks, B, L, 3) alpha_list = torch.stack(alpha_list, dim=0) # (n_blocks, B, L, d_state) return msa, pair, R_list, T_list, pair, alpha_list, state, symmsub