deepfold.modules package

Subpackages

Submodules

deepfold.modules.alphafold module

class deepfold.modules.alphafold.AlphaFold(config: AlphaFoldConfig)

Bases: PatchedModule

AlphaFold2 module.

Supplementary ‘1.4 AlphaFold Inference’: Algorithm 2.

forward(batch: Dict[str, Tensor], recycle_hook: Callable[[int, dict, dict], None] | None = None, save_all: bool = True) Dict[str, Tensor]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

register_dap_gradient_scaling_hooks(dap_size: int) None

deepfold.modules.angle_resnet module

class deepfold.modules.angle_resnet.AngleResnet(c_s: int, c_hidden: int, num_blocks: int, num_angles: int, eps: float)

Bases: PatchedModule

Angle Resnet module.

Supplementary ‘1.8 Structure module’: Algorithm 20, lines 11-14.

Parameters:
  • c_s – Single representation dimension (channels).

  • c_hidden – Hidden dimension (channels).

  • num_blocks – Number of resnet blocks.

  • num_angles – Number of torsion angles to generate.

  • eps – Epsilon to prevent division by zero.

forward(s: Tensor, s_initial: Tensor) Tuple[Tensor, Tensor]

Angle Resnet forward pass.

Parameters:
  • s – [batch, N_res, c_s] single representation

  • s_initial – [batch, N_res, c_s] initial single representation

Returns:

[batch, N_res, num_angles, 2] angles: [batch, N_res, num_angles, 2]

Return type:

unnormalized_angles

class deepfold.modules.angle_resnet.AngleResnetBlock(c_hidden: int)

Bases: PatchedModule

Angle Resnet Block module.

forward(a: Tensor) Tensor

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

deepfold.modules.attention module

class deepfold.modules.attention.CrossAttentionNoGate(c_q: int, c_kv: int, c_hidden: int, num_heads: int, inf: float, chunk_size: int | None, impl: str | None = None)

Bases: PatchedModule

Cross Multi-Head Attention module without gating.

Parameters:
  • c_q – Input dimension of query data tensor (channels).

  • c_kv – Input dimension of key|value data tensor (channels).

  • c_hidden – Hidden dimension (per-head).

  • num_heads – Number of attention heads.

  • inf – Safe infinity value.

  • chunk_size – Optional chunk size for a batch-like dimension. Supplementary ‘1.11.8 Reducing the memory consumption’: Inference.

forward(input_q: Tensor, input_kv: Tensor, mask: Tensor, bias: Tensor | None) Tensor

Attention forward pass.

Parameters:
  • input_q – [*, Q, c_q] query data

  • input_kv – [*, KV, c_kv] key|value data (KV == K == V)

  • mask – Logit mask tensor broadcastable to [*, num_heads, Q, KV]

  • bias – Optional logit bias tensor broadcastable to [*, num_heads, Q, KV]

Returns:

[*, Q, c_q] tensor

Return type:

output

class deepfold.modules.attention.SelfAttentionWithGate(c_qkv: int, c_hidden: int, num_heads: int, inf: float, chunk_size: int | None, impl: str | None = None)

Bases: PatchedModule

Self Multi-Head Attention module with gating.

Parameters:
  • c_qkv – Input dimension of query|key|value data tensor (channels).

  • c_hidden – Hidden dimension (per-head).

  • num_heads – Number of attention heads.

  • inf – Safe infinity value.

  • chunk_size – Optional chunk size for a batch-like dimension. Supplementary ‘1.11.8 Reducing the memory consumption’: Inference.

forward(input_qkv: Tensor, mask: Tensor, bias: Tensor | None, add_transposed_output_to: Tensor | None = None) Tensor

Attention forward pass.

Parameters:
  • input_qkv – [*, QKV, c_qkv] query data (QKV == Q == K == V)

  • mask – Logit mask tensor broadcastable to [*, num_heads, Q, K]

  • bias – Optional logit bias tensor broadcastable to [*, num_heads, Q, K]

  • add_transposed_output_to – Optional tensor to which transposed output will be added elementwisely.

Returns:

[*, Q, c_qkv] tensor

Return type:

output

deepfold.modules.auxiliary_heads module

class deepfold.modules.auxiliary_heads.AuxiliaryHeads(config: AuxiliaryHeadsConfig)

Bases: PatchedModule

Auxiliary Heads module.

forward(outputs: Dict[str, Tensor], seq_mask: Tensor | None = None, asym_id: Tensor | None = None) Dict[str, Tensor]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class deepfold.modules.auxiliary_heads.DistogramHead(c_z: int, num_bins: int)

Bases: PatchedModule

Distogram Head module.

Computes a distogram probability distribution.

Supplementary ‘1.9.8 Distogram prediction’.

Parameters:
  • c_z – Pair representation dimension (channels).

  • num_bins – Output dimension (channels).

forward(z: Tensor) Tensor

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class deepfold.modules.auxiliary_heads.ExperimentallyResolvedHead(c_s: int, c_out: int)

Bases: PatchedModule

Experimentally Resolved Head module.

Supplementary ‘1.9.10 Experimentally resolved prediction’.

Parameters:
  • c_s – Single representation dimension (channels).

  • c_out – Output dimension (channels).

forward(s: Tensor) Tensor

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class deepfold.modules.auxiliary_heads.MaskedMSAHead(c_m: int, c_out: int)

Bases: PatchedModule

Masked MSA Head module.

Supplementary ‘1.9.9 Masked MSA prediction’.

Parameters:
  • c_m – MSA representation dimension (channels).

  • c_out – Output dimension (channels).

forward(m: Tensor) Tensor

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class deepfold.modules.auxiliary_heads.PerResidueLDDTCaPredictor(c_s: int, c_hidden: int, num_bins: int)

Bases: PatchedModule

Per-Residue LDDT-Ca Predictor module.

Supplementary ‘1.9.6 Model confidence prediction (pLDDT)’: Algorithm 29.

