Source code for deepfold.modules.backbone_update
import torch
import torch.nn as nn
import torch.nn.functional as F
import deepfold.modules.inductor as inductor
from deepfold.modules.linear import Linear
[docs]
class BackboneUpdate(nn.Module):
"""Backbone Update module.
Supplementary '1.8.3 Backbone update': Algorithm 23.
Args:
c_s: Single representation dimension (channels).
"""
def __init__(self, c_s: int) -> None:
super().__init__()
self.linear = Linear(c_s, 6, bias=True, init="final")
[docs]
def forward(self, s: torch.Tensor) -> torch.Tensor:
if inductor.is_enabled():
forward_fn = _forward_jit
else:
forward_fn = _forward_eager
return forward_fn(s, self.linear.weight, self.linear.bias)
def _forward_eager(
x: torch.Tensor,
w: torch.Tensor,
b: torch.Tensor,
) -> torch.Tensor:
return F.linear(x, w, b)
_forward_jit = torch.compile(_forward_eager)