Source code for deepfold.modules.triangular_multiplicative_update

from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

import deepfold.distributed.model_parallel as mp
import deepfold.modules.inductor as inductor
from deepfold.modules.layer_norm import LayerNorm
from deepfold.modules.linear import Linear
from deepfold.utils.iter_utils import slice_generator
from deepfold.utils.precision import is_fp16_enabled


[docs] class TriangleMultiplicativeUpdate(nn.Module): """Triangle Multiplicative Update module. Supplementary '1.6.5 Triangular multiplicative update': Algorithms 11 and 12. Args: c_z: Pair or template representation dimension (channels). c_hidden: Hidden dimension (channels). tmu_type: "outgoing" or "incoming" """ def __init__( self, c_z: int, c_hidden: int, tmu_type: str, block_size: Optional[int] = None, ) -> None: super().__init__() self.c_z = c_z self.c_hidden = c_hidden self._is_outgoing = {"outgoing": True, "incoming": False}[tmu_type] self.block_size = block_size self.linear_ab_p = Linear(c_z, c_hidden * 2, init="default") self.linear_ab_g = Linear(c_z, c_hidden * 2, init="gating") # self.linear_a_p = Linear(c_z, c_hidden, bias=True, init="default") # self.linear_a_g = Linear(c_z, c_hidden, bias=True, init="gating") # self.linear_b_p = Linear(c_z, c_hidden, bias=True, init="default") # self.linear_b_g = Linear(c_z, c_hidden, bias=True, init="gating") self.linear_g = Linear(c_z, c_z, bias=True, init="gating") self.linear_z = Linear(c_hidden, c_z, bias=True, init="final") self.layer_norm_in = LayerNorm(c_z) self.layer_norm_out = LayerNorm(c_hidden)
[docs] def forward( self, z: torch.Tensor, mask: torch.Tensor, ) -> torch.Tensor: """Triangle Multiplicative Update forward pass. Args: z: [batch, N_res, N_res, c_z] pair representation mask: [batch, N_res, N_res] pair mask Returns: z_update: [batch, N_res, N_res, c_z] pair representation update """ z = self.layer_norm_in(z) # z: [batch, N_res, N_res, c_z] mask = mask.unsqueeze(-1) # mask: [batch, N_res, N_res, 1] if self.block_size is not None: return self._inference_forward(z, mask) # TODO: Fusion with a.float, b.float (?) a, b = _compute_projections( z, mask, self.linear_ab_g.weight, self.linear_ab_g.bias, self.linear_ab_p.weight, self.linear_ab_p.bias, ) # .chunk(2, dim=-1) if mp.is_enabled(): if self._is_outgoing: b = mp.gather(b, dim=-3, bwd="all_reduce_sum_split") else: a = mp.gather(a, dim=-2, bwd="all_reduce_sum_split") if is_fp16_enabled(): with torch.cuda.amp.autocast(enabled=False): x = self._combine_projections(a.float(), b.float()) else: x = self._combine_projections(a, b) # x: [batch, N_res, N_res, c_hidden] del a, b x = self.layer_norm_out(x) # x: [batch, N_res, N_res, c_hidden] x = _compute_output( x, z, self.linear_z.weight, self.linear_z.bias, self.linear_g.weight, self.linear_g.bias, ) # x: [batch, N_res, N_res, c_z] return x
def _combine_projections( self, a: torch.Tensor, b: torch.Tensor, ) -> torch.Tensor: if self._is_outgoing: a = a.movedim(-1, -3) b = b.swapdims(-1, -3) else: a = a.swapdims(-1, -3) b = b.movedim(-1, -3) p = torch.matmul(a, b) return p.movedim(-3, -1) def _inference_forward( self, z: torch.Tensor, mask: torch.Tensor, ) -> torch.Tensor: """ Args: Outgoing: z: [B, N', N, C] mask: [B, N', N, 1] Incoming: z: [B, N, N', C] mask: [B, N, N', 1] Returns: z: [*, N', N, C] outgoing [*, N, N', C] incoming Notes: Avoid too small block size (<256). """ out = torch.empty_like(z) par_dim = z.shape[-3] if self._is_outgoing else z.shape[-2] # N' assert self.block_size is not None for i_begin, i_end in slice_generator(0, par_dim, self.block_size): if self._is_outgoing: z_i = z[:, i_begin:i_end, :, :] # [B, N'', N, C] a_chunk, _ = _compute_projections( z_i, mask[:, i_begin:i_end, :, :], self.linear_ab_g.weight, self.linear_ab_g.bias, self.linear_ab_p.weight, self.linear_ab_p.bias, ) # a_chunk: [B, I'', K, C] a_chunk = a_chunk.movedim(-1, -3) # [B, C, I'', K] else: # is_incoming z_i = z[:, :, i_begin:i_end, :] # [B, N, N'', C] _, b_chunk = _compute_projections( z_i, mask[:, :, i_begin:i_end, :], self.linear_ab_g.weight, self.linear_ab_g.bias, self.linear_ab_p.weight, self.linear_ab_p.bias, ) # b_chunk: [B, K, J'', C] b_chunk = b_chunk.movedim(-1, -3) # [B, C, K, J''] for j_begin, j_end in slice_generator(0, par_dim, self.block_size): if self._is_outgoing: z_j = z[:, j_begin:j_end, :, :] # [B, N'', N, C] _, b_chunk = _compute_projections( z_j, mask[:, j_begin:j_end, :, :], self.linear_ab_g.weight, self.linear_ab_g.bias, self.linear_ab_p.weight, self.linear_ab_p.bias, ) # b_chunk: [B, K'', J, C] b_chunk = b_chunk.swapdims(-1, -3) # [B, C, K'', J] b_chunk = b_chunk.contiguous() else: z_j = z[:, :, j_begin:j_end, :] # [B, N, N'', C] a_chunk, _ = _compute_projections( z_j, mask[:, :, j_begin:j_end, :], self.linear_ab_g.weight, self.linear_ab_g.bias, self.linear_ab_p.weight, self.linear_ab_p.bias, ) # a_chunk: [B, K, I'', C] a_chunk = a_chunk.swapdims(-1, -3) # [B, C, I'', K] a_chunk = a_chunk.contiguous() if mp.is_enabled(): for r in range(mp.size()): if self._is_outgoing: if r == mp.rank(): buf = b_chunk.clone() else: buf = torch.empty_like(b_chunk) buf = mp.broadcast(buf, r) x_chunk = torch.matmul(a_chunk, buf) del buf else: if r == mp.rank(): buf = a_chunk.clone() else: buf = torch.empty_like(a_chunk) buf = mp.broadcast(buf, r) x_chunk = torch.matmul(buf, b_chunk) del buf x_chunk = x_chunk.movedim(-3, -1) j_global_begin = par_dim * r + j_begin j_global_end = min(j_global_begin + self.block_size, par_dim * (r + 1)) if self._is_outgoing: out[:, i_begin:i_end, j_global_begin:j_global_end, :] = x_chunk else: out[:, j_global_begin:j_global_end, i_begin:i_end, :] = x_chunk del x_chunk else: x_chunk = torch.matmul(a_chunk, b_chunk).movedim(-3, -1) # x_chunk: [B, C, I', J'] -> [B, I', J', C] if self._is_outgoing: out[:, i_begin:i_end, j_begin:j_end, :] = x_chunk else: out[:, j_begin:j_end, i_begin:i_end, :] = x_chunk for i_begin, i_end in slice_generator(0, z.shape[-3], self.block_size): for j_begin, j_end in slice_generator(0, z.shape[-2], self.block_size): z_chunk = z[:, i_begin:i_end, j_begin:j_end, :] x_chunk = self.layer_norm_out(out[:, i_begin:i_end, j_begin:j_end, :]) out[:, i_begin:i_end, j_begin:j_end, :] = _compute_output( x_chunk, z_chunk, self.linear_z.weight, self.linear_z.bias, self.linear_g.weight, self.linear_g.bias, ) return out
[docs] class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate): """Triangle Multiplication Outgoing module. Supplementary '1.6.5 Triangular multiplicative update': Algorithm 11 Triangular multiplicative update using "outgoing" edges. Args: c_z: Pair or template representation dimension (channels). c_hidden: Hidden dimension (channels). """ def __init__( self, c_z: int, c_hidden: int, block_size: Optional[int], ) -> None: super().__init__( c_z=c_z, c_hidden=c_hidden, tmu_type="outgoing", block_size=block_size, )
[docs] class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate): """Triangle Multiplication Incoming module. Supplementary '1.6.5 Triangular multiplicative update': Algorithm 12 Triangular multiplicative update using "incoming" edges. Args: c_z: Pair or template representation dimension (channels). c_hidden: Hidden dimension (channels). """ def __init__( self, c_z: int, c_hidden: int, block_size: Optional[int], ) -> None: super().__init__( c_z=c_z, c_hidden=c_hidden, tmu_type="incoming", block_size=block_size, )
def _compute_projections_eager( z: torch.Tensor, mask: torch.Tensor, w_ab_g: torch.Tensor, b_ab_g: torch.Tensor, w_ab_p: torch.Tensor, b_ab_p: torch.Tensor, ) -> torch.Tensor: ab = F.linear(z, w_ab_g, b_ab_g) ab = torch.sigmoid(ab) * mask ab = ab * F.linear(z, w_ab_p, b_ab_p) return ab _compute_projections_jit = torch.compile(_compute_projections_eager) def _compute_projections( z: torch.Tensor, mask: torch.Tensor, w_ab_g: torch.Tensor, b_ab_g: torch.Tensor, w_ab_p: torch.Tensor, b_ab_p: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: if inductor.is_enabled(): compute_projections_fn = _compute_projections_jit else: compute_projections_fn = _compute_projections_eager a, b = compute_projections_fn(z, mask, w_ab_g, b_ab_g, w_ab_p, b_ab_p).chunk(2, dim=-1) return a, b def _compute_output_eager( x: torch.Tensor, z: torch.Tensor, w_z: torch.Tensor, b_z: torch.Tensor, w_g: torch.Tensor, b_g: torch.Tensor, ) -> torch.Tensor: x = F.linear(x, w_z, b_z) g = torch.sigmoid(F.linear(z, w_g, b_g)) x = x * g return x _compute_output_jit = torch.compile(_compute_output_eager) def _compute_output( x: torch.Tensor, z: torch.Tensor, w_z: torch.Tensor, b_z: torch.Tensor, w_g: torch.Tensor, b_g: torch.Tensor, ) -> torch.Tensor: if inductor.is_enabled(): compute_output_fn = _compute_output_jit elif inductor.is_enabled_and_autograd_off(): compute_output_fn = _compute_output_jit else: compute_output_fn = _compute_output_eager return compute_output_fn(x, z, w_z, b_z, w_g, b_g)