Parameters:
  • c_s – Single representation dimension (channels).

  • c_hidden – Hidden dimension (channels).

  • num_bins – Output dimension (channels).

forward(s: Tensor) Tensor

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class deepfold.modules.auxiliary_heads.TMScoreHead(c_z: int, num_bins: int, max_bin: int)

Bases: PatchedModule

TM-Score Head module.

Supplementary ‘1.9.7 TM-score prediction’.

Parameters:
  • c_z – Pair representation dimension (channels).

  • num_bins – Output dimension (channels).

  • max_bin – Max bin range for discretizing the distribution.

forward(z: Tensor) Tensor

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

deepfold.modules.backbone_update module

class deepfold.modules.backbone_update.BackboneUpdate(c_s: int)

Bases: PatchedModule

Backbone Update module.

Supplementary ‘1.8.3 Backbone update’: Algorithm 23.

Parameters:

c_s – Single representation dimension (channels).

forward(s: Tensor) Tensor

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

deepfold.modules.dropout module

class deepfold.modules.dropout.Dropout(p: float, share_dim: int | Tuple[int, ...] = ())

Bases: PatchedModule

Dropout module.

Implementation of dropout with the ability to share the dropout mask along a particular dimension.

If not in training mode, this module computes the identity function.

Supplementary ‘1.11.6 Dropout details’.

Parameters:
  • p – Dropout rate (probability of an element to be zeroed).

  • share_dim – Dimension(s) along which the dropout mask is shared.

  • inplace – If set to True, will do this operation in-place.

forward(x: Tensor, add_output_to: Tensor, dap_scattered_dim: int | None = None, inplace: bool = False) Tensor

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class deepfold.modules.dropout.DropoutColumnwise(p: float)

Bases: Dropout

Dropout Columnwise module.

class deepfold.modules.dropout.DropoutRowwise(p: float)

Bases: Dropout

Dropout Rowwise module.

deepfold.modules.evoformer_block module

class deepfold.modules.evoformer_block.EvoformerBlock(c_m: int, c_z: int, c_hidden_msa_att: int, c_hidden_opm: int, c_hidden_tri_mul: int, c_hidden_tri_att: int, num_heads_msa: int, num_heads_tri: int, transition_n: int, msa_dropout: float, pair_dropout: float, inf: float, eps_opm: float, chunk_size_msa_att: int | None, chunk_size_opm: int | None, chunk_size_tri_att: int | None, block_size_tri_mul: int | None, outer_product_mean_first: bool = False)

Bases: PatchedModule

Evoformer Block module.

Supplementary ‘1.6 Evoformer blocks’: Algorithm 6. MSA Transition and Communication are included.

Parameters:
  • c_m – MSA representation dimension (channels).

  • c_z – Pair representation dimension (channels).

  • c_hidden_msa_att – Hidden dimension in MSA attention.

  • c_hidden_opm – Hidden dimension in outer product mean.

  • c_hidden_tri_mul – Hidden dimension in multiplicative updates.

  • c_hidden_tri_att – Hidden dimension in triangular attention.

  • num_heads_msa – Number of heads used in MSA attention.

  • num_heads_tri – Number of heads used in triangular attention.

  • transition_n – Channel multiplier in transitions.

  • msa_dropout – Dropout rate for MSA activations.

  • pair_dropout – Dropout rate for pair activations.

  • inf – Safe infinity value.

  • eps_opm – Epsilon to prevent division by zero in outer product mean.

  • chunk_size_msa_att – Optional chunk size for a batch-like dimension in MSA attention.

  • chunk_size_opm – Optional chunk size for a batch-like dimension in outer product mean.

  • chunk_size_tri_att – Optional chunk size for a batch-like dimension in triangular attention.

forward(m: Tensor, z: Tensor, msa_mask: Tensor, pair_mask: Tensor, inplace_safe: bool) Tuple[Tensor, Tensor]

Evoformer Block forward pass.

Parameters:
  • m – [batch, N_seq, N_res, c_m] MSA representation

  • z – [batch, N_res, N_res, c_z] pair representation

  • msa_mask – [batch, N_seq, N_res] MSA mask

  • pair_mask – [batch, N_res, N_res] pair mask

Returns:

[batch, N_seq, N_res, c_m] updated MSA representation z: [batch, N_res, N_res, c_z] updated pair representation

Return type:

m

deepfold.modules.evoformer_block_pair_core module

class deepfold.modules.evoformer_block_pair_core.EvoformerBlockPairCore(c_z: int, c_hidden_tri_mul: int, c_hidden_tri_att: int, num_heads_tri: int, transition_n: int, pair_dropout: float, inf: float, chunk_size_tri_att: int | None, block_size_tri_mul: int | None)

Bases: PatchedModule

Evoformer Block Pair Core module.

Pair stack for: - Supplementary ‘1.6 Evoformer blocks’: Algorithm 6 - Supplementary ‘1.7.2 Unclustered MSA stack’: Algorithm 18

Parameters:
  • c_z – Pair representation dimension (channels).

  • c_hidden_msa_att – Hidden dimension in MSA attention.

  • c_hidden_opm – Hidden dimension in outer product mean.

  • c_hidden_tri_mul – Hidden dimension in multiplicative updates.

  • c_hidden_tri_att – Hidden dimension in triangular attention.

  • num_heads_msa – Number of heads used in MSA attention.

  • num_heads_tri – Number of heads used in triangular attention.

  • transition_n – Channel multiplier in transitions.

  • msa_dropout – Dropout rate for MSA activations.

  • pair_dropout – Dropout rate for pair activations.

  • inf – Safe infinity value.

  • chunk_size_tri_att – Optional chunk size for a batch-like dimension in triangular attention.

forward(z: Tensor, pair_mask: Tensor, inplace_safe: bool) Tensor

Evoformer Block Core forward pass.

Parameters:
  • z – [batch, N_res, N_res, c_z] pair representation

  • pair_mask – [batch, N_res, N_res] pair mask

Returns:

[batch, N_res, N_res, c_z] updated pair representation

Return type:

z

deepfold.modules.evoformer_stack module

