Source code for deepfold.modules.triangular_attention

from typing import Optional

import torch
import torch.nn as nn

import deepfold.distributed.model_parallel as mp
from deepfold.modules.attention import SelfAttentionWithGate
from deepfold.modules.layer_norm import LayerNorm
from deepfold.modules.linear import Linear


[docs] class TriangleAttention(nn.Module): """Triangle Attention module. Supplementary '1.6.6 Triangular self-attention': Algorithms 13 and 14. Args: c_z: Pair or template representation dimension (channels). c_hidden: Hidden dimension (channels). num_heads: Number of attention heads. ta_type: "starting" or "ending" inf: Safe infinity value. chunk_size: Optional chunk size for a batch-like dimension. """ def __init__( self, c_z: int, c_hidden: int, num_heads: int, ta_type: str, inf: float, chunk_size: Optional[int], ) -> None: super().__init__() self._is_starting = {"starting": True, "ending": False}[ta_type] self.layer_norm = LayerNorm(c_z) self.linear = Linear(c_z, num_heads, bias=False, init="normal") self.mha = SelfAttentionWithGate( c_qkv=c_z, c_hidden=c_hidden, num_heads=num_heads, inf=inf, chunk_size=chunk_size, )
[docs] def forward( self, z: torch.Tensor, mask: torch.Tensor, ) -> torch.Tensor: """Triangle Attention forward pass. Args: z: [batch, N_res, N_res, c_z] pair representation mask: [batch, N_res, N_res] pair mask Returns: z_update: [batch, N_res, N_res, c_z] pair representation update """ if not self._is_starting: z = z.transpose(-2, -3) mask = mask.transpose(-1, -2) # z: [batch, N_res, N_res, c_z] # mask: [batch, N_res, N_res] z = self.layer_norm(z) # z: [batch, N_res, N_res, c_z] mask = mask.unsqueeze(-2).unsqueeze(-3) # mask: [batch, N_res, 1, 1, N_res] triangle_bias = self.linear(z) # triangle_bias: [batch, N_res, N_res, num_heads] if mp.is_enabled(): triangle_bias = mp.gather(triangle_bias, dim=-3, bwd="all_reduce_sum_split") triangle_bias = triangle_bias.movedim(-1, -3).unsqueeze(-4).contiguous() # triangle_bias: [batch, 1, num_heads, N_res, N_res] z = self.mha( input_qkv=z, mask=mask, bias=triangle_bias, ) # z: [batch, N_res, N_res, c_z] if not self._is_starting: z = z.transpose(-2, -3) # z: [batch, N_res, N_res, c_z] return z
[docs] class TriangleAttentionStartingNode(TriangleAttention): """Triangle Attention Starting Node module. Supplementary '1.6.6 Triangular self-attention': Algorithm 13 Triangular gated self-attention around starting node. Args: c_z: Pair or template representation dimension (channels). c_hidden: Hidden dimension (channels). num_heads: Number of attention heads. inf: Safe infinity value. chunk_size: Optional chunk size for a batch-like dimension. """ def __init__( self, c_z: int, c_hidden: int, num_heads: int, inf: float, chunk_size: Optional[int], ) -> None: super().__init__( c_z=c_z, c_hidden=c_hidden, num_heads=num_heads, ta_type="starting", inf=inf, chunk_size=chunk_size, )
[docs] class TriangleAttentionEndingNode(TriangleAttention): """Triangle Attention Ending Node module. Supplementary '1.6.6 Triangular self-attention': Algorithm 14 Triangular gated self-attention around ending node. Args: c_z: Pair or template representation dimension (channels). c_hidden: Hidden dimension (channels). num_heads: Number of attention heads. inf: Safe infinity value. chunk_size: Optional chunk size for a batch-like dimension. """ def __init__( self, c_z: int, c_hidden: int, num_heads: int, inf: float, chunk_size: Optional[int], ) -> None: super().__init__( c_z=c_z, c_hidden=c_hidden, num_heads=num_heads, ta_type="ending", inf=inf, chunk_size=chunk_size, )