Source code for deepfold.modules.angle_resnet

from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

import deepfold.modules.inductor as inductor
from deepfold.modules.linear import Linear


[docs] class AngleResnet(nn.Module): """Angle Resnet module. Supplementary '1.8 Structure module': Algorithm 20, lines 11-14. Args: c_s: Single representation dimension (channels). c_hidden: Hidden dimension (channels). num_blocks: Number of resnet blocks. num_angles: Number of torsion angles to generate. eps: Epsilon to prevent division by zero. """ def __init__( self, c_s: int, c_hidden: int, num_blocks: int, num_angles: int, eps: float, ) -> None: super().__init__() self.c_s = c_s self.c_hidden = c_hidden self.num_blocks = num_blocks self.num_angles = num_angles self.eps = eps self.linear_in = Linear(c_s, c_hidden, bias=True, init="default") self.linear_initial = Linear(c_s, c_hidden, bias=True, init="default") self.layers = nn.ModuleList([AngleResnetBlock(c_hidden=c_hidden) for _ in range(num_blocks)]) self.linear_out = Linear(c_hidden, num_angles * 2, bias=True, init="default")
[docs] def forward( self, s: torch.Tensor, s_initial: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Angle Resnet forward pass. Args: s: [batch, N_res, c_s] single representation s_initial: [batch, N_res, c_s] initial single representation Returns: unnormalized_angles: [batch, N_res, num_angles, 2] angles: [batch, N_res, num_angles, 2] """ # The ReLU's applied to the inputs are absent from the supplement # pseudocode but present in the source. For maximal compatibility with # the pretrained weights, I'm going with the source. s_initial = self.linear_initial(torch.relu(s_initial)) s = self.linear_in(torch.relu(s)) s = s + s_initial # s: [batch, N_res, c_hidden] for layer in self.layers: s = layer(s) s = torch.relu(s) # s: [batch, N_res, c_hidden] s = self.linear_out(s) # s: [batch, N_res, num_angles * 2] if inductor.is_enabled(): forward_angles_fn = _forward_angles_jit else: forward_angles_fn = _forward_angles_eager unnormalized_angles, angles = forward_angles_fn(s, self.num_angles, self.eps) # unnormalized_angles: [batch, N_res, num_angles, 2] # angles: [batch, N_res, num_angles, 2] return unnormalized_angles, angles
[docs] class AngleResnetBlock(nn.Module): """Angle Resnet Block module.""" def __init__(self, c_hidden: int) -> None: super().__init__() self.linear_1 = Linear(c_hidden, c_hidden, bias=True, init="relu") self.linear_2 = Linear(c_hidden, c_hidden, bias=True, init="final")
[docs] def forward(self, a: torch.Tensor) -> torch.Tensor: if inductor.is_enabled(): forward_angle_resnet_block_fn = _forward_angle_resnet_block_jit else: forward_angle_resnet_block_fn = _forward_angle_resnet_block_eager return forward_angle_resnet_block_fn( a, self.linear_1.weight, self.linear_1.bias, self.linear_2.weight, self.linear_2.bias, )
def _forward_angles_eager( s: torch.Tensor, num_angles: int, eps: float, ) -> Tuple[torch.Tensor, torch.Tensor]: s = s.view(s.shape[:-1] + (num_angles, 2)) # s: [batch, N_res, num_angles, 2] unnormalized_angles = s # unnormalized_angles: [batch, N_res, num_angles, 2] norm_denom = torch.sqrt( torch.clamp( torch.sum(s**2, dim=-1, keepdim=True), min=eps, ) ) angles = s / norm_denom return unnormalized_angles, angles _forward_angles_jit = torch.compile(_forward_angles_eager) def _forward_angle_resnet_block_eager( a: torch.Tensor, w1: torch.Tensor, b1: torch.Tensor, w2: torch.Tensor, b2: torch.Tensor, ) -> torch.Tensor: x = torch.relu(a) x = F.linear(x, w1, b1) x = torch.relu(x) x = F.linear(x, w2, b2) y = a + x return y _forward_angle_resnet_block_jit = torch.compile(_forward_angle_resnet_block_eager)