class deepfold.modules.evoformer_stack.EvoformerStack(c_m: int, c_z: int, c_hidden_msa_att: int, c_hidden_opm: int, c_hidden_tri_mul: int, c_hidden_tri_att: int, c_s: int, num_heads_msa: int, num_heads_tri: int, num_blocks: int, transition_n: int, msa_dropout: float, pair_dropout: float, inf: float, eps_opm: float, chunk_size_msa_att: int | None, chunk_size_opm: int | None, chunk_size_tri_att: int | None, block_size_tri_mul: int | None, outer_product_mean_first: bool = False)

Bases: PatchedModule

Evoformer Stack module.

Supplementary ‘1.6 Evoformer blocks’: Algorithm 6.

Parameters:
  • c_m – MSA representation dimension (channels).

  • c_z – Pair representation dimension (channels).

  • c_hidden_msa_att – Hidden dimension in MSA attention.

  • c_hidden_opm – Hidden dimension in outer product mean.

  • c_hidden_tri_mul – Hidden dimension in multiplicative updates.

  • c_hidden_tri_att – Hidden dimension in triangular attention.

  • c_s – Single representation dimension (channels).

  • num_heads_msa – Number of heads used in MSA attention.

  • num_heads_tri – Number of heads used in triangular attention.

  • num_blocks – Number of blocks in the stack.

  • transition_n – Channel multiplier in transitions.

  • msa_dropout – Dropout rate for MSA activations.

  • pair_dropout – Dropout rate for pair activations.

  • inf – Safe infinity value.

  • eps_opm – Epsilon to prevent division by zero in outer product mean.

  • chunk_size_msa_att – Optional chunk size for a batch-like dimension in MSA attention.

  • chunk_size_opm – Optional chunk size for a batch-like dimension in outer product mean.

  • chunk_size_tri_att – Optional chunk size for a batch-like dimension in triangular attention.

forward(m: Tensor, z: Tensor, msa_mask: Tensor, pair_mask: Tensor, gradient_checkpointing: bool, inplace_safe: bool) Tuple[Tensor, Tensor, Tensor]

Evoformer Stack forward pass.

Parameters:
  • m – [batch, N_seq, N_res, c_m] MSA representation

  • z – [batch, N_res, N_res, c_z] pair representation

  • msa_mask – [batch, N_seq, N_res] MSA mask

  • pair_mask – [batch, N_res, N_res] pair mask

  • gradient_checkpointing – whether to use gradient checkpointing

Returns:

[batch, N_seq, N_res, c_m] updated MSA representation z: [batch, N_res, N_res, c_z] updated pair representation s: [batch, N_res, c_s] single representation

Return type:

m

deepfold.modules.extra_msa_block module

class deepfold.modules.extra_msa_block.ExtraMSABlock(c_e: int, c_z: int, c_hidden_msa_att: int, c_hidden_opm: int, c_hidden_tri_mul: int, c_hidden_tri_att: int, num_heads_msa: int, num_heads_tri: int, transition_n: int, msa_dropout: float, pair_dropout: float, inf: float, eps: float, eps_opm: float, chunk_size_msa_att: int | None, chunk_size_opm: int | None, chunk_size_tri_att: int | None, block_size_tri_mul: int | None, outer_product_mean_first: bool = False)

Bases: PatchedModule

Extra MSA Block module.

Supplementary ‘1.7.2 Unclustered MSA stack’: Algorithm 18.

Parameters:
  • c_e – Extra MSA representation dimension (channels).

  • c_z – Pair representation dimension (channels).

  • c_hidden_msa_att – Hidden dimension in MSA attention.

  • c_hidden_opm – Hidden dimension in outer product mean.

  • c_hidden_tri_mul – Hidden dimension in multiplicative updates.

  • c_hidden_tri_att – Hidden dimension in triangular attention.

  • num_heads_msa – Number of heads used in MSA attention.

  • num_heads_tri – Number of heads used in triangular attention.

  • transition_n – Channel multiplier in transitions.

  • msa_dropout – Dropout rate for MSA activations.

  • pair_dropout – Dropout rate for pair activations.

  • inf – Safe infinity value.

  • eps – Epsilon to prevent division by zero.

  • eps_opm – Epsilon to prevent division by zero in outer product mean.

  • chunk_size_msa_att – Optional chunk size for a batch-like dimension in MSA attention.

  • chunk_size_opm – Optional chunk size for a batch-like dimension in outer product mean.

  • chunk_size_tri_att – Optional chunk size for a batch-like dimension in triangular attention.

forward(m: Tensor, z: Tensor, msa_mask: Tensor, pair_mask: Tensor, inplace_safe: bool) Tuple[Tensor, Tensor]

Extra MSA Block forward pass.

Parameters:
  • m – [batch, N_extra_seq, N_res, c_e] extra MSA representation

  • z – [batch, N_res, N_res, c_z] pair representation

  • msa_mask – [batch, N_extra_seq, N_res] extra MSA mask

  • pair_mask – [batch, N_res, N_res] pair mask

Returns:

[batch, N_extra_seq, N_res, c_e] updated extra MSA representation z: [batch, N_res, N_res, c_z] updated pair representation

Return type:

m

deepfold.modules.extra_msa_embedder module

class deepfold.modules.extra_msa_embedder.ExtraMSAEmbedder(emsa_dim: int, c_e: int)

Bases: PatchedModule

Extra MSA Embedder module.

Embeds the “extra_msa_feat” feature.

Supplementary ‘1.4 AlphaFold Inference’: Algorithm 2, line 15.

Parameters:
  • emsa_dim – Input extra_msa_feat dimension (channels).

  • c_e – Output extra MSA representation dimension (channels).

forward(extra_msa_feat: Tensor) Tensor

Extra MSA Embedder forward pass.

Parameters:

extra_msa_feat – [batch, N_extra_seq, N_res, emsa_dim]

Returns:

[batch, N_extra_seq, N_res, c_e]

Return type:

extra_msa_embedding

