Source code for miniworld.models_MiniWorld_v1_5_use_interaction.Embeddings

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from miniworld.utils.util import get_Cb
from miniworld.utils.util_module import create_custom_forward, rbf, init_lecun_normal
from miniworld.models_MiniWorld_v1_5_use_interaction.Attention_module import Attention, FeedForwardLayer, AttentionWithBias
from miniworld.models_MiniWorld_v1_5_use_interaction.Track_module import PairStr2Pair, UpdatePairByValue

# Module contains classes and functions to generate initial embeddings
# 20230918 PSK MiniWorld V1.5
[docs] class PositionalEncoding2D(nn.Module): # Add relative positional encoding to pair features def __init__(self, d_model, minpos=-31, maxpos=31, topk=2): super(PositionalEncoding2D, self).__init__() self.minpos = minpos self.maxpos = maxpos self.nbin = abs(minpos)+maxpos + 1 + 1 self.emb = nn.Embedding(self.nbin, d_model) self.emb_chain = nn.Embedding(topk+2, d_model) # 2 = 1 for self, 1 for distance > topk self.topk = topk
[docs] @torch.no_grad() def get_chain_break_from_mask(self, chain_mask): _, L, _ = chain_mask.shape # Extract the upper diagonal upper_diagonal = torch.diagonal(chain_mask[0], offset=1, dim1=-2, dim2=-1) # Identify breaks in the chain where the upper diagonal is 0 breaks = (upper_diagonal == 0).nonzero(as_tuple=False).squeeze(-1) # Prepare starts and ends of chains based on breaks starts = torch.cat([torch.tensor([0], device=chain_mask.device), breaks + 1]) ends = torch.cat([breaks, torch.tensor([L-1], device=chain_mask.device)]) # Create the chain_break dictionary chain_break = {i: (start, end) for i, (start, end) in enumerate(zip(starts.tolist(), ends.tolist()))} return chain_break
[docs] @torch.no_grad() def get_chain_embedding(self, xyz, chain_mask, topk=2): B, L, N, _ = xyz.shape chain_break = self.get_chain_break_from_mask(chain_mask) # assume that B = 1 C = len(chain_break) # Number of chains topk = min(topk, C) # Limit topk to the number of chains - 1 if topk < 1: return torch.zeros((B, L, L), dtype=torch.long, device=xyz.device) # Step 1: Calculate center of mass for each chain CoM_Ca = torch.zeros((B, C, 3), dtype=xyz.dtype, device=xyz.device) CoM_N = torch.zeros((B, C, 3), dtype=xyz.dtype, device=xyz.device) for i, (start, end) in enumerate(chain_break.values()): chain_Ca = xyz[:, start:end+1, 1, :] # Slice tensor for each chain chain_N = xyz[:, start:end+1, 0, :] # Slice tensor for each chain mass_Ca = chain_Ca.mean(dim=1) mass_N = chain_N.mean(dim=1) CoM_Ca[:, i, :] = mass_Ca CoM_N[:, i, :] = mass_N diff = CoM_Ca.unsqueeze(2) - CoM_N.unsqueeze(1) # (B, C, C, 3) distance = torch.norm(diff, dim=-1) # (B, C, C) # Step 4: Identify the topk nearest chains for each chain based on nearest_distances distance[:, torch.arange(C), torch.arange(C)] = float('inf') # Set the diagonal to infinity nearest_topk = distance.topk(k=topk,dim=-1, largest=False)[1] # (B, C, topk) # Step 5: Create the BxCxC index map idx_map = torch.full((B, C, C), topk+1, dtype=torch.long, device=xyz.device) for b in range(B): for i in range(C): idx_map[b, i, nearest_topk[b, i]] = torch.arange(1, topk+1, device=xyz.device) idx_map[:,torch.arange(C),torch.arange(C)] = 0 # zero for the same chain, 1~topk for the nearest chains, topk+1 for the rest # Step 6: Expand to BxLxL for embedding chain_idx_map = torch.zeros((L), dtype=torch.long, device=xyz.device) for i, (start, end) in enumerate(chain_break.values()): chain_idx_map[start:end+1] = i expanded_idx_map = idx_map[:, chain_idx_map, :][:, :, chain_idx_map] return expanded_idx_map # (B, L, L), 0~topk+1
[docs] def forward(self, idx, chain_mask, xyz): B, L = idx.shape[:2] bins = torch.arange(self.minpos, self.maxpos, device=idx.device) seqsep = torch.full((B,L,L),100, device=idx.device) seqsep = idx[:,None,:] - idx[:,:,None] # (B, L, L) # ib = torch.bucketize(seqsep, bins).long() # (B, L, L) # PSK 20230913: interchain -> 64 chain_mask = chain_mask.to(torch.bool) ib[~chain_mask] = self.nbin - 1 emb = self.emb(ib) #(B, L, L, d_model) chain_idx = self.get_chain_embedding(xyz, chain_mask, self.topk) # (B, L, L) emb_chain = self.emb_chain(chain_idx) # (B, L, L, d_model) return emb + emb_chain
UNK_IDX = 20 GAP_IDX = 21 MASK_IDX = 22 NUM_CLASSES = 23
[docs] class InteractionEmb(nn.Module): def __init__(self, d_pair=128): super(InteractionEmb, self).__init__() self.emb = nn.Embedding(3, d_pair) # 3 for 0 (no information), 1 (positive), 2 (negative)
[docs] def forward(self, interaction_points): """ Inputs: - interaction_points: Interaction points (B, L, L) 0 (no information), 1 (positive), 2 (negative) Outputs: - emb: Interaction embedding (B, L, L, d_pair) """ emb = self.emb(interaction_points) return emb
[docs] class Species_emb(nn.Module): def __init__(self, d_msa=256, p_drop=0.1): super(Species_emb, self).__init__() self.species_emb = nn.Embedding(1, d_msa) self.reset_parameter()
[docs] def reset_parameter(self): self.species_emb = init_lecun_normal(self.species_emb)
[docs] def forward(self, msa_speices): msa_speices = self.species_emb(msa_speices) return msa_speices
[docs] class MSA_emb(nn.Module): # Get initial seed MSA embedding def __init__(self, d_msa=256, d_pair=128, d_state=32, d_init=NUM_CLASSES+1, minpos=-32, maxpos=32, p_drop=0.1): super(MSA_emb, self).__init__() self.emb = nn.Linear(d_init, d_msa) # embedding for general MSA self.species_emb = Species_emb(d_msa=d_msa, p_drop=p_drop) self.emb_left = nn.Embedding(NUM_CLASSES+1, d_pair) # embedding for query sequence -- used for pair embedding self.emb_right = nn.Embedding(NUM_CLASSES+1, d_pair) # embedding for query sequence -- used for pair embedding self.emb_state = nn.Embedding(NUM_CLASSES+1, d_state) self.pos = PositionalEncoding2D(d_pair, minpos=minpos, maxpos=maxpos) self.d_init = d_init self.d_msa = d_msa self.reset_parameter()
[docs] def reset_parameter(self): self.emb = init_lecun_normal(self.emb) self.emb_left = init_lecun_normal(self.emb_left) self.emb_right = init_lecun_normal(self.emb_right) self.emb_state = init_lecun_normal(self.emb_state) nn.init.zeros_(self.emb.bias)
[docs] def forward(self, msa, msa_species, msa_zero_pos, seq, idx, chain_mask, xyz, symmids=None): # Inputs: # - msa: Input MSA (B, N, L, d_init) # - seq: Input Sequence (B, L) # - idx: Residue index # Outputs: # - msa: Initial MSA embedding (B, N, L, d_msa) # - pair: Initial Pair embedding (B, L, L, d_pair) B, N, L = msa.shape[:3] # number of sequenes in MSA # msa embedding # msa_zero_pos = msa_unk_pos | msa_mask_pos msa = self.emb(msa) # (B, N, L, d_model) # MSA embedding msa[msa_zero_pos] = 0.0 # zero out UNK and MASK msa_species = self.species_emb(msa_species) # (B, N, d_model) # Species embedding msa_species = msa_species.unsqueeze(-2) # (B, N, 1, d_model) msa = torch.cat((msa_species, msa), dim=-2) # (B, N, L+1, d_model) # 20231116 PSK MiniWorld V0.5.2 # zero out UNK and MASK # And don't update msa using query sequence seq = F.pad(seq, (1,0), value=NUM_CLASSES) # (B, L+1) # pair embedding left = self.emb_left(seq)[:,None] # (B, 1, L+1, d_pair) right = self.emb_right(seq)[:,:,None] # (B, L+1, 1, d_pair) pair = (left + right) # (B, L+1, L+1, d_pair) pair = pair.clone() pair[:,1:,1:] = pair[:,1:,1:] + self.pos(idx, chain_mask, xyz) # add relative position # state embedding state = self.emb_state(seq) #.repeat(oligo,1,1) return msa, pair, state
[docs] class Extra_emb(nn.Module): # Get initial seed MSA embedding def __init__(self, d_msa=256, d_init=NUM_CLASSES+1, p_drop=0.1): super(Extra_emb, self).__init__() self.emb = nn.Linear(d_init, d_msa) # embedding for general MSA self.emb_q = nn.Embedding(NUM_CLASSES, d_msa) # embedding for query sequence self.d_init = d_init self.d_msa = d_msa self.reset_parameter()
[docs] def reset_parameter(self): self.emb = init_lecun_normal(self.emb) nn.init.zeros_(self.emb.bias)
[docs] def forward(self, msa, seq, idx, oligo=1): # Inputs: # - msa: Input MSA (B, N, L, d_init) # - seq: Input Sequence (B, L) # - idx: Residue index # Outputs: # - msa: Initial MSA embedding (B, N, L, d_msa) #N = msa.shape[1] # number of sequenes in MSA B,N,L = msa.shape[:3] msa = self.emb(msa) # (B, N, L, d_model) # MSA embedding seq = self.emb_q(seq).unsqueeze(1) # (B, 1, L, d_model) -- query embedding msa = msa + seq.expand(-1, N, -1, -1) # adding query embedding to MSA return msa
[docs] class TemplatePairStack(nn.Module): # process template pairwise features # use structure-biased attention def __init__(self, n_block=2, d_templ=64, n_head=4, d_hidden=16, d_t1d=NUM_CLASSES+1, d_state=32, p_drop=0.25): super(TemplatePairStack, self).__init__() self.n_block = n_block # self.proj_t1d = nn.Linear(d_t1d, d_state) proc_s = [PairStr2Pair(d_pair=d_templ, n_head=n_head, d_hidden=d_hidden, d_state=d_state, p_drop=p_drop) for i in range(n_block)] self.block = nn.ModuleList(proc_s) self.norm = nn.LayerNorm(d_templ) # self.reset_parameter() # def reset_parameter(self): # self.proj_t1d = init_lecun_normal(self.proj_t1d) # nn.init.zeros_(self.proj_t1d.bias)
[docs] def forward(self, templ, rbf_feat, t1d, use_checkpoint=False, p2p_crop=-1, symmids=None): B, T, L = templ.shape[:3] templ = templ.reshape(B*T, L, L, -1) # t1d = t1d.reshape(B*T, L, -1) # state = self.proj_t1d(t1d) for i_block in range(self.n_block): if use_checkpoint: templ = checkpoint.checkpoint(create_custom_forward(self.block[i_block]), templ, p2p_crop, False) #, symmids) else: templ = self.block[i_block](templ, p2p_crop, False) #, symmids) return self.norm(templ).reshape(B, T, L, L, -1)
[docs] class TemplateTorsionStack(nn.Module): def __init__(self, n_block=2, d_templ=64, d_rbf=64, n_head=4, d_hidden=16, p_drop=0.15): super(TemplateTorsionStack, self).__init__() self.n_block=n_block self.proj_pair = nn.Linear(d_templ+d_rbf, d_templ) proc_s = [AttentionWithBias(d_in=d_templ, d_bias=d_templ, n_head=n_head, d_hidden=d_hidden) for i in range(n_block)] self.row_attn = nn.ModuleList(proc_s) proc_s = [FeedForwardLayer(d_templ, 4, p_drop=p_drop) for i in range(n_block)] self.ff = nn.ModuleList(proc_s) self.norm = nn.LayerNorm(d_templ)
[docs] def reset_parameter(self): self.proj_pair = init_lecun_normal(self.proj_pair) nn.init.zeros_(self.proj_pair.bias)
[docs] def forward(self, tors, pair, rbf_feat, use_checkpoint=False): B, T, L = tors.shape[:3] tors = tors.reshape(B*T, L, -1) pair = pair.reshape(B*T, L, L, -1) pair = torch.cat((pair, rbf_feat), dim=-1) pair = self.proj_pair(pair) for i_block in range(self.n_block): if use_checkpoint: tors = tors + checkpoint.checkpoint(create_custom_forward(self.row_attn[i_block]), tors, pair) else: tors = tors + self.row_attn[i_block](tors, pair) tors = tors + self.ff[i_block](tors) return self.norm(tors).reshape(B, T, L, -1)
[docs] class Templ_emb(nn.Module): # Get template embedding # Features are # t2d: # - 37 distogram bins + 6 orientations (43) # - Mask (missing/unaligned) (1) # t1d: # - tiled AA sequence (20 standard aa + gap) # - confidence (1) # def __init__(self, d_t1d=NUM_CLASSES+1, d_t2d=43+1, d_tor=30, d_pair=128, d_state=32, n_block=2, d_templ=64, n_head=4, d_hidden=16, p_drop=0.25): super(Templ_emb, self).__init__() # process 2D features self.emb = nn.Linear(d_t1d*2+d_t2d, d_templ) self.templ_stack = TemplatePairStack(n_block=n_block, d_templ=d_templ, n_head=n_head, d_hidden=d_hidden, p_drop=p_drop) # self.attn = Attention(d_pair, d_templ, n_head, d_hidden, d_pair, p_drop=p_drop) self.proj_templ = nn.Linear(d_templ, d_pair) # process torsion angles self.proj_t1d = nn.Linear(d_t1d+d_tor, d_templ) self.attn_tor = Attention(d_state, d_templ, n_head, d_hidden, d_state, p_drop=p_drop) self.reset_parameter()
[docs] def reset_parameter(self): self.emb = init_lecun_normal(self.emb) nn.init.zeros_(self.emb.bias) nn.init.kaiming_normal_(self.proj_templ.weight, nonlinearity='relu') nn.init.zeros_(self.proj_templ.bias) nn.init.kaiming_normal_(self.proj_t1d.weight, nonlinearity='relu') nn.init.zeros_(self.proj_t1d.bias)
def _get_templ_emb(self, t1d, t2d): B, T, L, _ = t1d.shape # Prepare 2D template features left = t1d.unsqueeze(3).expand(-1,-1,-1,L,-1) right = t1d.unsqueeze(2).expand(-1,-1,L,-1,-1) # templ = torch.cat((t2d, left, right), -1) # (B, T, L, L, NUM_CLASSES * 4) return self.emb(templ) # Template templures (B, T, L, L, d_templ) def _get_templ_rbf(self, xyz_t, mask_t): B, T, L = xyz_t.shape[:3] # process each template features xyz_t = xyz_t.reshape(B*T, L, 3) mask_t = mask_t.reshape(B*T, L, L) diff = xyz_t.unsqueeze(2) - xyz_t.unsqueeze(1) # (B*T, L, L, 3) rbf_feat = rbf(torch.norm(diff,dim=-1)).to(xyz_t.dtype) * mask_t[...,None] # (B*T, L, L, d_rbf) return rbf_feat
[docs] def forward(self, t1d, t2d, alpha_t, xyz_t, mask_t, pair, state, use_checkpoint=False, p2p_crop=-1, symmids=None): # Input # - t1d: 1D template info (B, T, L, NUM_CLASSES) # - t2d: 2D template info (B, T, L, L, 44) # - alpha_t: torsion angle info (B, T, L, 30) # - xyz_t: template CA coordinates (B, T, L, 3) # - mask_t: is valid residue pair? (B, T, L, L) # - pair: query pair features (B, L+1, L+1, d_pair) # - state: query state features (B, L+1, d_state) B, T, L, _ = t1d.shape pair_without_species = pair[:,1:,1:] state_without_species = state[:,1:] templ = self._get_templ_emb(t1d, t2d) rbf_feat = self._get_templ_rbf(xyz_t, mask_t) # process each template pair feature templ = self.templ_stack( templ, rbf_feat, t1d, use_checkpoint=use_checkpoint, p2p_crop=p2p_crop, symmids=symmids ).to(pair.dtype) # (B, T, L,L, d_templ) # Prepare 1D template torsion angle features t1d = torch.cat((t1d, alpha_t), dim=-1) # (B, T, L, NUM_CLASSES+30) t1d = self.proj_t1d(t1d) # mixing query state features to template state features state_without_species = state_without_species.reshape(B*L, 1, -1) # (B*L, 1, d_state) t1d = t1d.permute(0,2,1,3).reshape(B*L, T, -1) if use_checkpoint: out = checkpoint.checkpoint(create_custom_forward(self.attn_tor), state_without_species, t1d, t1d) out = out.reshape(B, L, -1) else: out = self.attn_tor(state_without_species, t1d, t1d).reshape(B, L, -1) state = state + F.pad(out, (0,0,1,0), value=0.0) # (B, L+1, d_state) # mixing query pair features to template information (Template pointwise attention) pair_without_species = pair_without_species.reshape(B*L*L, 1, -1) templ = templ.permute(0, 2, 3, 1, 4).reshape(B*L*L, T, -1) # if use_checkpoint: # out = checkpoint.checkpoint(create_custom_forward(self.attn), pair_without_species, templ, templ) # out = out.reshape(B, L, L, -1) # else: # out = self.attn(pair_without_species, templ, templ).reshape(B, L, L, -1) # out = templ.mean(dim=-2).reshape(B, L, L, -1) # (B,L,L,d_templ) out = self.proj_templ(out) # (B, L, L, d_pair) pair_without_species = pair_without_species.reshape(B, L, L, -1) pair[:,1:,1:] = pair[:,1:,1:] + out return pair, state
[docs] class Recycling(nn.Module): def __init__(self, value_net, d_msa=256, d_pair=128, d_state=32, d_rbf=64): super(Recycling, self).__init__() #self.emb_rbf = nn.Linear(d_rbf, d_pair) self.value_net = value_net self.norm_pair = nn.LayerNorm(d_pair) self.norm_msa = nn.LayerNorm(d_msa) self.norm_state = nn.LayerNorm(d_state) self.update_pair_by_value = UpdatePairByValue(d_pair=d_pair, d_rbf=d_rbf)
[docs] def forward(self, target_seq, idx, chain_mask, msa, pair, state, xyz, mask_recycle=None): B, L = msa.shape[:2] msa = self.norm_msa(msa) L = L-1 state = self.norm_state(state) pair = self.norm_pair(pair) # recreate Cb given N,Ca,C Cb = get_Cb(xyz[:,:,:3]) diff = Cb.unsqueeze(2) - Cb.unsqueeze(1) # (B, L, L, 3) dist_CB = rbf( torch.norm(diff,dim=-1) ).reshape(B,L,L,-1) if mask_recycle != None: dist_CB = mask_recycle[...,None].float()*dist_CB h_E, pae_neighbor, E_idx = self.value_net(target_seq, idx, chain_mask, xyz) pair = self.update_pair_by_value(xyz, pair, h_E, pae_neighbor, E_idx) return msa, pair, state