Source code for deepfold.modules.msa_column_global_attention

from typing import Optional

import torch
import torch.nn as nn

from deepfold.modules.global_attention import GlobalAttention
from deepfold.modules.layer_norm import LayerNorm


[docs] class MSAColumnGlobalAttention(nn.Module): """MSA Column Global Attention module. Supplementary '1.7.2 Unclustered MSA stack': Algorithm 19 MSA global column-wise gated self-attention. Args: c_e: Extra MSA representation dimension (channels). c_hidden: Per-head hidden dimension (channels). num_heads: Number of attention heads. inf: Safe infinity value. eps: Epsilon to prevent division by zero. chunk_size: Optional chunk size for a batch-like dimension. """ def __init__( self, c_e: int, c_hidden: int, num_heads: int, inf: float, eps: float, chunk_size: Optional[int], ) -> None: super().__init__() self.layer_norm_m = LayerNorm(c_e) self.global_attention = GlobalAttention( c_e=c_e, c_hidden=c_hidden, num_heads=num_heads, inf=inf, eps=eps, chunk_size=chunk_size, )
[docs] def forward( self, m: torch.Tensor, mask: torch.Tensor, ) -> torch.Tensor: """MSA Column Global Attention forward pass. Args: m: [batch, N_extra_seq, N_res, c_e] extra MSA representation mask: [batch, N_extra_seq, N_res] extra MSA mask Returns: m: [batch, N_extra_seq, N_res, c_e] updated extra MSA representation """ m_transposed = m.transpose(-2, -3) # m_transposed: [batch, N_res, N_extra_seq, c_e] mask = mask.transpose(-1, -2) # mask: [batch, N_res, N_extra_seq] m_transposed_normalized = self.layer_norm_m(m_transposed) m = self.global_attention( m=m_transposed_normalized, mask=mask, add_transposed_output_to=m, ) # m: [batch, N_extra_seq, N_res, c_e] return m