deepfold.modules.extra_msa_stack module

class deepfold.modules.extra_msa_stack.ExtraMSAStack(c_e: int, c_z: int, c_hidden_msa_att: int, c_hidden_opm: int, c_hidden_tri_mul: int, c_hidden_tri_att: int, num_heads_msa: int, num_heads_tri: int, num_blocks: int, transition_n: int, msa_dropout: float, pair_dropout: float, inf: float, eps: float, eps_opm: float, chunk_size_msa_att: int | None, chunk_size_opm: int | None, chunk_size_tri_att: int | None, block_size_tri_mul: int | None, outer_product_mean_first: bool = False)

Bases: PatchedModule

Extra MSA Stack module.

Supplementary ‘1.7.2 Unclustered MSA stack’.

Parameters:
  • c_e – Extra MSA representation dimension (channels).

  • c_z – Pair representation dimension (channels).

  • c_hidden_msa_att – Hidden dimension in MSA attention.

  • c_hidden_opm – Hidden dimension in outer product mean.

  • c_hidden_tri_mul – Hidden dimension in multiplicative updates.

  • c_hidden_tri_att – Hidden dimension in triangular attention.

  • num_heads_msa – Number of heads used in MSA attention.

  • num_heads_tri – Number of heads used in triangular attention.

  • num_blocks – Number of blocks in the stack.

  • transition_n – Channel multiplier in transitions.

  • msa_dropout – Dropout rate for MSA activations.

  • pair_dropout – Dropout rate for pair activations.

  • inf – Safe infinity value.

  • eps – Epsilon to prevent division by zero.

  • eps_opm – Epsilon to prevent division by zero in outer product mean.

  • chunk_size_msa_att – Optional chunk size for a batch-like dimension in MSA attention.

  • chunk_size_opm – Optional chunk size for a batch-like dimension in outer product mean.

  • chunk_size_tri_att – Optional chunk size for a batch-like dimension in triangular attention.

forward(m: Tensor, z: Tensor, msa_mask: Tensor, pair_mask: Tensor, gradient_checkpointing: bool, inplace_safe: bool) Tensor

Extra MSA Stack forward pass.

Parameters:
  • m – [batch, N_extra_seq, N_res, c_e] extra MSA representation

  • z – [batch, N_res, N_res, c_z] pair representation

  • msa_mask – [batch, N_extra_seq, N_res] extra MSA mask

  • pair_mask – [batch, N_res, N_res] pair mask

  • gradient_checkpointing – whether to use gradient checkpointing

Returns:

[batch, N_res, N_res, c_z] updated pair representation

Return type:

z

deepfold.modules.global_attention module

class deepfold.modules.global_attention.GlobalAttention(c_e: int, c_hidden: int, num_heads: int, inf: float, eps: float, chunk_size: int | None)

Bases: PatchedModule

Global Attention module.

Parameters:
  • c_e – Extra MSA representation dimension (channels).

  • c_hidden – Per-head hidden dimension (channels).

  • num_heads – Number of attention heads.

  • inf – Safe infinity value.

  • eps – Epsilon to prevent division by zero.

  • chunk_size – Optional chunk size for a batch-like dimension.

forward(m: Tensor, mask: Tensor, add_transposed_output_to: Tensor | None) Tensor

Global Attention forward pass.

Parameters:
  • m – [batch, N_res, N_extra_seq, c_e] transposed extra MSA representation

  • mask – [batch, N_res, N_extra_seq] transposed extra MSA mask

  • add_transposed_output_to – Optional tensor to which transposed output will be added elementwisely.

Returns:

[batch, N_extra_seq, N_res, c_e] updated extra MSA representation

Return type:

m

deepfold.modules.inductor module

deepfold.modules.inductor.disable()
deepfold.modules.inductor.enable() None
deepfold.modules.inductor.is_enabled() bool
deepfold.modules.inductor.is_enabled_and_autograd_off() bool

deepfold.modules.input_embedder module

class deepfold.modules.input_embedder.InputEmbedder(tf_dim: int, msa_dim: int, c_z: int, c_m: int, max_relative_index: int, **kwargs)

Bases: PatchedModule

Input Embedder module.

Supplementary ‘1.5 Input embeddings’.

Parameters:
  • tf_dim – Input target_feat dimension (channels).

  • msa_dim – Input msa_feat dimension (channels).

  • c_z – Output pair representation dimension (channels).

  • c_m – Output MSA representation dimension (channels).

  • relpos_k – Relative position clip distance.

forward(target_feat: Tensor, residue_index: Tensor, msa_feat: Tensor, inplace_safe: bool) Tuple[Tensor, Tensor]

Input Embedder forward pass.

Supplementary ‘1.5 Input embeddings’: Algorithm 3.

Parameters:
  • target_feat – [batch, N_res, tf_dim]

  • residue_index – [batch, N_res]

  • msa_feat – [batch, N_clust, N_res, msa_dim]

Returns:

[batch, N_clust, N_res, c_m] pair_emb: [batch, N_res, N_res, c_z]

Return type:

msa_emb

class deepfold.modules.input_embedder.InputEmbedderMultimer(tf_dim: int, msa_dim: int, c_z: int, c_m: int, max_relative_index: int, use_chain_relative: bool, max_relative_chain: int, **kwargs)

Bases: PatchedModule

Input Embedder module for multimer model.

forward(target_feat: Tensor, residue_index: Tensor, msa_feat: Tensor, asym_id: Tensor, entity_id: Tensor, sym_id: Tensor, inplace_safe: bool) Tuple[Tensor, Tensor]

Input Embedder forward pass.

relpos(residue_index: Tensor, asym_id: Tensor, entity_id: Tensor, sym_id: Tensor) Tensor

deepfold.modules.invariant_point_attention module

class deepfold.modules.invariant_point_attention.InvariantPointAttention(c_s: int, c_z: int, c_hidden: int, num_heads: int, num_qk_points: int, num_v_points: int, separate_kv: bool, inf: float, eps: float)

Bases: PatchedModule

