Source code for deepfold.modules.msa_row_attention_with_pair_bias

from typing import Optional

import torch
import torch.nn as nn

import deepfold.distributed.model_parallel as mp
from deepfold.modules.attention import SelfAttentionWithGate
from deepfold.modules.layer_norm import LayerNorm
from deepfold.modules.linear import Linear


[docs] class MSARowAttentionWithPairBias(nn.Module): """MSA Row Attention With Pair Bias module. Supplementary '1.6.1 MSA row-wise gated self-attention with pair bias': Algorithm 7. Args: c_m: MSA (or Extra MSA) representation dimension (channels). c_z: Pair representation dimension (channels). c_hidden: Per-head hidden dimension (channels). num_heads: Number of attention heads. inf: Safe infinity value. chunk_size: Optional chunk size for a batch-like dimension. """ def __init__( self, c_m: int, c_z: int, c_hidden: int, num_heads: int, inf: float, chunk_size: Optional[int], ) -> None: super().__init__() self.layer_norm_m = LayerNorm(c_m) self.layer_norm_z = LayerNorm(c_z) self.linear_z = Linear(c_z, num_heads, bias=False, init="normal") self.mha = SelfAttentionWithGate( c_qkv=c_m, c_hidden=c_hidden, num_heads=num_heads, inf=inf, chunk_size=chunk_size, )
[docs] def forward( self, m: torch.Tensor, z: torch.Tensor, mask: torch.Tensor, ) -> torch.Tensor: """MSA Row Attention With Pair Bias forward pass. Args: m: [batch, N_seq, N_res, c_m] MSA (or Extra MSA) representation z: [batch, N_res, N_res, c_z] pair representation mask: [batch, N_seq, N_res] MSA (or Extra MSA) mask Returns: m_update: [batch, N_seq, N_res, c_m] MSA (or Extra MSA) representation update """ mask = mask.unsqueeze(-2).unsqueeze(-3) # mask: [batch, N_seq, 1, 1, N_res] z = self.layer_norm_z(z) z = self.linear_z(z) if mp.is_enabled(): z = mp.gather(z, dim=-3, bwd="all_reduce_sum_split") # z: [batch, N_res, N_res, num_heads] z = z.movedim(-1, -3).unsqueeze(-4) # z: [batch, 1, num_heads, N_res, N_res] m = self.layer_norm_m(m) m = self.mha( input_qkv=m, mask=mask, bias=z, ) # m: [batch, N_seq, N_res, c_m] return m