Source code for deepfold.modules.evoformer_block_pair_core

from typing import Optional, Tuple

import torch
import torch.nn as nn

import deepfold.distributed.model_parallel as mp
from deepfold.modules.dropout import DropoutColumnwise, DropoutRowwise
from deepfold.modules.pair_transition import PairTransition
from deepfold.modules.triangular_attention import TriangleAttentionEndingNode, TriangleAttentionStartingNode
from deepfold.modules.triangular_multiplicative_update import TriangleMultiplicationIncoming, TriangleMultiplicationOutgoing


[docs] class EvoformerBlockPairCore(nn.Module): """Evoformer Block Pair Core module. Pair stack for: - Supplementary '1.6 Evoformer blocks': Algorithm 6 - Supplementary '1.7.2 Unclustered MSA stack': Algorithm 18 Args: c_z: Pair representation dimension (channels). c_hidden_msa_att: Hidden dimension in MSA attention. c_hidden_opm: Hidden dimension in outer product mean. c_hidden_tri_mul: Hidden dimension in multiplicative updates. c_hidden_tri_att: Hidden dimension in triangular attention. num_heads_msa: Number of heads used in MSA attention. num_heads_tri: Number of heads used in triangular attention. transition_n: Channel multiplier in transitions. msa_dropout: Dropout rate for MSA activations. pair_dropout: Dropout rate for pair activations. inf: Safe infinity value. chunk_size_tri_att: Optional chunk size for a batch-like dimension in triangular attention. """ def __init__( self, c_z: int, c_hidden_tri_mul: int, c_hidden_tri_att: int, num_heads_tri: int, transition_n: int, pair_dropout: float, inf: float, chunk_size_tri_att: Optional[int], block_size_tri_mul: Optional[int], ) -> None: super().__init__() self.tri_mul_out = TriangleMultiplicationOutgoing( c_z=c_z, c_hidden=c_hidden_tri_mul, block_size=block_size_tri_mul, ) self.tmo_dropout_rowwise = DropoutRowwise( p=pair_dropout, ) self.tri_mul_in = TriangleMultiplicationIncoming( c_z=c_z, c_hidden=c_hidden_tri_mul, block_size=block_size_tri_mul, ) self.tmi_dropout_rowwise = DropoutRowwise( p=pair_dropout, ) self.tri_att_start = TriangleAttentionStartingNode( c_z=c_z, c_hidden=c_hidden_tri_att, num_heads=num_heads_tri, inf=inf, chunk_size=chunk_size_tri_att, ) self.tasn_dropout_rowwise = DropoutRowwise( p=pair_dropout, ) self.tri_att_end = TriangleAttentionEndingNode( c_z=c_z, c_hidden=c_hidden_tri_att, num_heads=num_heads_tri, inf=inf, chunk_size=chunk_size_tri_att, ) self.taen_dropout_columnwise = DropoutColumnwise( p=pair_dropout, ) self.pair_transition = PairTransition( c_z=c_z, n=transition_n, )
[docs] def forward( self, z: torch.Tensor, pair_mask: torch.Tensor, inplace_safe: bool, ) -> torch.Tensor: """Evoformer Block Core forward pass. Args: z: [batch, N_res, N_res, c_z] pair representation pair_mask: [batch, N_res, N_res] pair mask Returns: z: [batch, N_res, N_res, c_z] updated pair representation """ if mp.is_enabled(): z = self._forward_dap(z=z, pair_mask=pair_mask, inplace_safe=inplace_safe) else: z = self._forward(z=z, pair_mask=pair_mask, inplace_safe=inplace_safe) return z
def _forward( self, z: torch.Tensor, pair_mask: torch.Tensor, inplace_safe: bool, ) -> torch.Tensor: z = self.tmo_dropout_rowwise( self.tri_mul_out(z=z, mask=pair_mask), add_output_to=z, inplace=inplace_safe, ) z = self.tmi_dropout_rowwise( self.tri_mul_in(z=z, mask=pair_mask), add_output_to=z, inplace=inplace_safe, ) z = self.tasn_dropout_rowwise( self.tri_att_start(z=z, mask=pair_mask), add_output_to=z, inplace=inplace_safe, ) z = self.taen_dropout_columnwise( self.tri_att_end(z=z, mask=pair_mask), add_output_to=z, inplace=inplace_safe, ) z = self.pair_transition(z, mask=pair_mask, inplace_safe=inplace_safe) return z def _forward_dap( self, z: torch.Tensor, pair_mask: torch.Tensor, inplace_safe: bool, ) -> torch.Tensor: pair_mask_row = mp.scatter(pair_mask, dim=-2) pair_mask_col = mp.scatter(pair_mask, dim=-1) z = self.tmo_dropout_rowwise( self.tri_mul_out(z=z, mask=pair_mask_row), add_output_to=z, inplace=inplace_safe, ) z = mp.row_to_col(z) z = self.tmi_dropout_rowwise( self.tri_mul_in(z=z, mask=pair_mask_col), dap_scattered_dim=-2, add_output_to=z, inplace=inplace_safe, ) z = mp.col_to_row(z) z = self.tasn_dropout_rowwise( self.tri_att_start(z=z, mask=pair_mask_row), add_output_to=z, inplace=inplace_safe, ) z = mp.row_to_col(z) z = self.taen_dropout_columnwise( self.tri_att_end(z=z, mask=pair_mask_col), add_output_to=z, inplace=inplace_safe, ) z = self.pair_transition(z, mask=pair_mask_col, inplace_safe=inplace_safe) z = mp.col_to_row(z) return z