Source code for deepfold.modules.outer_product_mean

from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

import deepfold.distributed.model_parallel as mp
import deepfold.modules.inductor as inductor
from deepfold.modules.layer_norm import LayerNorm
from deepfold.modules.linear import Linear
from deepfold.utils.iter_utils import slice_generator
from deepfold.utils.precision import is_fp16_enabled


[docs] class OuterProductMean(nn.Module): """Outer Product Mean module. Supplementary '1.6.4 Outer product mean': Algorithm 10. Args: c_m: MSA (or Extra MSA) representation dimension (channels). c_z: Pair representation dimension (channels). c_hidden: Hidden dimension (channels). eps: Epsilon to prevent division by zero. chunk_size: Optional chunk size for a batch-like dimension. """ def __init__( self, c_m: int, c_z: int, c_hidden: int, eps: float, chunk_size: Optional[int], ) -> None: super().__init__() assert eps == 1e-3 self.c_m = c_m self.c_z = c_z self.c_hidden = c_hidden self.eps = eps self.chunk_size = chunk_size self.layer_norm = LayerNorm(c_m) self.linear_1 = Linear(c_m, c_hidden, bias=True, init="default") self.linear_2 = Linear(c_m, c_hidden, bias=True, init="default") self.linear_out = Linear(c_hidden**2, c_z, bias=True, init="final")
[docs] def forward( self, m: torch.Tensor, mask: torch.Tensor, add_output_to: torch.Tensor, inplace_safe: bool, ) -> torch.Tensor: """Outer Product Mean forward pass. Args: m: [batch, N_seq, N_res, c_m] MSA representation mask: [batch, N_seq, N_res] MSA mask add_output_to: pair representation to which add outer update Returns: outer: [batch, N_res, N_res, c_z] updated pair representation """ if is_fp16_enabled(): with torch.cuda.amp.autocast(enabled=False): return self._forward(m.float(), mask, add_output_to, inplace_safe) else: return self._forward(m, mask, add_output_to, inplace_safe)
def _forward( self, m: torch.Tensor, mask: torch.Tensor, add_output_to: torch.Tensor, inplace_safe: bool, ) -> torch.Tensor: m = self.layer_norm(m) # m: [batch, N_seq, N_res, c_m] mask = mask.unsqueeze(-1) # mask: [batch, N_seq, N_res, 1] if mp.is_enabled(): mask_s = mp.scatter(mask, dim=2) a, b = _forward_linear_a_b( m, self.linear_1.weight, self.linear_1.bias, self.linear_2.weight, self.linear_2.bias, mask_s, ) b = mp.gather(b, dim=-2, bwd="all_reduce_sum_split") else: a, b = _forward_linear_a_b( m, self.linear_1.weight, self.linear_1.bias, self.linear_2.weight, self.linear_2.bias, mask, ) # a: [batch, N_seq, N_res, c_hidden] # b: [batch, N_seq, N_res, c_hidden] if inductor.is_enabled(): # TODO: does it work with chunked forward? outer = _forward_outer_jit( a, b, self.linear_out.weight, self.linear_out.bias, a.shape[0], # batch a.shape[2], # a_N_res b.shape[2], # b_N_res a.shape[3], # c_hidden ) else: a = a.transpose(-2, -3) # a: [batch, N_res, N_seq, c_hidden] b = b.transpose(-2, -3) # b: [batch, N_res, N_seq, c_hidden] outer = self._outer_forward(a=a, b=b) # outer: [batch, N_res, N_res, c_z] norm = torch.einsum("...abc,...adc->...bdc", mask, mask) # norm: [batch, N_res, N_res, 1] if mp.is_enabled(): norm = mp.scatter(norm, dim=-3) outer = _forward_normalize_add(norm, outer, add_output_to, self.eps, inplace_safe) # outer: [batch, N_res, N_res, c_z] return outer def _outer_forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: if self.chunk_size is None: return self._outer(a=a, b=b) else: return self._outer_chunked(a=a, b=b, chunk_size=self.chunk_size) def _outer(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: outer = torch.einsum("...bac,...dae->...bdce", a, b) # outer: [batch, a_N_res, b_N_res, c_hidden, c_hidden] outer = outer.reshape(outer.shape[:-2] + (self.c_hidden * self.c_hidden,)) # outer: [batch, a_N_res, b_N_res, c_hidden * c_hidden] outer = self.linear_out(outer) # outer: [batch, a_N_res, b_N_res, c_z] return outer def _outer_chunked(self, a: torch.Tensor, b: torch.Tensor, chunk_size: int) -> torch.Tensor: outer_chunks = [] subbatch_size = a.size(1) for left, right in slice_generator(0, subbatch_size, chunk_size): a_chunk = a[:, left:right] outer_chunk = self._outer(a=a_chunk, b=b) outer_chunks.append(outer_chunk) return torch.cat(outer_chunks, dim=1)
def _forward_linear_a_b_eager( m: torch.Tensor, w1: torch.Tensor, b1: torch.Tensor, w2: torch.Tensor, b2: torch.Tensor, mask: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: a = F.linear(m, w1, b1) * mask b = F.linear(m, w2, b2) * mask return a, b _forward_linear_a_b_jit = torch.compile(_forward_linear_a_b_eager) def _forward_linear_a_b( m: torch.Tensor, w1: torch.Tensor, b1: torch.Tensor, w2: torch.Tensor, b2: torch.Tensor, mask: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: if inductor.is_enabled(): forward_linear_a_b_fn = _forward_linear_a_b_jit elif inductor.is_enabled_and_autograd_off(): forward_linear_a_b_fn = _forward_linear_a_b_jit else: forward_linear_a_b_fn = _forward_linear_a_b_eager return forward_linear_a_b_fn(m, w1, b1, w2, b2, mask) def _forward_outer_eager( a: torch.Tensor, b: torch.Tensor, w_o: torch.Tensor, b_o: torch.Tensor, batch: int, a_N_res: int, b_N_res: int, c_hidden: int, ) -> torch.Tensor: # a: [batch, N_seq, N_res, c_hidden] a = a.transpose(-2, -3) # a: [batch, N_res, N_seq, c_hidden] a = a.transpose(-1, -2) # a: [batch, N_res, c_hidden, N_seq] a = a.flatten(1, 2) # a: [batch, N_res * c_hidden, N_seq] # b: [batch, N_seq, N_res, c_hidden] b = b.flatten(-2, -1) # b: [batch, N_seq, N_res * c_hidden] outer = torch.bmm(a, b) # outer: [batch, a_N_res * c_hidden, b_N_res * c_hidden] outer = outer.reshape((batch, a_N_res, c_hidden, b_N_res, c_hidden)) # outer: [batch, a_N_res, c_hidden, b_N_res, c_hidden] outer = outer.transpose(2, 3) # outer: [batch, a_N_res, b_N_res, c_hidden, c_hidden] outer = outer.flatten(-2, -1) # outer: [batch, a_N_res, b_N_res, c_hidden * c_hidden] outer = F.linear(outer, w_o, b_o) # outer: [batch, a_N_res, b_N_res, c_z] return outer _forward_outer_jit = torch.compile(_forward_outer_eager) def _forward_normalize_add_eager( norm: torch.Tensor, outer: torch.Tensor, z: torch.Tensor, eps: float, inplace: bool, ) -> torch.Tensor: if inplace: outer /= norm + eps z += outer return z else: outer = outer / (norm + eps) return z + outer _forward_normalize_add_jit = torch.compile(_forward_normalize_add_eager) def _forward_normalize_add( norm: torch.Tensor, outer: torch.Tensor, z: torch.Tensor, eps: float, inplace: bool, ) -> torch.Tensor: if inductor.is_enabled(): forward_normalize_add_fn = _forward_normalize_add_jit else: forward_normalize_add_fn = _forward_normalize_add_eager return forward_normalize_add_fn(norm, outer, z, eps, inplace)