Source code for deepfold.modules.template_angle_embedder
import torch
import torch.nn as nn
import torch.nn.functional as F
import deepfold.modules.inductor as inductor
from deepfold.modules.linear import Linear
[docs]
class TemplateAngleEmbedder(nn.Module):
"""Template Angle Embedder module.
Embeds the "template_angle_feat" feature.
Supplementary '1.4 AlphaFold Inference': Algorithm 2, line 7.
Args:
ta_dim: Input `template_angle_feat` dimension (channels).
c_m: Output MSA representation dimension (channels).
"""
def __init__(
self,
ta_dim: int,
c_m: int,
) -> None:
super().__init__()
self.linear_1 = Linear(ta_dim, c_m, bias=True, init="relu")
self.linear_2 = Linear(c_m, c_m, bias=True, init="relu")
[docs]
def forward(
self,
template_angle_feat: torch.Tensor,
) -> torch.Tensor:
"""Template Angle Embedder forward pass.
Args:
template_angle_feat: [batch, N_templ, N_res, ta_dim]
Returns:
template_angle_embedding: [batch, N_templ, N_res, c_m]
"""
# dap1 fusion regression
if inductor.is_enabled():
forward_fn = _forward_jit
elif inductor.is_enabled_and_autograd_off():
forward_fn = _forward_jit
else:
forward_fn = _forward_eager
return forward_fn(
template_angle_feat,
self.linear_1.weight,
self.linear_1.bias,
self.linear_2.weight,
self.linear_2.bias,
)
def _forward_eager(
x: torch.Tensor,
w1: torch.Tensor,
b1: torch.Tensor,
w2: torch.Tensor,
b2: torch.Tensor,
) -> torch.Tensor:
x = F.linear(x, w1, b1)
x = torch.relu(x)
x = F.linear(x, w2, b2)
return x
_forward_jit = torch.compile(_forward_eager)