Source code for deepfold.modules.msa_column_attention

from typing import Optional

import torch
import torch.nn as nn

from deepfold.modules.attention import SelfAttentionWithGate
from deepfold.modules.layer_norm import LayerNorm


[docs] class MSAColumnAttention(nn.Module): """MSA Column Attention module. Supplementary '1.6.2 MSA column-wise gated self-attention': Algorithm 8. Args: c_m: MSA representation dimension (channels). c_hidden: Per-head hidden dimension (channels). num_heads: Number of attention heads. inf: Safe infinity value. chunk_size: Optional chunk size for a batch-like dimension. """ def __init__( self, c_m: int, c_hidden: int, num_heads: int, inf: float, chunk_size: Optional[int], ) -> None: super().__init__() self.layer_norm_m = LayerNorm(c_m) self.mha = SelfAttentionWithGate( c_qkv=c_m, c_hidden=c_hidden, num_heads=num_heads, inf=inf, chunk_size=chunk_size, )
[docs] def forward( self, m: torch.Tensor, mask: torch.Tensor, ) -> torch.Tensor: """MSA Column Attention forward pass. Args: m: [batch, N_seq, N_res, c_m] MSA representation mask: [batch, N_seq, N_res] MSA mask Returns: m: [batch, N_seq, N_res, c_m] updated MSA representation """ m_transposed = m.transpose(-2, -3) # m_transposed: [batch, N_res, N_seq, c_m] mask = mask.transpose(-1, -2) # mask: [batch, N_res, N_seq] mask = mask.unsqueeze(-2).unsqueeze(-3) # mask: [batch, N_res, 1, 1, N_seq] m_transposed_normalized = self.layer_norm_m(m_transposed) m = self.mha( input_qkv=m_transposed_normalized, mask=mask, bias=None, add_transposed_output_to=m, ) # m: [batch, N_seq, N_res, c_m] return m