Source code for deepfold.modules.template_projection

import torch
import torch.nn as nn
import torch.nn.functional as F

import deepfold.modules.inductor as inductor
from deepfold.modules.linear import Linear


[docs] class TemplateProjection(nn.Module): """Template Projection module. Multimer '7.7. Architectural Modifications'. Args: c_t: Template representation dimension (channels). c_z: Pair representation dimension (channels). """ def __init__(self, c_t: int, c_z: int) -> None: super().__init__() self.linear_t = Linear(c_t, c_z, bias=True, init="default")
[docs] def forward(self, t: torch.Tensor) -> torch.Tensor: """Template Projector forward pass. Args: t: [batch, N_templ, N_res, N_res, c_t] template representation Returns: z_update: [batch, N_res, N_res, c_z] pair representation update from template representation """ # Average template features. t = torch.mean(t, dim=-4) # t: [*, N_res, N_res, c_t] if inductor.is_enabled(): z_update = _forward_template_projection_jit(t, self.linear_t.weight, self.linear_t.bias) else: z_update = _forward_template_projection_eager(t, self.linear_t.weight, self.linear_t.bias) # z_update: [batch, N_res, N_res, c_z] return z_update
def _forward_template_projection_eager( t: torch.Tensor, w_proj: torch.Tensor, b_proj: torch.Tensor, ) -> torch.Tensor: return F.linear(F.relu(t), w_proj, b_proj) _forward_template_projection_jit = torch.compile(_forward_template_projection_eager)