Source code for deepfold.modules.template_pointwise_attention

from typing import Optional

import torch
import torch.nn as nn

from deepfold.modules.attention import CrossAttentionNoGate


[docs] class TemplatePointwiseAttention(nn.Module): """Template Pointwise Attention module. Supplementary '1.7.1 Template stack': Algorithm 17. Args: c_t: Template representation dimension (channels). c_z: Pair representation dimension (channels). c_hidden: Hidden dimension (per-head). num_heads: Number of attention heads. inf: Safe infinity value. chunk_size: Optional chunk size for a batch-like dimension. """ def __init__( self, c_t: int, c_z: int, c_hidden: int, num_heads: int, inf: float, chunk_size: Optional[int], ) -> None: super().__init__() self.mha = CrossAttentionNoGate( c_q=c_z, c_kv=c_t, c_hidden=c_hidden, num_heads=num_heads, inf=inf, chunk_size=chunk_size, impl="torch", )
[docs] def forward( self, t: torch.Tensor, z: torch.Tensor, template_mask: torch.Tensor, ) -> torch.Tensor: """Template Pointwise Attention forward pass. Args: t: [batch, N_templ, N_res, N_res, c_t] template representation z: [batch, N_res, N_res, c_z] pair representation template_mask: [batch, N_templ] template mask Returns: z_update: [batch, N_res, N_res, c_z] pair representation update from template representation """ t = t.movedim(-4, -2) # t: [batch, N_res, N_res, N_templ, c_t] z = z.unsqueeze(-2) # z: [batch, N_res, N_res, 1, c_z] template_mask = template_mask.unsqueeze(-2).unsqueeze(-3).unsqueeze(-4).unsqueeze(-5) # template_mask: [batch, 1, 1, 1, 1, N_templ] z = self.mha( input_q=z, input_kv=t, mask=template_mask, bias=None, ) # z: [batch, N_res, N_res, 1, c_z] z = z.squeeze(-2) # z: [batch, N_res, N_res, c_z] return z