Invariant Point Attention (IPA) module.

Supplementary ‘1.8.2 Invariant point attention (IPA)’: Algorithm 22.

Parameters:
  • c_s – Single representation dimension (channels).

  • c_z – Pair representation dimension (channels).

  • c_hidden – Hidden dimension (channels).

  • num_heads – Number of attention heads.

  • num_qk_points – Number of query/key points.

  • num_v_points – Number of value points.

  • separate_kv – Separate key/value projection.

  • inf – Safe infinity value.

  • eps – Epsilon to prevent division by zero.

forward(s: Tensor, z: Tensor, r: Rigid, mask: Tensor, inplace_safe: bool) Tensor

Invariant Point Attention (IPA) forward pass.

Parameters:
  • s – [batch, N_res, c_s] single representation

  • z – [batch, N_res, N_res, c_z] pair representation

  • r – [batch, N_res] rigids transformation

  • mask – [batch, N_res] sequence mask

Returns:

[batch, N_res, c_s] single representation update

Return type:

s_update

class deepfold.modules.invariant_point_attention.InvariantPointAttentionMultimer(c_s: int, c_z: int, c_hidden: int, num_heads: int, num_qk_points: int, num_v_points: int, inf: float = 100000.0, eps: float = 1e-08)

Bases: PatchedModule

forward(s: Tensor, z: Tensor, r: Rigid3Array, mask: Tensor, inplace_safe: bool) Tensor
Parameters:
  • s – [*, N_res, C_s] single representation

  • z – [*, N_res, N_res, C_z] pair representation

  • r – [*, N_res] transformation object

  • mask – [*, N_res] mask

Returns:

[*, N_res, C_s] single representation update

class deepfold.modules.invariant_point_attention.PointProjection(c_hidden: int, num_points: int, no_heads: int, is_multimer: bool, return_local_points: bool = False)

Bases: PatchedModule

forward(activations: Tensor, rigids: Rigid | Rigid3Array) Tensor | Tuple[Tensor, Tensor]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

deepfold.modules.invariant_point_attention.ipa_point_weights_init_(weights_data: Tensor) None

deepfold.modules.layer_norm module

class deepfold.modules.layer_norm.LayerNorm(in_channels: int, eps: float = 1e-05)

Bases: PatchedModule

Layer Normalization module.

Supplementary ‘1.11.4 Parameters initialization’: Layer normalization.

Parameters:
  • in_channels – Last dimension of the input tensor.

  • eps – A value added to the denominator for numerical stability.

forward(x: Tensor) Tensor

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

deepfold.modules.linear module

class deepfold.modules.linear.Linear(in_features: int, out_features: int, bias: bool = True, init: str = 'default')

Bases: Linear

Linear transformation with extra non-standard initializations.

Supplementary ‘1.11.4 Parameters initialization’: Linear layers.

Parameters:
  • in_features – Last dimension of the input tensor.

  • out_features – Last dimension of the output tensor.

  • bias – Whether to learn an additive bias. Default: True.

  • init – Parameter initialization method. One of: - “default”: LeCun (fan-in) with a truncated normal distribution - “relu”: He initialization with a truncated normal distribution - “glorot”: fan-average Glorot uniform initialization - “gating”: Weights=0, Bias=1 - “normal”: Normal initialization with std=1/sqrt(fan_in) - “final”: Weights=0, Bias=0

deepfold.modules.linear.final_init_(weight_data: Tensor) None
deepfold.modules.linear.gating_init_(weight_data: Tensor) None
deepfold.modules.linear.glorot_uniform_init_(weight_data: Tensor) None
deepfold.modules.linear.he_normal_init_(weight_data: Tensor) None
deepfold.modules.linear.lecun_normal_init_(weight_data: Tensor) None
deepfold.modules.linear.normal_init_(weight_data: Tensor) None
deepfold.modules.linear.trunc_normal_init_(weight_data: Tensor, scale: float = 1.0, fan: str = 'fan_in') None

deepfold.modules.msa_column_attention module

class deepfold.modules.msa_column_attention.MSAColumnAttention(c_m: int, c_hidden: int, num_heads: int, inf: float, chunk_size: int | None)

Bases: PatchedModule

MSA Column Attention module.

Supplementary ‘1.6.2 MSA column-wise gated self-attention’: Algorithm 8.

Parameters:
  • c_m – MSA representation dimension (channels).

  • c_hidden – Per-head hidden dimension (channels).

  • num_heads – Number of attention heads.

  • inf – Safe infinity value.

  • chunk_size – Optional chunk size for a batch-like dimension.

forward(m: Tensor, mask: Tensor) Tensor

MSA Column Attention forward pass.

Parameters:
  • m – [batch, N_seq, N_res, c_m] MSA representation

  • mask – [batch, N_seq, N_res] MSA mask

Returns:

[batch, N_seq, N_res, c_m] updated MSA representation

Return type:

m

deepfold.modules.msa_column_global_attention module

class deepfold.modules.msa_column_global_attention.MSAColumnGlobalAttention(c_e: int, c_hidden: int, num_heads: int, inf: float, eps: float, chunk_size: int | None)

Bases: PatchedModule

MSA Column Global Attention module.

Supplementary ‘1.7.2 Unclustered MSA stack’: Algorithm 19 MSA global column-wise gated self-attention.

Parameters:
  • c_e – Extra MSA representation dimension (channels).

  • c_hidden – Per-head hidden dimension (channels).

  • num_heads – Number of attention heads.

  • inf – Safe infinity value.

  • eps – Epsilon to prevent division by zero.

  • chunk_size – Optional chunk size for a batch-like dimension.

forward(m: Tensor, mask: Tensor) Tensor

MSA Column Global Attention forward pass.

Parameters:
  • m – [batch, N_extra_seq, N_res, c_e] extra MSA representation

  • mask – [batch, N_extra_seq, N_res] extra MSA mask

Returns:

[batch, N_extra_seq, N_res, c_e] updated extra MSA representation

Return type:

m

