Source code for deepfold.modules.extra_msa_stack
from functools import partial
from typing import Optional
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint as gradient_checkpointing_fn
import deepfold.distributed.model_parallel as mp
from deepfold.modules.extra_msa_block import ExtraMSABlock
from deepfold.utils.dist_utils import get_pad_size, pad_tensor
[docs]
class ExtraMSAStack(nn.Module):
"""Extra MSA Stack module.
Supplementary '1.7.2 Unclustered MSA stack'.
Args:
c_e: Extra MSA representation dimension (channels).
c_z: Pair representation dimension (channels).
c_hidden_msa_att: Hidden dimension in MSA attention.
c_hidden_opm: Hidden dimension in outer product mean.
c_hidden_tri_mul: Hidden dimension in multiplicative updates.
c_hidden_tri_att: Hidden dimension in triangular attention.
num_heads_msa: Number of heads used in MSA attention.
num_heads_tri: Number of heads used in triangular attention.
num_blocks: Number of blocks in the stack.
transition_n: Channel multiplier in transitions.
msa_dropout: Dropout rate for MSA activations.
pair_dropout: Dropout rate for pair activations.
inf: Safe infinity value.
eps: Epsilon to prevent division by zero.
eps_opm: Epsilon to prevent division by zero in outer product mean.
chunk_size_msa_att: Optional chunk size for a batch-like dimension in MSA attention.
chunk_size_opm: Optional chunk size for a batch-like dimension in outer product mean.
chunk_size_tri_att: Optional chunk size for a batch-like dimension in triangular attention.
"""
def __init__(
self,
c_e: int,
c_z: int,
c_hidden_msa_att: int,
c_hidden_opm: int,
c_hidden_tri_mul: int,
c_hidden_tri_att: int,
num_heads_msa: int,
num_heads_tri: int,
num_blocks: int,
transition_n: int,
msa_dropout: float,
pair_dropout: float,
inf: float,
eps: float,
eps_opm: float,
chunk_size_msa_att: Optional[int],
chunk_size_opm: Optional[int],
chunk_size_tri_att: Optional[int],
block_size_tri_mul: Optional[int],
outer_product_mean_first: bool = False,
) -> None:
super().__init__()
self.opm_first = outer_product_mean_first
self.blocks = nn.ModuleList(
[
ExtraMSABlock(
c_e=c_e,
c_z=c_z,
c_hidden_msa_att=c_hidden_msa_att,
c_hidden_opm=c_hidden_opm,
c_hidden_tri_mul=c_hidden_tri_mul,
c_hidden_tri_att=c_hidden_tri_att,
num_heads_msa=num_heads_msa,
num_heads_tri=num_heads_tri,
transition_n=transition_n,
msa_dropout=msa_dropout,
pair_dropout=pair_dropout,
inf=inf,
eps=eps,
eps_opm=eps_opm,
chunk_size_msa_att=chunk_size_msa_att,
chunk_size_opm=chunk_size_opm,
chunk_size_tri_att=chunk_size_tri_att,
block_size_tri_mul=block_size_tri_mul,
outer_product_mean_first=outer_product_mean_first,
)
for _ in range(num_blocks)
]
)
[docs]
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
gradient_checkpointing: bool,
inplace_safe: bool,
) -> torch.Tensor:
"""Extra MSA Stack forward pass.
Args:
m: [batch, N_extra_seq, N_res, c_e] extra MSA representation
z: [batch, N_res, N_res, c_z] pair representation
msa_mask: [batch, N_extra_seq, N_res] extra MSA mask
pair_mask: [batch, N_res, N_res] pair mask
gradient_checkpointing: whether to use gradient checkpointing
Returns:
z: [batch, N_res, N_res, c_z] updated pair representation
"""
if gradient_checkpointing:
assert torch.is_grad_enabled()
z = self._forward_blocks_with_gradient_checkpointing(
m=m,
z=z,
msa_mask=msa_mask,
pair_mask=pair_mask,
)
else:
z = self._forward_blocks(
m=m,
z=z,
msa_mask=msa_mask,
pair_mask=pair_mask,
inplace_safe=inplace_safe,
)
return z
def _forward_blocks(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
inplace_safe: bool,
) -> torch.Tensor:
if mp.is_enabled():
msa_col_pad_size = get_pad_size(m, -2, mp.size())
msa_row_pad_size = get_pad_size(m, -3, mp.size())
m = pad_tensor(m, -2, msa_col_pad_size)
m = pad_tensor(m, -3, msa_row_pad_size)
if self.opm_first:
m = mp.scatter(m, dim=-2)
else:
m = mp.scatter(m, dim=-3)
msa_mask = pad_tensor(msa_mask, -1, msa_col_pad_size)
msa_mask = pad_tensor(msa_mask, -2, msa_row_pad_size)
pair_pad_size = get_pad_size(z, -3, mp.size())
z = pad_tensor(z, -2, pair_pad_size)
z = pad_tensor(z, -3, pair_pad_size)
z = mp.scatter(z, dim=-3)
pair_mask = pad_tensor(pair_mask, -1, pair_pad_size)
pair_mask = pad_tensor(pair_mask, -2, pair_pad_size)
for block in self.blocks:
m, z = block(m=m, z=z, msa_mask=msa_mask, pair_mask=pair_mask, inplace_safe=inplace_safe)
if mp.is_enabled():
z = mp.gather(z, dim=-3)
if pair_pad_size != 0:
z = z[..., : z.size(-3) - pair_pad_size, : z.size(-2) - pair_pad_size, :]
return z
def _forward_blocks_with_gradient_checkpointing(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
) -> torch.Tensor:
blocks = [partial(block, msa_mask=msa_mask, pair_mask=pair_mask) for block in self.blocks]
if mp.is_enabled():
if self.opm_first:
m = mp.scatter(m, dim=-2)
else:
m = mp.scatter(m, dim=-3)
z = mp.scatter(z, dim=-3)
for block in blocks:
m, z = gradient_checkpointing_fn(block, m, z, use_reentrant=True)
if mp.is_enabled():
z = mp.gather(z, dim=-3)
return z