Source code for deepfold.modules.single_transition

from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

import deepfold.modules.inductor as inductor
from deepfold.modules.layer_norm import LayerNorm
from deepfold.modules.linear import Linear


[docs] class SingleTransition(nn.Module): """Single Transition module. Supplementary '1.8 Structure module': Algorithm 20, lines 8-9. Args: c_s: Single representation dimension (channels). dropout_rate: Dropout rate. """ def __init__( self, c_s: int, dropout_rate: float, ) -> None: super().__init__() self.linear_1 = Linear(c_s, c_s, bias=True, init="relu") self.linear_2 = Linear(c_s, c_s, bias=True, init="relu") self.linear_3 = Linear(c_s, c_s, bias=True, init="final") self.dropout = nn.Dropout(dropout_rate) self.layer_norm = LayerNorm(c_s)
[docs] def forward( self, s: torch.Tensor, inplace_safe: bool, ) -> torch.Tensor: if inductor.is_enabled(): forward_fn = _forward_jit else: forward_fn = _forward_eager s = forward_fn( s, self.linear_1.weight, self.linear_1.bias, self.linear_2.weight, self.linear_2.bias, self.linear_3.weight, self.linear_3.bias, inplace_safe, ) s = self.layer_norm(self.dropout(s)) return s
def _forward_eager( s: torch.Tensor, w1: torch.Tensor, b1: torch.Tensor, w2: torch.Tensor, b2: torch.Tensor, w3: torch.Tensor, b3: torch.Tensor, inplace: bool, ) -> torch.Tensor: x = F.linear(s, w1, b1) x = torch.relu(x) x = F.linear(x, w2, b2) x = torch.relu(x) x = F.linear(x, w3, b3) if inplace: s += x else: s = s + x return s _forward_jit = torch.compile(_forward_eager) # TODO: Chunk