deepfold.modules.msa_row_attention_with_pair_bias module

class deepfold.modules.msa_row_attention_with_pair_bias.MSARowAttentionWithPairBias(c_m: int, c_z: int, c_hidden: int, num_heads: int, inf: float, chunk_size: int | None)

Bases: PatchedModule

MSA Row Attention With Pair Bias module.

Supplementary ‘1.6.1 MSA row-wise gated self-attention with pair bias’: Algorithm 7.

Parameters:
  • c_m – MSA (or Extra MSA) representation dimension (channels).

  • c_z – Pair representation dimension (channels).

  • c_hidden – Per-head hidden dimension (channels).

  • num_heads – Number of attention heads.

  • inf – Safe infinity value.

  • chunk_size – Optional chunk size for a batch-like dimension.

forward(m: Tensor, z: Tensor, mask: Tensor) Tensor

MSA Row Attention With Pair Bias forward pass.

Parameters:
  • m – [batch, N_seq, N_res, c_m] MSA (or Extra MSA) representation

  • z – [batch, N_res, N_res, c_z] pair representation

  • mask – [batch, N_seq, N_res] MSA (or Extra MSA) mask

Returns:

[batch, N_seq, N_res, c_m]

MSA (or Extra MSA) representation update

Return type:

m_update

deepfold.modules.msa_transition module

class deepfold.modules.msa_transition.MSATransition(c_m: int, n: int)

Bases: PatchedModule

MSA Transition module.

Supplementary ‘1.6.3 MSA transition’: Algorithm 9.

Parameters:
  • c_m – MSA (or Extra MSA) representation dimension (channels).

  • nc_m multiplier to obtain hidden dimension (channels).

forward(m: Tensor, mask: Tensor | None = None, inplace_safe: bool = False) Tensor

MSA Transition forward pass.

Parameters:
  • m – [batch, N_seq, N_res, c_m] MSA representation

  • mask – [batch, N_seq, N_res] MSA mask

Returns:

[batch, N_seq, N_res, c_m] updated MSA representation

Return type:

m

deepfold.modules.outer_product_mean module

class deepfold.modules.outer_product_mean.OuterProductMean(c_m: int, c_z: int, c_hidden: int, eps: float, chunk_size: int | None)

Bases: PatchedModule

Outer Product Mean module.

Supplementary ‘1.6.4 Outer product mean’: Algorithm 10.

Parameters:
  • c_m – MSA (or Extra MSA) representation dimension (channels).

  • c_z – Pair representation dimension (channels).

  • c_hidden – Hidden dimension (channels).

  • eps – Epsilon to prevent division by zero.

  • chunk_size – Optional chunk size for a batch-like dimension.

forward(m: Tensor, mask: Tensor, add_output_to: Tensor, inplace_safe: bool) Tensor

Outer Product Mean forward pass.

Parameters:
  • m – [batch, N_seq, N_res, c_m] MSA representation

  • mask – [batch, N_seq, N_res] MSA mask

  • add_output_to – pair representation to which add outer update

Returns:

[batch, N_res, N_res, c_z] updated pair representation

Return type:

outer

deepfold.modules.pair_transition module

class deepfold.modules.pair_transition.PairTransition(c_z: int, n: int)

Bases: PatchedModule

Pair Transition module.

Supplementary ‘1.6.7 Transition in the pair stack’: Algorithm 15.

Parameters:
  • c_z – Pair or template representation dimension (channels).

  • nc_z multiplier to obtain hidden dimension (channels).

forward(z: Tensor, mask: Tensor | None = None, inplace_safe: bool = False) Tensor

Pair Transition forward pass.

Parameters:
  • z – [batch, N_res, N_res, c_z] pair representation

  • mask – [batch, N_res, N_res] pair mask

Returns:

[batch, N_res, N_res, c_z] updated pair representation

Return type:

z

deepfold.modules.recycling_embedder module

class deepfold.modules.recycling_embedder.RecyclingEmbedder(c_m: int, c_z: int, min_bin: float, max_bin: float, num_bins: int, inf: float)

Bases: PatchedModule

Recycling Embedder module.

Supplementary ‘1.10 Recycling iterations’.

Parameters:
  • c_m – MSA representation dimension (channels).

  • c_z – Pair representation dimension (channels).

  • min_bin – Smallest distogram bin (Angstroms).

  • max_bin – Largest distogram bin (Angstroms).

  • num_bins – Number of distogram bins.

  • inf – Safe infinity value.

forward(m: Tensor, z: Tensor, m0_prev: Tensor, z_prev: Tensor, x_prev: Tensor, inplace_safe: bool) Tuple[Tensor, Tensor]

Recycling Embedder forward pass.

Supplementary ‘1.10 Recycling iterations’: Algorithm 32.

Parameters:
  • m – [batch, N_clust, N_res, c_m]

  • z – [batch, N_res, N_res, c_z]

  • m0_prev – [batch, N_res, c_m]

  • z_prev – [batch, N_res, N_res, c_z]

  • x_prev – [batch, N_res, 3]

Returns:

[batch, N_clust, N_res, c_m] z: [batch, N_res, N_res, c_z]

Return type:

m

deepfold.modules.single_transition module

class deepfold.modules.single_transition.SingleTransition(c_s: int, dropout_rate: float)

Bases: PatchedModule

Single Transition module.

Supplementary ‘1.8 Structure module’: Algorithm 20, lines 8-9.

Parameters:
  • c_s – Single representation dimension (channels).

  • dropout_rate – Dropout rate.

forward(s: Tensor, inplace_safe: bool) Tensor

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

deepfold.modules.structure_module module

class deepfold.modules.structure_module.StructureModule(c_s: int, c_z: int, c_hidden_ipa: int, c_hidden_ang_res: int, num_heads_ipa: int, num_qk_points: int, num_v_points: int, is_multimer: bool, dropout_rate: float, num_blocks: int, num_ang_res_blocks: int, num_angles: int, scale_factor: float, inf: float, eps: float)

Bases: PatchedModule

Structure Module.

