Source code for deepfold.modules.extra_msa_block
from typing import Optional, Tuple
import torch
import torch.nn as nn
import deepfold.distributed.model_parallel as mp
from deepfold.modules.dropout import DropoutRowwise
from deepfold.modules.evoformer_block_pair_core import EvoformerBlockPairCore
from deepfold.modules.msa_column_global_attention import MSAColumnGlobalAttention
from deepfold.modules.msa_row_attention_with_pair_bias import MSARowAttentionWithPairBias
from deepfold.modules.msa_transition import MSATransition
from deepfold.modules.outer_product_mean import OuterProductMean
[docs]
class ExtraMSABlock(nn.Module):
"""Extra MSA Block module.
Supplementary '1.7.2 Unclustered MSA stack': Algorithm 18.
Args:
c_e: Extra MSA representation dimension (channels).
c_z: Pair representation dimension (channels).
c_hidden_msa_att: Hidden dimension in MSA attention.
c_hidden_opm: Hidden dimension in outer product mean.
c_hidden_tri_mul: Hidden dimension in multiplicative updates.
c_hidden_tri_att: Hidden dimension in triangular attention.
num_heads_msa: Number of heads used in MSA attention.
num_heads_tri: Number of heads used in triangular attention.
transition_n: Channel multiplier in transitions.
msa_dropout: Dropout rate for MSA activations.
pair_dropout: Dropout rate for pair activations.
inf: Safe infinity value.
eps: Epsilon to prevent division by zero.
eps_opm: Epsilon to prevent division by zero in outer product mean.
chunk_size_msa_att: Optional chunk size for a batch-like dimension in MSA attention.
chunk_size_opm: Optional chunk size for a batch-like dimension in outer product mean.
chunk_size_tri_att: Optional chunk size for a batch-like dimension in triangular attention.
"""
def __init__(
self,
c_e: int,
c_z: int,
c_hidden_msa_att: int,
c_hidden_opm: int,
c_hidden_tri_mul: int,
c_hidden_tri_att: int,
num_heads_msa: int,
num_heads_tri: int,
transition_n: int,
msa_dropout: float,
pair_dropout: float,
inf: float,
eps: float,
eps_opm: float,
chunk_size_msa_att: Optional[int],
chunk_size_opm: Optional[int],
chunk_size_tri_att: Optional[int],
block_size_tri_mul: Optional[int],
outer_product_mean_first: bool = False,
) -> None:
super().__init__()
self.opm_first = outer_product_mean_first
self.msa_att_row = MSARowAttentionWithPairBias(
c_m=c_e,
c_z=c_z,
c_hidden=c_hidden_msa_att,
num_heads=num_heads_msa,
inf=inf,
chunk_size=chunk_size_msa_att,
)
self.msa_att_col = MSAColumnGlobalAttention(
c_e=c_e,
c_hidden=c_hidden_msa_att,
num_heads=num_heads_msa,
inf=inf,
eps=eps,
chunk_size=chunk_size_msa_att,
)
self.msa_dropout_rowwise = DropoutRowwise(
p=msa_dropout,
)
self.msa_transition = MSATransition(
c_m=c_e,
n=transition_n,
)
self.outer_product_mean = OuterProductMean(
c_m=c_e,
c_z=c_z,
c_hidden=c_hidden_opm,
eps=eps_opm,
chunk_size=chunk_size_opm,
)
self.pair_core = EvoformerBlockPairCore(
c_z=c_z,
c_hidden_tri_mul=c_hidden_tri_mul,
c_hidden_tri_att=c_hidden_tri_att,
num_heads_tri=num_heads_tri,
transition_n=transition_n,
pair_dropout=pair_dropout,
inf=inf,
chunk_size_tri_att=chunk_size_tri_att,
block_size_tri_mul=block_size_tri_mul,
)
[docs]
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
inplace_safe: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Extra MSA Block forward pass.
Args:
m: [batch, N_extra_seq, N_res, c_e] extra MSA representation
z: [batch, N_res, N_res, c_z] pair representation
msa_mask: [batch, N_extra_seq, N_res] extra MSA mask
pair_mask: [batch, N_res, N_res] pair mask
Returns:
m: [batch, N_extra_seq, N_res, c_e] updated extra MSA representation
z: [batch, N_res, N_res, c_z] updated pair representation
"""
if mp.is_enabled():
msa_mask_row = mp.scatter(msa_mask, dim=-2)
msa_mask_col = mp.scatter(msa_mask, dim=-1)
if self.opm_first:
z = self.outer_product_mean(m=m, mask=msa_mask, add_output_to=z, inplace_safe=inplace_safe)
m = mp.col_to_row(m)
m = self.msa_dropout_rowwise(self.msa_att_row(m=m, z=z, mask=msa_mask_row), add_output_to=m)
m = mp.row_to_col(m)
m = self.msa_att_col(m=m, mask=msa_mask_col)
m = self.msa_transition(m=m, mask=msa_mask_col, inplace_safe=inplace_safe)
if not self.opm_first:
z = self.outer_product_mean(m=m, mask=msa_mask, add_output_to=z, inplace_safe=inplace_safe)
m = mp.col_to_row(m)
else:
if self.opm_first:
z = self.outer_product_mean(m=m, mask=msa_mask, add_output_to=z, inplace_safe=inplace_safe)
m = self.msa_dropout_rowwise(self.msa_att_row(m=m, z=z, mask=msa_mask), add_output_to=m)
m = self.msa_att_col(m=m, mask=msa_mask)
m = self.msa_transition(m=m, mask=msa_mask, inplace_safe=inplace_safe)
if not self.opm_first:
z = self.outer_product_mean(m=m, mask=msa_mask, add_output_to=z, inplace_safe=inplace_safe)
z = self.pair_core(z=z, pair_mask=pair_mask, inplace_safe=inplace_safe)
return m, z