Source code for miniworld.models_MiniWorld_v1_5_use_interaction.Attention_module

import math
from miniworld.utils.util_module import *

[docs] class FeedForwardLayer(nn.Module): def __init__(self, d_model, r_ff, p_drop=0.1): super(FeedForwardLayer, self).__init__() self.norm = nn.LayerNorm(d_model) self.linear1 = nn.Linear(d_model, d_model*r_ff) self.dropout = nn.Dropout(p_drop) self.linear2 = nn.Linear(d_model*r_ff, d_model) self.reset_parameter()
[docs] def reset_parameter(self): # initialize linear layer right before ReLu: He initializer (kaiming normal) nn.init.kaiming_normal_(self.linear1.weight, nonlinearity='relu') nn.init.zeros_(self.linear1.bias) # initialize linear layer right before residual connection: zero initialize nn.init.zeros_(self.linear2.weight) nn.init.zeros_(self.linear2.bias)
[docs] def forward(self, src): src = self.norm(src) src = self.linear2(self.dropout(F.relu_(self.linear1(src)))) return src
[docs] class Attention(nn.Module): # calculate multi-head attention def __init__(self, d_query, d_key, n_head, d_hidden, d_out, p_drop=0.1): super(Attention, self).__init__() self.h = n_head self.dim = d_hidden # self.to_q = nn.Linear(d_query, n_head*d_hidden, bias=False) self.to_k = nn.Linear(d_key, n_head*d_hidden, bias=False) self.to_v = nn.Linear(d_key, n_head*d_hidden, bias=False) # self.to_out = nn.Linear(n_head*d_hidden, d_out) self.scaling = 1/math.sqrt(d_hidden) # # initialize all parameters properly self.reset_parameter()
[docs] def reset_parameter(self): # query/key/value projection: Glorot uniform / Xavier uniform nn.init.xavier_uniform_(self.to_q.weight) nn.init.xavier_uniform_(self.to_k.weight) nn.init.xavier_uniform_(self.to_v.weight) # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining nn.init.zeros_(self.to_out.weight) nn.init.zeros_(self.to_out.bias)
[docs] def forward(self, query, key, value): B, Q = query.shape[:2] B, K = key.shape[:2] # query = self.to_q(query).reshape(B, Q, self.h, self.dim) key = self.to_k(key).reshape(B, K, self.h, self.dim) value = self.to_v(value).reshape(B, K, self.h, self.dim) # query = query * self.scaling attn = einsum('bqhd,bkhd->bhqk', query, key) attn = F.softmax(attn, dim=-1) # out = einsum('bhqk,bkhd->bqhd', attn, value) out = out.reshape(B, Q, self.h*self.dim) # out = self.to_out(out) return out
[docs] class AttentionWithBias(nn.Module): def __init__(self, d_in=256, d_bias=128, n_head=8, d_hidden=32): super(AttentionWithBias, self).__init__() self.norm_in = nn.LayerNorm(d_in) self.norm_bias = nn.LayerNorm(d_bias) # self.to_q = nn.Linear(d_in, n_head*d_hidden, bias=False) self.to_k = nn.Linear(d_in, n_head*d_hidden, bias=False) self.to_v = nn.Linear(d_in, n_head*d_hidden, bias=False) self.to_b = nn.Linear(d_bias, n_head, bias=False) self.to_g = nn.Linear(d_in, n_head*d_hidden) self.to_out = nn.Linear(n_head*d_hidden, d_in) self.scaling = 1/math.sqrt(d_hidden) self.h = n_head self.dim = d_hidden self.reset_parameter()
[docs] def reset_parameter(self): # query/key/value projection: Glorot uniform / Xavier uniform nn.init.xavier_uniform_(self.to_q.weight) nn.init.xavier_uniform_(self.to_k.weight) nn.init.xavier_uniform_(self.to_v.weight) # bias: normal distribution self.to_b = init_lecun_normal(self.to_b) # gating: zero weights, one biases (mostly open gate at the begining) nn.init.zeros_(self.to_g.weight) nn.init.ones_(self.to_g.bias) # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining nn.init.zeros_(self.to_out.weight) nn.init.zeros_(self.to_out.bias)
[docs] def forward(self, x, bias): B, L = x.shape[:2] # x = self.norm_in(x) bias = self.norm_bias(bias) # query = self.to_q(x).reshape(B, L, self.h, self.dim) key = self.to_k(x).reshape(B, L, self.h, self.dim) value = self.to_v(x).reshape(B, L, self.h, self.dim) bias = self.to_b(bias) # (B, L, L, h) gate = torch.sigmoid(self.to_g(x)) # key = key * self.scaling attn = einsum('bqhd,bkhd->bqkh', query, key) attn = attn + bias attn = F.softmax(attn, dim=-2) # out = einsum('bqkh,bkhd->bqhd', attn, value).reshape(B, L, -1) out = gate * out # out = self.to_out(out) return out
# MSA Attention (row/column) from AlphaFold architecture
[docs] class SequenceWeight(nn.Module): def __init__(self, d_msa, n_head, d_hidden, p_drop=0.1): super(SequenceWeight, self).__init__() self.h = n_head self.dim = d_hidden self.scale = 1.0 / math.sqrt(self.dim) self.to_query = nn.Linear(d_msa, n_head*d_hidden) self.to_key = nn.Linear(d_msa, n_head*d_hidden) self.dropout = nn.Dropout(p_drop) self.reset_parameter()
[docs] def reset_parameter(self): # query/key/value projection: Glorot uniform / Xavier uniform nn.init.xavier_uniform_(self.to_query.weight) nn.init.xavier_uniform_(self.to_key.weight)
[docs] def forward(self, msa): B, N, L = msa.shape[:3] tar_seq = msa[:,0] q = self.to_query(tar_seq).view(B, 1, L, self.h, self.dim) k = self.to_key(msa).view(B, N, L, self.h, self.dim) q = q * self.scale attn = einsum('bqihd,bkihd->bkihq', q, k) attn = F.softmax(attn, dim=1) return self.dropout(attn)
[docs] class MSARowAttentionWithBias(nn.Module): def __init__(self, d_msa=256, d_pair=128, n_head=8, d_hidden=32): super(MSARowAttentionWithBias, self).__init__() self.norm_msa = nn.LayerNorm(d_msa) self.norm_pair = nn.LayerNorm(d_pair) # self.seq_weight = SequenceWeight(d_msa, n_head, d_hidden, p_drop=0.1) self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False) self.to_k = nn.Linear(d_msa, n_head*d_hidden, bias=False) self.to_v = nn.Linear(d_msa, n_head*d_hidden, bias=False) self.to_b = nn.Linear(d_pair, n_head, bias=False) self.to_g = nn.Linear(d_msa, n_head*d_hidden) self.to_out = nn.Linear(n_head*d_hidden, d_msa) self.scaling = 1/math.sqrt(d_hidden) self.h = n_head self.dim = d_hidden self.reset_parameter()
[docs] def reset_parameter(self): # query/key/value projection: Glorot uniform / Xavier uniform nn.init.xavier_uniform_(self.to_q.weight) nn.init.xavier_uniform_(self.to_k.weight) nn.init.xavier_uniform_(self.to_v.weight) # bias: normal distribution self.to_b = init_lecun_normal(self.to_b) # gating: zero weights, one biases (mostly open gate at the begining) nn.init.zeros_(self.to_g.weight) nn.init.ones_(self.to_g.bias) # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining nn.init.zeros_(self.to_out.weight) nn.init.zeros_(self.to_out.bias)
[docs] def forward(self, msa, pair, sym=None): # TODO: make this as tied-attention B, N, L = msa.shape[:3] # O = pair.shape[0] # PSK, 20230730 난 symmetry 안 쓰니까 O == 1로 강제. 그리고 pair.shape[0]이 내 코드에선 Batch size임. # if (B==1): msa = self.norm_msa(msa) pair = self.norm_pair(pair) seq_weight = self.seq_weight(msa) # (B, N, L, h, 1) query = self.to_q(msa).reshape(B, N, L, self.h, self.dim) key = self.to_k(msa).reshape(B, N, L, self.h, self.dim) value = self.to_v(msa).reshape(B, N, L, self.h, self.dim) bias = self.to_b(pair) # (B, L, L, h) gate = torch.sigmoid(self.to_g(msa)) query = query * seq_weight.expand(-1, -1, -1, -1, self.dim) key = key * self.scaling attn = einsum('bsqhd,bskhd->bqkh', query, key) attn = attn + bias attn = F.softmax(attn, dim=-2) out = einsum('bqkh,bskhd->bsqhd', attn, value).reshape(B, N, L, -1) out = gate * out out = self.to_out(out) return out
[docs] class MSAColAttention(nn.Module): def __init__(self, d_msa=256, n_head=8, d_hidden=32): super(MSAColAttention, self).__init__() self.norm_msa = nn.LayerNorm(d_msa) # self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False) self.to_k = nn.Linear(d_msa, n_head*d_hidden, bias=False) self.to_v = nn.Linear(d_msa, n_head*d_hidden, bias=False) self.to_g = nn.Linear(d_msa, n_head*d_hidden) self.to_out = nn.Linear(n_head*d_hidden, d_msa) self.scaling = 1/math.sqrt(d_hidden) self.h = n_head self.dim = d_hidden self.reset_parameter()
[docs] def reset_parameter(self): # query/key/value projection: Glorot uniform / Xavier uniform nn.init.xavier_uniform_(self.to_q.weight) nn.init.xavier_uniform_(self.to_k.weight) nn.init.xavier_uniform_(self.to_v.weight) # gating: zero weights, one biases (mostly open gate at the begining) nn.init.zeros_(self.to_g.weight) nn.init.ones_(self.to_g.bias) # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining nn.init.zeros_(self.to_out.weight) nn.init.zeros_(self.to_out.bias)
[docs] def forward(self, msa): B, N, L = msa.shape[:3] msa_species = msa[:,:,:1] # (B, N, 1, d_msa) # msa = self.norm_msa(msa) # query = self.to_q(msa_species).reshape(B, N, self.h, self.dim) key = self.to_k(msa_species).reshape(B, N, self.h, self.dim) value = self.to_v(msa).reshape(B, N, L, self.h, self.dim) gate = torch.sigmoid(self.to_g(msa)) # query = query * self.scaling attn = einsum('bqhd,bkhd->bhqk', query, key) # (B, h, N, N) attn = F.softmax(attn, dim=-1) # out = einsum('bhqk,bkihd->bqihd', attn, value).reshape(B, N, L, -1) out = gate * out # out = self.to_out(out) return out
[docs] class MSAColGlobalAttention(nn.Module): def __init__(self, d_msa=64, n_head=8, d_hidden=8): super(MSAColGlobalAttention, self).__init__() self.norm_msa = nn.LayerNorm(d_msa) # self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False) self.to_k = nn.Linear(d_msa, d_hidden, bias=False) self.to_v = nn.Linear(d_msa, d_hidden, bias=False) self.to_g = nn.Linear(d_msa, n_head*d_hidden) self.to_out = nn.Linear(n_head*d_hidden, d_msa) self.scaling = 1/math.sqrt(d_hidden) self.h = n_head self.dim = d_hidden self.reset_parameter()
[docs] def reset_parameter(self): # query/key/value projection: Glorot uniform / Xavier uniform nn.init.xavier_uniform_(self.to_q.weight) nn.init.xavier_uniform_(self.to_k.weight) nn.init.xavier_uniform_(self.to_v.weight) # gating: zero weights, one biases (mostly open gate at the begining) nn.init.zeros_(self.to_g.weight) nn.init.ones_(self.to_g.bias) # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining nn.init.zeros_(self.to_out.weight) nn.init.zeros_(self.to_out.bias)
[docs] def forward(self, msa): B, N, L = msa.shape[:3] # msa = self.norm_msa(msa) # query = self.to_q(msa).reshape(B, N, L, self.h, self.dim) query = query.mean(dim=1) # (B, L, h, dim) key = self.to_k(msa) # (B, N, L, dim) value = self.to_v(msa) # (B, N, L, dim) gate = torch.sigmoid(self.to_g(msa)) # (B, N, L, h*dim) # query = query * self.scaling attn = einsum('bihd,bkid->bihk', query, key) # (B, L, h, N) attn = F.softmax(attn, dim=-1) # out = einsum('bihk,bkid->bihd', attn, value).reshape(B, 1, L, -1) # (B, 1, L, h*dim) out = gate * out # (B, N, L, h*dim) # out = self.to_out(out) return out
# Instead of triangle attention, use Tied axail attention with bias from coordinates..?
[docs] class BiasedAxialAttention(nn.Module): def __init__(self, d_pair, d_bias, n_head, d_hidden, p_drop=0.1, is_row=True): super(BiasedAxialAttention, self).__init__() # self.is_row = is_row self.norm_pair = nn.LayerNorm(d_pair) self.norm_bias = nn.LayerNorm(d_pair) self.to_q = nn.Linear(d_pair, n_head*d_hidden, bias=False) self.to_k = nn.Linear(d_pair, n_head*d_hidden, bias=False) self.to_v = nn.Linear(d_pair, n_head*d_hidden, bias=False) self.to_b = nn.Linear(d_pair, n_head, bias=False) self.to_g = nn.Linear(d_pair, n_head*d_hidden) self.to_out = nn.Linear(n_head*d_hidden, d_pair) self.scaling = 1/math.sqrt(d_hidden) self.h = n_head self.dim = d_hidden self.dim_out = d_pair # initialize all parameters properly self.reset_parameter()
[docs] def reset_parameter(self): # query/key/value projection: Glorot uniform / Xavier uniform nn.init.xavier_uniform_(self.to_q.weight) nn.init.xavier_uniform_(self.to_k.weight) nn.init.xavier_uniform_(self.to_v.weight) # bias: normal distribution self.to_b = init_lecun_normal(self.to_b) # gating: zero weights, one biases (mostly open gate at the begining) nn.init.zeros_(self.to_g.weight) nn.init.ones_(self.to_g.bias) # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining nn.init.zeros_(self.to_out.weight) nn.init.zeros_(self.to_out.bias)
[docs] def forward(self, pair, sym=None, stride=-1): O, L = pair.shape[:2] # after subunit mask is applied if O==1 or sym is None or sym.shape[0]!=O: # asymm mode # pair: (B, L, L, d_pair) if self.is_row: pair = pair.permute(0,2,1,3) pair = self.norm_pair(pair) bias = self.norm_bias(pair) query = self.to_q(pair).reshape(O, L, L, self.h, self.dim) key = self.to_k(pair).reshape(O, L, L, self.h, self.dim) value = self.to_v(pair).reshape(O, L, L, self.h, self.dim) bias = self.to_b(bias) # (B, L, L, h) gate = torch.sigmoid(self.to_g(pair)) # (B, L, L, h*dim) query = query * self.scaling key = key / L # normalize for tied attention attn = einsum('bnihk,bnjhk->bijh', query, key) # tied attention attn = attn + bias attn = F.softmax(attn, dim=-2) # (B, L, L, h) out = einsum('bijh,bnjhd->bnihd', attn, value).reshape(O, L, L, -1) out = gate * out out = self.to_out(out) if self.is_row: out = out.permute(0,2,1,3) else: # symmetric version if self.is_row: pair = pair[sym[0,:]].permute(0,2,1,3) pair = self.norm_pair(pair) bias = self.norm_bias(pair) query = self.to_q(pair).reshape(O, L, L, self.h, self.dim) key = self.to_k(pair).reshape(O, L, L, self.h, self.dim) value = self.to_v(pair).reshape(O, L, L, self.h, self.dim) bias = self.to_b(pair) # (B, L, L, h) gate = torch.sigmoid(self.to_g(pair)) # (B, L, L, h*dim) query = query * self.scaling key = key / (O*L) # normalize for tied attention attn=torch.zeros((O,L,L,self.h), device=pair.device) for i in range(O): attn[i] = torch.einsum('bnihk,bnjhk->ijh', query[sym[:,i]], key[sym[:,0]]) # tied attention #attn = einsum('bnihk,bnjhk->bijh', query, key) # tied attention attn = attn + bias # apply bias # softmax over dims 0 & 2 attn = F.softmax( attn.transpose(1,2).reshape(O*L,L,self.h), dim=0 ).reshape(O,L,L,self.h).transpose(1,2) out=torch.zeros((O,L,L,self.h,self.dim), device=pair.device) for i in range(O): out[i] = torch.einsum('bijh,bnjhd->nihd', attn[sym[:,i]], value) # tied attention #out = einsum('bijh,bnjhd->bnihd', attn, value).reshape(O, L, L, -1) out = gate * out.reshape(O,L,L,-1) out = self.to_out(out) if self.is_row: out = out[sym[0,:]].permute(0,2,1,3) return out
[docs] class TriangleMultiplication(nn.Module): def __init__(self, d_pair, d_hidden=128, outgoing=True): super(TriangleMultiplication, self).__init__() self.norm = nn.LayerNorm(d_pair) self.left_proj = nn.Linear(d_pair, d_hidden) self.right_proj = nn.Linear(d_pair, d_hidden) self.left_gate = nn.Linear(d_pair, d_hidden) self.right_gate = nn.Linear(d_pair, d_hidden) # self.gate = nn.Linear(d_pair, d_pair) self.norm_out = nn.LayerNorm(d_hidden) self.out_proj = nn.Linear(d_hidden, d_pair) self.d_hidden = d_hidden self.outgoing = outgoing self.reset_parameter()
[docs] def reset_parameter(self): # normal distribution for regular linear weights self.left_proj = init_lecun_normal(self.left_proj) self.right_proj = init_lecun_normal(self.right_proj) # Set Bias of Linear layers to zeros nn.init.zeros_(self.left_proj.bias) nn.init.zeros_(self.right_proj.bias) # gating: zero weights, one biases (mostly open gate at the begining) nn.init.zeros_(self.left_gate.weight) nn.init.ones_(self.left_gate.bias) nn.init.zeros_(self.right_gate.weight) nn.init.ones_(self.right_gate.bias) nn.init.zeros_(self.gate.weight) nn.init.ones_(self.gate.bias) # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining nn.init.zeros_(self.out_proj.weight) nn.init.zeros_(self.out_proj.bias)
[docs] def forward(self, pair, sym=None): pair = self.norm(pair) O,L = pair.shape[:2] if O==1 or sym is None or sym.shape[0]!=O: # asymm mode left = self.left_proj(pair) # (B, L, L, d_h) left_gate = torch.sigmoid(self.left_gate(pair)) left = left_gate * left right = self.right_proj(pair) # (B, L, L, d_h) right_gate = torch.sigmoid(self.right_gate(pair)) right = right_gate * right if self.outgoing: out = einsum('bikd,bjkd->bijd', left, right/float(L)) else: out = einsum('bkid,bkjd->bijd', left, right/float(L)) out = self.norm_out(out) out = self.out_proj(out) gate = torch.sigmoid(self.gate(pair)) # (B, L, L, d_pair) out = gate * out else: left = self.left_proj(pair) # (B, L, L, d_h) left_gate = torch.sigmoid(self.left_gate(pair)) left = left_gate * left right = self.right_proj(pair) # (B, L, L, d_h) right_gate = torch.sigmoid(self.right_gate(pair)) right = right_gate * right if self.outgoing: out=torch.zeros((O,L,L,self.d_hidden), device=pair.device) for i in range(O): out[i] = torch.einsum('bikd,bjkd->ijd', left[sym[i,:]], right[sym[0,:]]/float(O*L)) # tied attention else: out=torch.zeros((O,L,L,self.d_hidden), device=pair.device) for i in range(O): out[i] = torch.einsum('bkid,bkjd->ijd', left[sym[:,i]], right[sym[:,0]]/float(O*L)) # tied attention out = self.norm_out(out) out = self.out_proj(out) gate = torch.sigmoid(self.gate(pair)) # (B, L, L, d_pair) out = gate * out return out
[docs] class StateAttentionGate(nn.Module): def __init__(self, d_state, d_pair, n_head=8, d_hidden=32, p_drop=0.1): super(StateAttentionGate, self).__init__() self.norm_state = nn.LayerNorm(d_state) self.norm_pair = nn.LayerNorm(d_pair) self.to_q = nn.Linear(d_state, n_head*d_hidden, bias=False) self.to_k = nn.Linear(d_state, n_head*d_hidden, bias=False) self.to_v = nn.Linear(d_state, n_head*d_hidden, bias=False) self.to_b = nn.Linear(d_pair, n_head, bias=False) self.to_g = nn.Linear(d_state, n_head*d_hidden) self.scaling = 1/math.sqrt(d_hidden) self.h = n_head self.dim = d_hidden self.reset_parameter()
[docs] def reset_parameter(self): # query/key/value projection: Glorot uniform / Xavier uniform nn.init.xavier_uniform_(self.to_q.weight) nn.init.xavier_uniform_(self.to_k.weight) nn.init.xavier_uniform_(self.to_v.weight) # bias: normal distribution self.to_b = init_lecun_normal(self.to_b) # gating: zero weights, one biases (mostly open gate at the begining) nn.init.zeros_(self.to_g.weight) nn.init.ones_(self.to_g.bias) # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining nn.init.zeros_(self.to_out.weight) nn.init.zeros_(self.to_out.bias)
[docs] def forward(self, state, pair, chain_mask): ''' state : (B, L, d_state) pair : (B, L, L, d_pair) chain_mask : (B, L, L) ''' B, L, C = state.shape[:3] # state = self.norm_state(state) pair = self.norm_pair(pair) interchain_mask = ~chain_mask.unsqueeze(-1) # (B, L, L, 1) query = self.to_q(state).reshape(B, L, self.h, self.dim) # (B, L, h, dim) key = self.to_k(state).reshape(B, L, self.h, self.dim) # (B, L, h, dim) value = self.to_v(state).reshape(B, L, self.h, self.dim) # (B, L, h, dim) bias = self.to_b(pair) # (B, L, L, h) gate = torch.sigmoid(self.to_g(state)) key = key * self.scaling attn = einsum('bqhd,bkhd->bhqk', query, key) # (B, h, L, L) attn = attn.reshape(B, L, L, self.h) # (B, L, L, h) attn = attn + bias # (B, L, L, h) attn = F.softmax(attn, dim=-2) # (B, L, L, h) attn = attn * interchain_mask # (B, L, L, h) attn = attn / attn.sum(dim=1, keepdim=True) # (B, L, L, h) out = einsum('bhqk,bkhd->bqhd', attn, value).reshape(B, L, -1) # (B, L, h*dim) gate = torch.sigmoid(self.to_g(out)) # (B, L, h*dim) return gate
[docs] class GraphTriangleAttention(nn.Module): def __init__(self, d_pair, d_state, top_k=64, d_rbf=64, n_head=8, n_query_pt=4, d_hidden=32): super(GraphTriangleAttention, self).__init__() self.n_head = n_head self.dim = d_hidden self.top_k = top_k self.n_query_pt = n_query_pt self.norm_pair = nn.LayerNorm(d_pair) self.norm_state = nn.LayerNorm(d_state) self.to_qk_point = nn.Linear(d_state, 2*n_head*n_query_pt*3, bias=False) self.to_bias = nn.Linear(d_rbf, 1, bias=False) self.to_q = nn.Linear(d_pair, n_head*d_hidden, bias=False) self.to_k = nn.Linear(d_pair, n_head*d_hidden, bias=False) self.to_v = nn.Linear(d_pair, n_head*d_hidden, bias=False) self.to_gate = nn.Linear(d_pair, n_head*d_hidden) self.to_out = nn.Linear(n_head*d_hidden, d_pair) self.reset_parameter()
[docs] def reset_parameter(self): # query/key/value projection: Glorot uniform / Xavier uniform nn.init.xavier_uniform_(self.to_q.weight) nn.init.xavier_uniform_(self.to_k.weight) nn.init.xavier_uniform_(self.to_v.weight) nn.init.xavier_uniform_(self.to_qk_point.weight) # bias: normal distribution self.to_bias = init_lecun_normal(self.to_bias) # gating: zero weights, one biases (mostly open gate at the begining) nn.init.zeros_(self.to_gate.weight) nn.init.ones_(self.to_gate.bias) # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining nn.init.zeros_(self.to_out.weight) nn.init.zeros_(self.to_out.bias)
[docs] def apply_RT(self, Rs, Ts, vec): # Rs : (B, L, 3, 3) # Ts : (B, L, 3) # vec : (B, L, h, p, 3) B, L, h, p, _ = vec.shape Ts = Ts.unsqueeze(-2).unsqueeze(-2) # (B, L, 1, 1, 3) Ts = Ts.expand(-1,-1,h,p,-1) # (B, L, h, p, 3) return einsum('blij,blhpj->blhpi', Rs, vec) + Ts
[docs] def gather_edges_3d(self, edges, neighbor_idx): # Features (B, L, L, h, *) ==> (B, L, K, K, h, *) # neighbor_idx (B, L, K, h) B, L, K, h = neighbor_idx.shape neighbor_idx = neighbor_idx.permute(0,3,1,2).reshape(B*h, L, K) # (B*h, L, K) edges = edges.permute(0,3,1,2,4).reshape(B*h, L, L, -1) # (B*h, L, L, C) neigh_idx = neighbor_idx[:,:,:,None]*L + neighbor_idx[:,:,None,:] # (B*h, L, K, K) neigh_idx = neigh_idx.reshape(B*h, L, K*K, 1).expand(-1,-1,-1,edges.shape[-1]) edges = edges.reshape(B*h, 1, L*L, -1).expand(-1,L,-1,-1) # expand != repeat pair_neigh = torch.gather(edges, 2, neigh_idx).reshape(B*h, L, K, K, -1) pair_neigh = pair_neigh.reshape(B, h, L, K, K, -1).permute(0,2,3,4,1,5) # (B, L, K, K, h, C) return pair_neigh
[docs] def gather_edges(self, edges, neighbor_idx): # Features [B,L,L,h,C] at Neighbor indices [B,L,K,h] => Neighbor features [B,L,K,h,C] neighbors = neighbor_idx.unsqueeze(-1).expand(-1, -1, -1, -1, edges.size(-1)) edge_features = torch.gather(edges, 2, neighbors) return edge_features
[docs] def forward(self, pair_in, state, Rs, Ts, pair_mask=None, use_species = True): ''' Input: - pair: pairwise features. shape [B, L+1, L+1, d_pair] if use_species is True, [B, L, L, d_pair] if use_species is False - state: node features. shape [B, L+1, d_state] if use_species is True, [B, L, d_state] if use_species is False - Rs, Ts : SE(3) Frame. shape [B, L, 3, 3], [B, L, 3] - pair_mask: mask for valid residue pairs. shape [B, L+1, L+1] if use_species is True, [B, L, L] if use_species is False ''' B, L = state.shape[:2] if pair_mask is not None: pair_mask = pair_mask[...,None] # Input normalization pair = self.norm_pair(pair_in) state = self.norm_state(state) # It doesn't use the real xyz coordinates to define neighbors. # Instaed, it generates virtual xyz coordinates to define neighbors so that it can have a chance to escape from the wrong structures. if use_species: state_wo_species = state[:,1:] qk_pt = self.to_qk_point(state_wo_species).reshape(B, L-1, self.n_head, self.n_query_pt, 6) else: qk_pt = self.to_qk_point(state).reshape(B, L, self.n_head, self.n_query_pt, 6) query_pt, key_pt = torch.split(qk_pt, 3, dim=-1) # 2*(B, L, n_head, n_q, 3) # apply rigid transform to query & key points # if query & key points at the origin, the resulting RT_q & RT_k will be just CA atom coordinates of given structure RT_q = self.apply_RT(Rs, Ts, query_pt) # (B, L, h, n_q, 3) RT_k = self.apply_RT(Rs, Ts, key_pt) dist_qk = RT_q[:,:,None] - RT_k[:,None,:] # (B, L, L, h, n_q, 3) dist_qk = torch.norm(dist_qk, dim=-1) # (B, L, L, h, n_q) dist_qk = dist_qk.mean(dim=-1) # (B, L, L, h): average distance between virtual query/key points per head dist = torch.zeros(((B, L, L, self.n_head)), device=pair.device) # (B, L+1, L+1, h) or (B, L, L, h) if use_species: dist[:,1:,1:] = dist_qk else: dist = dist_qk if pair_mask != None: # ignore unvalid residue pairs (pairs with virtual node) dist = dist * pair_mask # (B, L, L, h) D_max, _ = torch.max(dist, -2, keepdim=True) # (B, L, 1, h) dist = dist + (1.0 - pair_mask) * (D_max+100.0) # Get Top-K neighbors # E_idx: neighbor indices (B, L, K, h) TOP_K = np.minimum(self.top_k, L-1) _, E_idx = torch.topk(dist, TOP_K, dim=-2, largest=False) # (B, L, K, h) # bias from structures dist_neigh = self.gather_edges_3d(dist[...,None], E_idx).reshape(B, L, TOP_K, TOP_K, self.n_head) # (B, L, K, K, h) rbf_neigh = rbf(dist_neigh).reshape(B, L, TOP_K, TOP_K, self.n_head, -1) bias = self.to_bias(rbf_neigh).reshape(B, L, TOP_K, TOP_K, self.n_head) # (B, L, K, K, h) pair_q = self.to_q(pair).reshape(B, L, L, self.n_head, self.dim) # (B, L, L, h, d) pair_k = self.to_k(pair).reshape(B, L, L, self.n_head, self.dim) # (B, L, L, h, d) pair_v = self.to_v(pair).reshape(B, L, L, self.n_head, self.dim) # (B, L, L, h, d) gate = torch.sigmoid(self.to_gate(pair).reshape(B, L, L, self.n_head, self.dim)) out = torch.zeros_like(gate) # Gather neighbor pairs pair_q = self.gather_edges(pair_q, E_idx) # (B, L, K, h, d) pair_k = self.gather_edges(pair_k, E_idx) # (B, L, K, h, d) pair_v = self.gather_edges(pair_v, E_idx) # (B, L, K, h, d) gate = self.gather_edges(gate, E_idx) # (B, L, K, h, d) attn_pair = torch.einsum('blihd,bljhd->blijh', pair_q, pair_k) # (B, L, K, K, h) attn_pair = attn_pair + bias attn_pair = nn.Softmax(dim=3)(attn_pair) # (B, L, K, K-softmaxed, h) if pair_mask != None: pair_mask_3d = self.gather_edges_3d(pair_mask[...,None].expand(-1,-1,-1,self.n_head,-1), E_idx).reshape(B, L, TOP_K, TOP_K, self.n_head) attn_pair = attn_pair.masked_fill(pair_mask_3d==0, -1e4) out_pair = torch.einsum('blijh,bljhd->blihd', attn_pair, pair_v) # (B, L, K, h, d) out_pair = gate * out_pair E_idx = E_idx.unsqueeze(-1).expand(-1,-1,-1,-1,out.size(-1)) out_pair = torch.scatter_add(out, 2, E_idx, out_pair) out_pair = self.to_out(out_pair.reshape(B, L, L, -1)) if pair_mask is not None: out_pair = pair_mask * out_pair return out_pair