Supplementary ‘1.8 Structure module’: Algorithm 20.

Parameters:
  • c_s – Single representation dimension (channels).

  • c_z – Pair representation dimension (channels).

  • c_hidden_ipa – Hidden dimension in invariant point attention.

  • c_hidden_ang_res – Hidden dimension in angle resnet.

  • num_heads_ipa – Number of heads used in invariant point attention.

  • num_qk_points – Number of query/key points in invariant point attention.

  • num_v_points – Number of value points in invariant point attention.

  • dropout_rate – Dropout rate in structure module.

  • num_blocks – Number of shared blocks in the forward pass.

  • num_ang_res_blocks – Number of blocks in angle resnet.

  • num_angles – Number of angles in angle resnet.

  • scale_factor – Scale translation factor.

  • inf – Safe infinity value.

  • eps – Epsilon to prevent division by zero.

forward(s: Tensor, z: Tensor, mask: Tensor, aatype: Tensor, inplace_safe: bool) Dict[str, Tensor]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

deepfold.modules.template_angle_embedder module

class deepfold.modules.template_angle_embedder.TemplateAngleEmbedder(ta_dim: int, c_m: int)

Bases: PatchedModule

Template Angle Embedder module.

Embeds the “template_angle_feat” feature.

Supplementary ‘1.4 AlphaFold Inference’: Algorithm 2, line 7.

Parameters:
  • ta_dim – Input template_angle_feat dimension (channels).

  • c_m – Output MSA representation dimension (channels).

forward(template_angle_feat: Tensor) Tensor

Template Angle Embedder forward pass.

Parameters:

template_angle_feat – [batch, N_templ, N_res, ta_dim]

Returns:

[batch, N_templ, N_res, c_m]

Return type:

template_angle_embedding

deepfold.modules.template_pair_block module

class deepfold.modules.template_pair_block.TemplatePairBlock(c_t: int, c_hidden_tri_att: int, c_hidden_tri_mul: int, num_heads_tri: int, pair_transition_n: int, dropout_rate: float, inf: float, chunk_size_tri_att: int | None, block_size_tri_mul: int | None, tri_att_first: bool = True)

Bases: PatchedModule

Template Pair Block module.

Supplementary ‘1.7.1 Template stack’: Algorithm 16.

Parameters:
  • c_t – Template representation dimension (channels).

  • c_hidden_tri_att – Hidden dimension in triangular attention.

  • c_hidden_tri_mul – Hidden dimension in multiplicative updates.

  • num_heads_tri – Number of heads used in triangular attention.

  • pair_transition_n – Channel multiplier in pair transition.

  • dropout_rate – Dropout rate for pair activations.

  • inf – Safe infinity value.

  • chunk_size_tri_att – Optional chunk size for a batch-like dimension in triangular attention.

forward(t: Tensor, mask: Tensor, inplace_safe: bool) Tensor

Template Pair Block forward pass.

Parameters:
  • t – [batch, N_templ, N_res, N_res, c_t] template representation

  • mask – [batch, N_res, N_res] pair mask

Returns:

[batch, N_templ, N_res, N_res, c_t] updated template representation

Return type:

t

deepfold.modules.template_pair_embedder module

class deepfold.modules.template_pair_embedder.TemplatePairEmbedder(tp_dim: int, c_t: int, **kwargs)

Bases: PatchedModule

Template Pair Embedder module.

Embeds the “template_pair_feat” feature.

Supplementary ‘1.4 AlphaFold Inference’: Algorithm 2, line 9.

Parameters:
  • tp_dim – Input template_pair_feat dimension (channels).

  • c_t – Output template representation dimension (channels).

build_template_pair_feat(feats: Dict[str, Tensor], min_bin: float, max_bin: float, num_bins: int, use_unit_vector: bool, inf: float, eps: float, dtype: dtype) Tensor
forward(template_pair_feat: Tensor) Tensor

Template Pair Embedder forward pass.

Parameters:

template_pair_feat – [batch, N_res, N_res, tp_dim]

Returns:

[batch, N_res, N_res, c_t]

Return type:

template_pair_embedding

class deepfold.modules.template_pair_embedder.TemplatePairEmbedderMultimer(c_z: int, c_t: int, c_dgram: int, c_aatype: int, **kwargs)

Bases: PatchedModule

build_template_pair_feat(feats: Dict[str, Tensor], min_bin: float, max_bin: float, num_bins: int, inf: float, eps: float, dtype: dtype) Dict[str, Tensor]
forward(query_embedding: Tensor, multichain_mask_2d: Tensor, template_dgram: Tensor, aatype_one_hot: Tensor, pseudo_beta_mask: Tensor, backbone_mask: Tensor, unit_vector: Tensor) Tensor

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

deepfold.modules.template_pair_stack module

class deepfold.modules.template_pair_stack.TemplatePairStack(c_t: int, c_hidden_tri_att: int, c_hidden_tri_mul: int, num_blocks: int, num_heads_tri: int, pair_transition_n: int, dropout_rate: float, inf: float, chunk_size_tri_att: int | None, block_size_tri_mul: int | None, tri_att_first: bool = True)

Bases: PatchedModule

Template Pair Stack module.

Supplementary ‘1.7.1 Template stack’: Algorithm 16.

Parameters:
  • c_t – Template representation dimension (channels).

  • c_hidden_tri_att – Hidden dimension in triangular attention.

  • c_hidden_tri_mul – Hidden dimension in multiplicative updates.

  • num_blocks – Number of blocks in the stack.

  • num_heads_tri – Number of heads used in triangular attention.

  • pair_transition_n – Channel multiplier in pair transition.

  • dropout_rate – Dropout rate for pair activations.

  • inf – Safe infinity value.

  • chunk_size_tri_att – Optional chunk size for a batch-like dimension in triangular attention.

forward(t: Tensor, mask: Tensor, gradient_checkpointing: bool, inplace_safe: bool) Tensor

Template Pair Stack forward pass.

