Source code for deepfold.modules.extra_msa_embedder

import torch
import torch.nn as nn
import torch.nn.functional as F

import deepfold.modules.inductor as inductor
from deepfold.modules.linear import Linear


[docs] class ExtraMSAEmbedder(nn.Module): """Extra MSA Embedder module. Embeds the "extra_msa_feat" feature. Supplementary '1.4 AlphaFold Inference': Algorithm 2, line 15. Args: emsa_dim: Input `extra_msa_feat` dimension (channels). c_e: Output extra MSA representation dimension (channels). """ def __init__( self, emsa_dim: int, c_e: int, ) -> None: super().__init__() self.linear = Linear(emsa_dim, c_e, bias=True, init="default")
[docs] def forward( self, extra_msa_feat: torch.Tensor, ) -> torch.Tensor: """Extra MSA Embedder forward pass. Args: extra_msa_feat: [batch, N_extra_seq, N_res, emsa_dim] Returns: extra_msa_embedding: [batch, N_extra_seq, N_res, c_e] """ if inductor.is_enabled(): forward_fn = _forward_jit else: forward_fn = _forward_eager return forward_fn( extra_msa_feat, self.linear.weight, self.linear.bias, )
def _forward_eager( x: torch.Tensor, w: torch.Tensor, b: torch.Tensor, ) -> torch.Tensor: return F.linear(x, w, b) _forward_jit = torch.compile(_forward_eager)