import math
from miniworld.utils.util_module import *
[docs]
class FeedForwardLayer(nn.Module):
def __init__(self, d_model, r_ff, p_drop=0.1):
super(FeedForwardLayer, self).__init__()
self.norm = nn.LayerNorm(d_model)
self.linear1 = nn.Linear(d_model, d_model*r_ff)
self.dropout = nn.Dropout(p_drop)
self.linear2 = nn.Linear(d_model*r_ff, d_model)
self.reset_parameter()
[docs]
def reset_parameter(self):
# initialize linear layer right before ReLu: He initializer (kaiming normal)
nn.init.kaiming_normal_(self.linear1.weight, nonlinearity='relu')
nn.init.zeros_(self.linear1.bias)
# initialize linear layer right before residual connection: zero initialize
nn.init.zeros_(self.linear2.weight)
nn.init.zeros_(self.linear2.bias)
[docs]
def forward(self, src):
src = self.norm(src)
src = self.linear2(self.dropout(F.relu_(self.linear1(src))))
return src
[docs]
class Attention(nn.Module):
# calculate multi-head attention
def __init__(self, d_query, d_key, n_head, d_hidden, d_out, p_drop=0.1):
super(Attention, self).__init__()
self.h = n_head
self.dim = d_hidden
#
self.to_q = nn.Linear(d_query, n_head*d_hidden, bias=False)
self.to_k = nn.Linear(d_key, n_head*d_hidden, bias=False)
self.to_v = nn.Linear(d_key, n_head*d_hidden, bias=False)
#
self.to_out = nn.Linear(n_head*d_hidden, d_out)
self.scaling = 1/math.sqrt(d_hidden)
#
# initialize all parameters properly
self.reset_parameter()
[docs]
def reset_parameter(self):
# query/key/value projection: Glorot uniform / Xavier uniform
nn.init.xavier_uniform_(self.to_q.weight)
nn.init.xavier_uniform_(self.to_k.weight)
nn.init.xavier_uniform_(self.to_v.weight)
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
nn.init.zeros_(self.to_out.weight)
nn.init.zeros_(self.to_out.bias)
[docs]
def forward(self, query, key, value):
B, Q = query.shape[:2]
B, K = key.shape[:2]
#
query = self.to_q(query).reshape(B, Q, self.h, self.dim)
key = self.to_k(key).reshape(B, K, self.h, self.dim)
value = self.to_v(value).reshape(B, K, self.h, self.dim)
#
query = query * self.scaling
attn = einsum('bqhd,bkhd->bhqk', query, key)
attn = F.softmax(attn, dim=-1)
#
out = einsum('bhqk,bkhd->bqhd', attn, value)
out = out.reshape(B, Q, self.h*self.dim)
#
out = self.to_out(out)
return out
[docs]
class AttentionWithBias(nn.Module):
def __init__(self, d_in=256, d_bias=128, n_head=8, d_hidden=32):
super(AttentionWithBias, self).__init__()
self.norm_in = nn.LayerNorm(d_in)
self.norm_bias = nn.LayerNorm(d_bias)
#
self.to_q = nn.Linear(d_in, n_head*d_hidden, bias=False)
self.to_k = nn.Linear(d_in, n_head*d_hidden, bias=False)
self.to_v = nn.Linear(d_in, n_head*d_hidden, bias=False)
self.to_b = nn.Linear(d_bias, n_head, bias=False)
self.to_g = nn.Linear(d_in, n_head*d_hidden)
self.to_out = nn.Linear(n_head*d_hidden, d_in)
self.scaling = 1/math.sqrt(d_hidden)
self.h = n_head
self.dim = d_hidden
self.reset_parameter()
[docs]
def reset_parameter(self):
# query/key/value projection: Glorot uniform / Xavier uniform
nn.init.xavier_uniform_(self.to_q.weight)
nn.init.xavier_uniform_(self.to_k.weight)
nn.init.xavier_uniform_(self.to_v.weight)
# bias: normal distribution
self.to_b = init_lecun_normal(self.to_b)
# gating: zero weights, one biases (mostly open gate at the begining)
nn.init.zeros_(self.to_g.weight)
nn.init.ones_(self.to_g.bias)
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
nn.init.zeros_(self.to_out.weight)
nn.init.zeros_(self.to_out.bias)
[docs]
def forward(self, x, bias):
B, L = x.shape[:2]
#
x = self.norm_in(x)
bias = self.norm_bias(bias)
#
query = self.to_q(x).reshape(B, L, self.h, self.dim)
key = self.to_k(x).reshape(B, L, self.h, self.dim)
value = self.to_v(x).reshape(B, L, self.h, self.dim)
bias = self.to_b(bias) # (B, L, L, h)
gate = torch.sigmoid(self.to_g(x))
#
key = key * self.scaling
attn = einsum('bqhd,bkhd->bqkh', query, key)
attn = attn + bias
attn = F.softmax(attn, dim=-2)
#
out = einsum('bqkh,bkhd->bqhd', attn, value).reshape(B, L, -1)
out = gate * out
#
out = self.to_out(out)
return out
# MSA Attention (row/column) from AlphaFold architecture
[docs]
class SequenceWeight(nn.Module):
def __init__(self, d_msa, n_head, d_hidden, p_drop=0.1):
super(SequenceWeight, self).__init__()
self.h = n_head
self.dim = d_hidden
self.scale = 1.0 / math.sqrt(self.dim)
self.to_query = nn.Linear(d_msa, n_head*d_hidden)
self.to_key = nn.Linear(d_msa, n_head*d_hidden)
self.dropout = nn.Dropout(p_drop)
self.reset_parameter()
[docs]
def reset_parameter(self):
# query/key/value projection: Glorot uniform / Xavier uniform
nn.init.xavier_uniform_(self.to_query.weight)
nn.init.xavier_uniform_(self.to_key.weight)
[docs]
def forward(self, msa):
B, N, L = msa.shape[:3]
tar_seq = msa[:,0]
q = self.to_query(tar_seq).view(B, 1, L, self.h, self.dim)
k = self.to_key(msa).view(B, N, L, self.h, self.dim)
q = q * self.scale
attn = einsum('bqihd,bkihd->bkihq', q, k)
attn = F.softmax(attn, dim=1)
return self.dropout(attn)
[docs]
class MSARowAttentionWithBias(nn.Module):
def __init__(self, d_msa=256, d_pair=128, n_head=8, d_hidden=32):
super(MSARowAttentionWithBias, self).__init__()
self.norm_msa = nn.LayerNorm(d_msa)
self.norm_pair = nn.LayerNorm(d_pair)
#
self.seq_weight = SequenceWeight(d_msa, n_head, d_hidden, p_drop=0.1)
self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
self.to_k = nn.Linear(d_msa, n_head*d_hidden, bias=False)
self.to_v = nn.Linear(d_msa, n_head*d_hidden, bias=False)
self.to_b = nn.Linear(d_pair, n_head, bias=False)
self.to_g = nn.Linear(d_msa, n_head*d_hidden)
self.to_out = nn.Linear(n_head*d_hidden, d_msa)
self.scaling = 1/math.sqrt(d_hidden)
self.h = n_head
self.dim = d_hidden
self.reset_parameter()
[docs]
def reset_parameter(self):
# query/key/value projection: Glorot uniform / Xavier uniform
nn.init.xavier_uniform_(self.to_q.weight)
nn.init.xavier_uniform_(self.to_k.weight)
nn.init.xavier_uniform_(self.to_v.weight)
# bias: normal distribution
self.to_b = init_lecun_normal(self.to_b)
# gating: zero weights, one biases (mostly open gate at the begining)
nn.init.zeros_(self.to_g.weight)
nn.init.ones_(self.to_g.bias)
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
nn.init.zeros_(self.to_out.weight)
nn.init.zeros_(self.to_out.bias)
[docs]
def forward(self, msa, pair, sym=None): # TODO: make this as tied-attention
B, N, L = msa.shape[:3]
# O = pair.shape[0]
# PSK, 20230730 난 symmetry 안 쓰니까 O == 1로 강제. 그리고 pair.shape[0]이 내 코드에선 Batch size임.
# if (B==1):
msa = self.norm_msa(msa)
pair = self.norm_pair(pair)
seq_weight = self.seq_weight(msa) # (B, N, L, h, 1)
query = self.to_q(msa).reshape(B, N, L, self.h, self.dim)
key = self.to_k(msa).reshape(B, N, L, self.h, self.dim)
value = self.to_v(msa).reshape(B, N, L, self.h, self.dim)
bias = self.to_b(pair) # (B, L, L, h)
gate = torch.sigmoid(self.to_g(msa))
query = query * seq_weight.expand(-1, -1, -1, -1, self.dim)
key = key * self.scaling
attn = einsum('bsqhd,bskhd->bqkh', query, key)
attn = attn + bias
attn = F.softmax(attn, dim=-2)
out = einsum('bqkh,bskhd->bsqhd', attn, value).reshape(B, N, L, -1)
out = gate * out
out = self.to_out(out)
return out
[docs]
class MSAColAttention(nn.Module):
def __init__(self, d_msa=256, n_head=8, d_hidden=32):
super(MSAColAttention, self).__init__()
self.norm_msa = nn.LayerNorm(d_msa)
#
self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
self.to_k = nn.Linear(d_msa, n_head*d_hidden, bias=False)
self.to_v = nn.Linear(d_msa, n_head*d_hidden, bias=False)
self.to_g = nn.Linear(d_msa, n_head*d_hidden)
self.to_out = nn.Linear(n_head*d_hidden, d_msa)
self.scaling = 1/math.sqrt(d_hidden)
self.h = n_head
self.dim = d_hidden
self.reset_parameter()
[docs]
def reset_parameter(self):
# query/key/value projection: Glorot uniform / Xavier uniform
nn.init.xavier_uniform_(self.to_q.weight)
nn.init.xavier_uniform_(self.to_k.weight)
nn.init.xavier_uniform_(self.to_v.weight)
# gating: zero weights, one biases (mostly open gate at the begining)
nn.init.zeros_(self.to_g.weight)
nn.init.ones_(self.to_g.bias)
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
nn.init.zeros_(self.to_out.weight)
nn.init.zeros_(self.to_out.bias)
[docs]
def forward(self, msa):
B, N, L = msa.shape[:3]
msa_species = msa[:,:,:1] # (B, N, 1, d_msa)
#
msa = self.norm_msa(msa)
#
query = self.to_q(msa_species).reshape(B, N, self.h, self.dim)
key = self.to_k(msa_species).reshape(B, N, self.h, self.dim)
value = self.to_v(msa).reshape(B, N, L, self.h, self.dim)
gate = torch.sigmoid(self.to_g(msa))
#
query = query * self.scaling
attn = einsum('bqhd,bkhd->bhqk', query, key) # (B, h, N, N)
attn = F.softmax(attn, dim=-1)
#
out = einsum('bhqk,bkihd->bqihd', attn, value).reshape(B, N, L, -1)
out = gate * out
#
out = self.to_out(out)
return out
[docs]
class MSAColGlobalAttention(nn.Module):
def __init__(self, d_msa=64, n_head=8, d_hidden=8):
super(MSAColGlobalAttention, self).__init__()
self.norm_msa = nn.LayerNorm(d_msa)
#
self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
self.to_k = nn.Linear(d_msa, d_hidden, bias=False)
self.to_v = nn.Linear(d_msa, d_hidden, bias=False)
self.to_g = nn.Linear(d_msa, n_head*d_hidden)
self.to_out = nn.Linear(n_head*d_hidden, d_msa)
self.scaling = 1/math.sqrt(d_hidden)
self.h = n_head
self.dim = d_hidden
self.reset_parameter()
[docs]
def reset_parameter(self):
# query/key/value projection: Glorot uniform / Xavier uniform
nn.init.xavier_uniform_(self.to_q.weight)
nn.init.xavier_uniform_(self.to_k.weight)
nn.init.xavier_uniform_(self.to_v.weight)
# gating: zero weights, one biases (mostly open gate at the begining)
nn.init.zeros_(self.to_g.weight)
nn.init.ones_(self.to_g.bias)
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
nn.init.zeros_(self.to_out.weight)
nn.init.zeros_(self.to_out.bias)
[docs]
def forward(self, msa):
B, N, L = msa.shape[:3]
#
msa = self.norm_msa(msa)
#
query = self.to_q(msa).reshape(B, N, L, self.h, self.dim)
query = query.mean(dim=1) # (B, L, h, dim)
key = self.to_k(msa) # (B, N, L, dim)
value = self.to_v(msa) # (B, N, L, dim)
gate = torch.sigmoid(self.to_g(msa)) # (B, N, L, h*dim)
#
query = query * self.scaling
attn = einsum('bihd,bkid->bihk', query, key) # (B, L, h, N)
attn = F.softmax(attn, dim=-1)
#
out = einsum('bihk,bkid->bihd', attn, value).reshape(B, 1, L, -1) # (B, 1, L, h*dim)
out = gate * out # (B, N, L, h*dim)
#
out = self.to_out(out)
return out
# Instead of triangle attention, use Tied axail attention with bias from coordinates..?
[docs]
class BiasedAxialAttention(nn.Module):
def __init__(self, d_pair, d_bias, n_head, d_hidden, p_drop=0.1, is_row=True):
super(BiasedAxialAttention, self).__init__()
#
self.is_row = is_row
self.norm_pair = nn.LayerNorm(d_pair)
self.norm_bias = nn.LayerNorm(d_pair)
self.to_q = nn.Linear(d_pair, n_head*d_hidden, bias=False)
self.to_k = nn.Linear(d_pair, n_head*d_hidden, bias=False)
self.to_v = nn.Linear(d_pair, n_head*d_hidden, bias=False)
self.to_b = nn.Linear(d_pair, n_head, bias=False)
self.to_g = nn.Linear(d_pair, n_head*d_hidden)
self.to_out = nn.Linear(n_head*d_hidden, d_pair)
self.scaling = 1/math.sqrt(d_hidden)
self.h = n_head
self.dim = d_hidden
self.dim_out = d_pair
# initialize all parameters properly
self.reset_parameter()
[docs]
def reset_parameter(self):
# query/key/value projection: Glorot uniform / Xavier uniform
nn.init.xavier_uniform_(self.to_q.weight)
nn.init.xavier_uniform_(self.to_k.weight)
nn.init.xavier_uniform_(self.to_v.weight)
# bias: normal distribution
self.to_b = init_lecun_normal(self.to_b)
# gating: zero weights, one biases (mostly open gate at the begining)
nn.init.zeros_(self.to_g.weight)
nn.init.ones_(self.to_g.bias)
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
nn.init.zeros_(self.to_out.weight)
nn.init.zeros_(self.to_out.bias)
[docs]
def forward(self, pair, sym=None, stride=-1):
O, L = pair.shape[:2] # after subunit mask is applied
if O==1 or sym is None or sym.shape[0]!=O: # asymm mode
# pair: (B, L, L, d_pair)
if self.is_row:
pair = pair.permute(0,2,1,3)
pair = self.norm_pair(pair)
bias = self.norm_bias(pair)
query = self.to_q(pair).reshape(O, L, L, self.h, self.dim)
key = self.to_k(pair).reshape(O, L, L, self.h, self.dim)
value = self.to_v(pair).reshape(O, L, L, self.h, self.dim)
bias = self.to_b(bias) # (B, L, L, h)
gate = torch.sigmoid(self.to_g(pair)) # (B, L, L, h*dim)
query = query * self.scaling
key = key / L # normalize for tied attention
attn = einsum('bnihk,bnjhk->bijh', query, key) # tied attention
attn = attn + bias
attn = F.softmax(attn, dim=-2) # (B, L, L, h)
out = einsum('bijh,bnjhd->bnihd', attn, value).reshape(O, L, L, -1)
out = gate * out
out = self.to_out(out)
if self.is_row:
out = out.permute(0,2,1,3)
else:
# symmetric version
if self.is_row:
pair = pair[sym[0,:]].permute(0,2,1,3)
pair = self.norm_pair(pair)
bias = self.norm_bias(pair)
query = self.to_q(pair).reshape(O, L, L, self.h, self.dim)
key = self.to_k(pair).reshape(O, L, L, self.h, self.dim)
value = self.to_v(pair).reshape(O, L, L, self.h, self.dim)
bias = self.to_b(pair) # (B, L, L, h)
gate = torch.sigmoid(self.to_g(pair)) # (B, L, L, h*dim)
query = query * self.scaling
key = key / (O*L) # normalize for tied attention
attn=torch.zeros((O,L,L,self.h), device=pair.device)
for i in range(O):
attn[i] = torch.einsum('bnihk,bnjhk->ijh', query[sym[:,i]], key[sym[:,0]]) # tied attention
#attn = einsum('bnihk,bnjhk->bijh', query, key) # tied attention
attn = attn + bias # apply bias
# softmax over dims 0 & 2
attn = F.softmax(
attn.transpose(1,2).reshape(O*L,L,self.h), dim=0
).reshape(O,L,L,self.h).transpose(1,2)
out=torch.zeros((O,L,L,self.h,self.dim), device=pair.device)
for i in range(O):
out[i] = torch.einsum('bijh,bnjhd->nihd', attn[sym[:,i]], value) # tied attention
#out = einsum('bijh,bnjhd->bnihd', attn, value).reshape(O, L, L, -1)
out = gate * out.reshape(O,L,L,-1)
out = self.to_out(out)
if self.is_row:
out = out[sym[0,:]].permute(0,2,1,3)
return out
[docs]
class TriangleMultiplication(nn.Module):
def __init__(self, d_pair, d_hidden=128, outgoing=True):
super(TriangleMultiplication, self).__init__()
self.norm = nn.LayerNorm(d_pair)
self.left_proj = nn.Linear(d_pair, d_hidden)
self.right_proj = nn.Linear(d_pair, d_hidden)
self.left_gate = nn.Linear(d_pair, d_hidden)
self.right_gate = nn.Linear(d_pair, d_hidden)
#
self.gate = nn.Linear(d_pair, d_pair)
self.norm_out = nn.LayerNorm(d_hidden)
self.out_proj = nn.Linear(d_hidden, d_pair)
self.d_hidden = d_hidden
self.outgoing = outgoing
self.reset_parameter()
[docs]
def reset_parameter(self):
# normal distribution for regular linear weights
self.left_proj = init_lecun_normal(self.left_proj)
self.right_proj = init_lecun_normal(self.right_proj)
# Set Bias of Linear layers to zeros
nn.init.zeros_(self.left_proj.bias)
nn.init.zeros_(self.right_proj.bias)
# gating: zero weights, one biases (mostly open gate at the begining)
nn.init.zeros_(self.left_gate.weight)
nn.init.ones_(self.left_gate.bias)
nn.init.zeros_(self.right_gate.weight)
nn.init.ones_(self.right_gate.bias)
nn.init.zeros_(self.gate.weight)
nn.init.ones_(self.gate.bias)
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
nn.init.zeros_(self.out_proj.weight)
nn.init.zeros_(self.out_proj.bias)
[docs]
def forward(self, pair, sym=None):
pair = self.norm(pair)
O,L = pair.shape[:2]
if O==1 or sym is None or sym.shape[0]!=O: # asymm mode
left = self.left_proj(pair) # (B, L, L, d_h)
left_gate = torch.sigmoid(self.left_gate(pair))
left = left_gate * left
right = self.right_proj(pair) # (B, L, L, d_h)
right_gate = torch.sigmoid(self.right_gate(pair))
right = right_gate * right
if self.outgoing:
out = einsum('bikd,bjkd->bijd', left, right/float(L))
else:
out = einsum('bkid,bkjd->bijd', left, right/float(L))
out = self.norm_out(out)
out = self.out_proj(out)
gate = torch.sigmoid(self.gate(pair)) # (B, L, L, d_pair)
out = gate * out
else:
left = self.left_proj(pair) # (B, L, L, d_h)
left_gate = torch.sigmoid(self.left_gate(pair))
left = left_gate * left
right = self.right_proj(pair) # (B, L, L, d_h)
right_gate = torch.sigmoid(self.right_gate(pair))
right = right_gate * right
if self.outgoing:
out=torch.zeros((O,L,L,self.d_hidden), device=pair.device)
for i in range(O):
out[i] = torch.einsum('bikd,bjkd->ijd', left[sym[i,:]], right[sym[0,:]]/float(O*L)) # tied attention
else:
out=torch.zeros((O,L,L,self.d_hidden), device=pair.device)
for i in range(O):
out[i] = torch.einsum('bkid,bkjd->ijd', left[sym[:,i]], right[sym[:,0]]/float(O*L)) # tied attention
out = self.norm_out(out)
out = self.out_proj(out)
gate = torch.sigmoid(self.gate(pair)) # (B, L, L, d_pair)
out = gate * out
return out
[docs]
class StateAttentionGate(nn.Module):
def __init__(self, d_state, d_pair, n_head=8, d_hidden=32, p_drop=0.1):
super(StateAttentionGate, self).__init__()
self.norm_state = nn.LayerNorm(d_state)
self.norm_pair = nn.LayerNorm(d_pair)
self.to_q = nn.Linear(d_state, n_head*d_hidden, bias=False)
self.to_k = nn.Linear(d_state, n_head*d_hidden, bias=False)
self.to_v = nn.Linear(d_state, n_head*d_hidden, bias=False)
self.to_b = nn.Linear(d_pair, n_head, bias=False)
self.to_g = nn.Linear(d_state, n_head*d_hidden)
self.scaling = 1/math.sqrt(d_hidden)
self.h = n_head
self.dim = d_hidden
self.reset_parameter()
[docs]
def reset_parameter(self):
# query/key/value projection: Glorot uniform / Xavier uniform
nn.init.xavier_uniform_(self.to_q.weight)
nn.init.xavier_uniform_(self.to_k.weight)
nn.init.xavier_uniform_(self.to_v.weight)
# bias: normal distribution
self.to_b = init_lecun_normal(self.to_b)
# gating: zero weights, one biases (mostly open gate at the begining)
nn.init.zeros_(self.to_g.weight)
nn.init.ones_(self.to_g.bias)
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
nn.init.zeros_(self.to_out.weight)
nn.init.zeros_(self.to_out.bias)
[docs]
def forward(self, state, pair, chain_mask):
'''
state : (B, L, d_state)
pair : (B, L, L, d_pair)
chain_mask : (B, L, L)
'''
B, L, C = state.shape[:3]
#
state = self.norm_state(state)
pair = self.norm_pair(pair)
interchain_mask = ~chain_mask.unsqueeze(-1) # (B, L, L, 1)
query = self.to_q(state).reshape(B, L, self.h, self.dim) # (B, L, h, dim)
key = self.to_k(state).reshape(B, L, self.h, self.dim) # (B, L, h, dim)
value = self.to_v(state).reshape(B, L, self.h, self.dim) # (B, L, h, dim)
bias = self.to_b(pair) # (B, L, L, h)
gate = torch.sigmoid(self.to_g(state))
key = key * self.scaling
attn = einsum('bqhd,bkhd->bhqk', query, key) # (B, h, L, L)
attn = attn.reshape(B, L, L, self.h) # (B, L, L, h)
attn = attn + bias # (B, L, L, h)
attn = F.softmax(attn, dim=-2) # (B, L, L, h)
attn = attn * interchain_mask # (B, L, L, h)
attn = attn / attn.sum(dim=1, keepdim=True) # (B, L, L, h)
out = einsum('bhqk,bkhd->bqhd', attn, value).reshape(B, L, -1) # (B, L, h*dim)
gate = torch.sigmoid(self.to_g(out)) # (B, L, h*dim)
return gate
[docs]
class GraphTriangleAttention(nn.Module):
def __init__(self, d_pair, d_state, top_k=64, d_rbf=64, n_head=8, n_query_pt=4, d_hidden=32):
super(GraphTriangleAttention, self).__init__()
self.n_head = n_head
self.dim = d_hidden
self.top_k = top_k
self.n_query_pt = n_query_pt
self.norm_pair = nn.LayerNorm(d_pair)
self.norm_state = nn.LayerNorm(d_state)
self.to_qk_point = nn.Linear(d_state, 2*n_head*n_query_pt*3, bias=False)
self.to_bias = nn.Linear(d_rbf, 1, bias=False)
self.to_q = nn.Linear(d_pair, n_head*d_hidden, bias=False)
self.to_k = nn.Linear(d_pair, n_head*d_hidden, bias=False)
self.to_v = nn.Linear(d_pair, n_head*d_hidden, bias=False)
self.to_gate = nn.Linear(d_pair, n_head*d_hidden)
self.to_out = nn.Linear(n_head*d_hidden, d_pair)
self.reset_parameter()
[docs]
def reset_parameter(self):
# query/key/value projection: Glorot uniform / Xavier uniform
nn.init.xavier_uniform_(self.to_q.weight)
nn.init.xavier_uniform_(self.to_k.weight)
nn.init.xavier_uniform_(self.to_v.weight)
nn.init.xavier_uniform_(self.to_qk_point.weight)
# bias: normal distribution
self.to_bias = init_lecun_normal(self.to_bias)
# gating: zero weights, one biases (mostly open gate at the begining)
nn.init.zeros_(self.to_gate.weight)
nn.init.ones_(self.to_gate.bias)
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
nn.init.zeros_(self.to_out.weight)
nn.init.zeros_(self.to_out.bias)
[docs]
def apply_RT(self, Rs, Ts, vec):
# Rs : (B, L, 3, 3)
# Ts : (B, L, 3)
# vec : (B, L, h, p, 3)
B, L, h, p, _ = vec.shape
Ts = Ts.unsqueeze(-2).unsqueeze(-2) # (B, L, 1, 1, 3)
Ts = Ts.expand(-1,-1,h,p,-1) # (B, L, h, p, 3)
return einsum('blij,blhpj->blhpi', Rs, vec) + Ts
[docs]
def gather_edges_3d(self, edges, neighbor_idx):
# Features (B, L, L, h, *) ==> (B, L, K, K, h, *)
# neighbor_idx (B, L, K, h)
B, L, K, h = neighbor_idx.shape
neighbor_idx = neighbor_idx.permute(0,3,1,2).reshape(B*h, L, K) # (B*h, L, K)
edges = edges.permute(0,3,1,2,4).reshape(B*h, L, L, -1) # (B*h, L, L, C)
neigh_idx = neighbor_idx[:,:,:,None]*L + neighbor_idx[:,:,None,:] # (B*h, L, K, K)
neigh_idx = neigh_idx.reshape(B*h, L, K*K, 1).expand(-1,-1,-1,edges.shape[-1])
edges = edges.reshape(B*h, 1, L*L, -1).expand(-1,L,-1,-1) # expand != repeat
pair_neigh = torch.gather(edges, 2, neigh_idx).reshape(B*h, L, K, K, -1)
pair_neigh = pair_neigh.reshape(B, h, L, K, K, -1).permute(0,2,3,4,1,5) # (B, L, K, K, h, C)
return pair_neigh
[docs]
def gather_edges(self, edges, neighbor_idx):
# Features [B,L,L,h,C] at Neighbor indices [B,L,K,h] => Neighbor features [B,L,K,h,C]
neighbors = neighbor_idx.unsqueeze(-1).expand(-1, -1, -1, -1, edges.size(-1))
edge_features = torch.gather(edges, 2, neighbors)
return edge_features
[docs]
def forward(self, pair_in, state, Rs, Ts, pair_mask=None, use_species = True):
'''
Input:
- pair: pairwise features. shape [B, L+1, L+1, d_pair] if use_species is True, [B, L, L, d_pair] if use_species is False
- state: node features. shape [B, L+1, d_state] if use_species is True, [B, L, d_state] if use_species is False
- Rs, Ts : SE(3) Frame. shape [B, L, 3, 3], [B, L, 3]
- pair_mask: mask for valid residue pairs. shape [B, L+1, L+1] if use_species is True, [B, L, L] if use_species is False
'''
B, L = state.shape[:2]
if pair_mask is not None:
pair_mask = pair_mask[...,None]
# Input normalization
pair = self.norm_pair(pair_in)
state = self.norm_state(state)
# It doesn't use the real xyz coordinates to define neighbors.
# Instaed, it generates virtual xyz coordinates to define neighbors so that it can have a chance to escape from the wrong structures.
if use_species:
state_wo_species = state[:,1:]
qk_pt = self.to_qk_point(state_wo_species).reshape(B, L-1, self.n_head, self.n_query_pt, 6)
else:
qk_pt = self.to_qk_point(state).reshape(B, L, self.n_head, self.n_query_pt, 6)
query_pt, key_pt = torch.split(qk_pt, 3, dim=-1) # 2*(B, L, n_head, n_q, 3)
# apply rigid transform to query & key points
# if query & key points at the origin, the resulting RT_q & RT_k will be just CA atom coordinates of given structure
RT_q = self.apply_RT(Rs, Ts, query_pt) # (B, L, h, n_q, 3)
RT_k = self.apply_RT(Rs, Ts, key_pt)
dist_qk = RT_q[:,:,None] - RT_k[:,None,:] # (B, L, L, h, n_q, 3)
dist_qk = torch.norm(dist_qk, dim=-1) # (B, L, L, h, n_q)
dist_qk = dist_qk.mean(dim=-1) # (B, L, L, h): average distance between virtual query/key points per head
dist = torch.zeros(((B, L, L, self.n_head)), device=pair.device) # (B, L+1, L+1, h) or (B, L, L, h)
if use_species:
dist[:,1:,1:] = dist_qk
else:
dist = dist_qk
if pair_mask != None:
# ignore unvalid residue pairs (pairs with virtual node)
dist = dist * pair_mask # (B, L, L, h)
D_max, _ = torch.max(dist, -2, keepdim=True) # (B, L, 1, h)
dist = dist + (1.0 - pair_mask) * (D_max+100.0)
# Get Top-K neighbors
# E_idx: neighbor indices (B, L, K, h)
TOP_K = np.minimum(self.top_k, L-1)
_, E_idx = torch.topk(dist, TOP_K, dim=-2, largest=False) # (B, L, K, h)
# bias from structures
dist_neigh = self.gather_edges_3d(dist[...,None], E_idx).reshape(B, L, TOP_K, TOP_K, self.n_head) # (B, L, K, K, h)
rbf_neigh = rbf(dist_neigh).reshape(B, L, TOP_K, TOP_K, self.n_head, -1)
bias = self.to_bias(rbf_neigh).reshape(B, L, TOP_K, TOP_K, self.n_head) # (B, L, K, K, h)
pair_q = self.to_q(pair).reshape(B, L, L, self.n_head, self.dim) # (B, L, L, h, d)
pair_k = self.to_k(pair).reshape(B, L, L, self.n_head, self.dim) # (B, L, L, h, d)
pair_v = self.to_v(pair).reshape(B, L, L, self.n_head, self.dim) # (B, L, L, h, d)
gate = torch.sigmoid(self.to_gate(pair).reshape(B, L, L, self.n_head, self.dim))
out = torch.zeros_like(gate)
# Gather neighbor pairs
pair_q = self.gather_edges(pair_q, E_idx) # (B, L, K, h, d)
pair_k = self.gather_edges(pair_k, E_idx) # (B, L, K, h, d)
pair_v = self.gather_edges(pair_v, E_idx) # (B, L, K, h, d)
gate = self.gather_edges(gate, E_idx) # (B, L, K, h, d)
attn_pair = torch.einsum('blihd,bljhd->blijh', pair_q, pair_k) # (B, L, K, K, h)
attn_pair = attn_pair + bias
attn_pair = nn.Softmax(dim=3)(attn_pair) # (B, L, K, K-softmaxed, h)
if pair_mask != None:
pair_mask_3d = self.gather_edges_3d(pair_mask[...,None].expand(-1,-1,-1,self.n_head,-1), E_idx).reshape(B, L, TOP_K, TOP_K, self.n_head)
attn_pair = attn_pair.masked_fill(pair_mask_3d==0, -1e4)
out_pair = torch.einsum('blijh,bljhd->blihd', attn_pair, pair_v) # (B, L, K, h, d)
out_pair = gate * out_pair
E_idx = E_idx.unsqueeze(-1).expand(-1,-1,-1,-1,out.size(-1))
out_pair = torch.scatter_add(out, 2, E_idx, out_pair)
out_pair = self.to_out(out_pair.reshape(B, L, L, -1))
if pair_mask is not None:
out_pair = pair_mask * out_pair
return out_pair