Parameters:
  • t – [batch, N_templ, N_res, N_res, c_t] template representation

  • mask – [batch, N_res, N_res] pair mask

  • gradient_checkpointing – whether to use gradient checkpointing

Returns:

[batch, N_templ, N_res, N_res, c_t] updated template representation

Return type:

t

deepfold.modules.template_pointwise_attention module

class deepfold.modules.template_pointwise_attention.TemplatePointwiseAttention(c_t: int, c_z: int, c_hidden: int, num_heads: int, inf: float, chunk_size: int | None)

Bases: PatchedModule

Template Pointwise Attention module.

Supplementary ‘1.7.1 Template stack’: Algorithm 17.

Parameters:
  • c_t – Template representation dimension (channels).

  • c_z – Pair representation dimension (channels).

  • c_hidden – Hidden dimension (per-head).

  • num_heads – Number of attention heads.

  • inf – Safe infinity value.

  • chunk_size – Optional chunk size for a batch-like dimension.

forward(t: Tensor, z: Tensor, template_mask: Tensor) Tensor

Template Pointwise Attention forward pass.

Parameters:
  • t – [batch, N_templ, N_res, N_res, c_t] template representation

  • z – [batch, N_res, N_res, c_z] pair representation

  • template_mask – [batch, N_templ] template mask

Returns:

[batch, N_res, N_res, c_z] pair representation update

from template representation

Return type:

z_update

deepfold.modules.template_projection module

class deepfold.modules.template_projection.TemplateProjection(c_t: int, c_z: int)

Bases: PatchedModule

Template Projection module.

Multimer ‘7.7. Architectural Modifications’.

Parameters:
  • c_t – Template representation dimension (channels).

  • c_z – Pair representation dimension (channels).

forward(t: Tensor) Tensor

Template Projector forward pass.

Parameters:

t – [batch, N_templ, N_res, N_res, c_t] template representation

Returns:

[batch, N_res, N_res, c_z] pair representation update

from template representation

Return type:

z_update

deepfold.modules.triangular_attention module

class deepfold.modules.triangular_attention.TriangleAttention(c_z: int, c_hidden: int, num_heads: int, ta_type: str, inf: float, chunk_size: int | None)

Bases: PatchedModule

Triangle Attention module.

Supplementary ‘1.6.6 Triangular self-attention’: Algorithms 13 and 14.

Parameters:
  • c_z – Pair or template representation dimension (channels).

  • c_hidden – Hidden dimension (channels).

  • num_heads – Number of attention heads.

  • ta_type – “starting” or “ending”

  • inf – Safe infinity value.

  • chunk_size – Optional chunk size for a batch-like dimension.

forward(z: Tensor, mask: Tensor) Tensor

Triangle Attention forward pass.

Parameters:
  • z – [batch, N_res, N_res, c_z] pair representation

  • mask – [batch, N_res, N_res] pair mask

Returns:

[batch, N_res, N_res, c_z] pair representation update

Return type:

z_update

class deepfold.modules.triangular_attention.TriangleAttentionEndingNode(c_z: int, c_hidden: int, num_heads: int, inf: float, chunk_size: int | None)

Bases: TriangleAttention

Triangle Attention Ending Node module.

Supplementary ‘1.6.6 Triangular self-attention’: Algorithm 14 Triangular gated self-attention around ending node.

Parameters:
  • c_z – Pair or template representation dimension (channels).

  • c_hidden – Hidden dimension (channels).

  • num_heads – Number of attention heads.

  • inf – Safe infinity value.

  • chunk_size – Optional chunk size for a batch-like dimension.

class deepfold.modules.triangular_attention.TriangleAttentionStartingNode(c_z: int, c_hidden: int, num_heads: int, inf: float, chunk_size: int | None)

Bases: TriangleAttention

Triangle Attention Starting Node module.

Supplementary ‘1.6.6 Triangular self-attention’: Algorithm 13 Triangular gated self-attention around starting node.

Parameters:
  • c_z – Pair or template representation dimension (channels).

  • c_hidden – Hidden dimension (channels).

  • num_heads – Number of attention heads.

  • inf – Safe infinity value.

  • chunk_size – Optional chunk size for a batch-like dimension.

deepfold.modules.triangular_multiplicative_update module

class deepfold.modules.triangular_multiplicative_update.TriangleMultiplicationIncoming(c_z: int, c_hidden: int, block_size: int | None)

Bases: TriangleMultiplicativeUpdate

Triangle Multiplication Incoming module.

Supplementary ‘1.6.5 Triangular multiplicative update’: Algorithm 12 Triangular multiplicative update using “incoming” edges.

Parameters:
  • c_z – Pair or template representation dimension (channels).

  • c_hidden – Hidden dimension (channels).

class deepfold.modules.triangular_multiplicative_update.TriangleMultiplicationOutgoing(c_z: int, c_hidden: int, block_size: int | None)

Bases: TriangleMultiplicativeUpdate

Triangle Multiplication Outgoing module.

Supplementary ‘1.6.5 Triangular multiplicative update’: Algorithm 11 Triangular multiplicative update using “outgoing” edges.

Parameters:
  • c_z – Pair or template representation dimension (channels).

  • c_hidden – Hidden dimension (channels).

class deepfold.modules.triangular_multiplicative_update.TriangleMultiplicativeUpdate(c_z: int, c_hidden: int, tmu_type: str, block_size: int | None = None)

Bases: PatchedModule

Triangle Multiplicative Update module.

Supplementary ‘1.6.5 Triangular multiplicative update’: Algorithms 11 and 12.

Parameters:
  • c_z – Pair or template representation dimension (channels).

  • c_hidden – Hidden dimension (channels).

  • tmu_type – “outgoing” or “incoming”

forward(z: Tensor, mask: Tensor) Tensor

Triangle Multiplicative Update forward pass.

Parameters:
  • z – [batch, N_res, N_res, c_z] pair representation

  • mask – [batch, N_res, N_res] pair mask

Returns:

[batch, N_res, N_res, c_z] pair representation update

Return type:

z_update

Module contents