Source code for deepfold.modules.msa_transition

from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

import deepfold.modules.inductor as inductor
from deepfold.modules.layer_norm import LayerNorm
from deepfold.modules.linear import Linear


[docs] class MSATransition(nn.Module): """MSA Transition module. Supplementary '1.6.3 MSA transition': Algorithm 9. Args: c_m: MSA (or Extra MSA) representation dimension (channels). n: `c_m` multiplier to obtain hidden dimension (channels). """ def __init__( self, c_m: int, n: int, ) -> None: super().__init__() self.layer_norm = LayerNorm(c_m) self.linear_1 = Linear(c_m, n * c_m, bias=True, init="relu") self.linear_2 = Linear(n * c_m, c_m, bias=True, init="final") # TODO: Chunk
[docs] def forward( self, m: torch.Tensor, mask: Optional[torch.Tensor] = None, inplace_safe: bool = False, ) -> torch.Tensor: """MSA Transition forward pass. Args: m: [batch, N_seq, N_res, c_m] MSA representation mask: [batch, N_seq, N_res] MSA mask Returns: m: [batch, N_seq, N_res, c_m] updated MSA representation """ # NOTE: DeepMind forgets to apply the MSA mask here. if mask is None: mask = m.new_ones(m.shape[:-1]) mask = mask.unsqueeze(-1) if inductor.is_enabled(): forward_fn = _forward_jit elif inductor.is_enabled_and_autograd_off(): forward_fn = _forward_jit else: forward_fn = _forward_eager return forward_fn( self.layer_norm(m), mask, self.linear_1.weight, self.linear_1.bias, self.linear_2.weight, self.linear_2.bias, m, inplace_safe, )
def _forward_eager( m: torch.Tensor, mask: torch.Tensor, w1: torch.Tensor, b1: torch.Tensor, w2: torch.Tensor, b2: torch.Tensor, out: torch.Tensor, inplace: bool, ) -> torch.Tensor: m = F.linear(m, w1, b1) m = torch.relu(m) m = F.linear(m, w2, b2) if inplace: m *= mask m += out else: m = m * mask m = out + m return m _forward_jit = torch.compile(_forward_eager) # TODO: Chunk