Source code for deepfold.modules.attention

import math
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

import deepfold.modules.inductor as inductor
from deepfold.modules.linear import Linear
from deepfold.modules.tweaks import evo_attn
from deepfold.utils.iter_utils import slice_generator


def _attention_gate_eager(
    output: torch.Tensor,
    gate: torch.Tensor,
    linear_g_bias: torch.Tensor,
) -> torch.Tensor:
    gate = torch.sigmoid(gate + linear_g_bias)
    output = output * gate
    return output


_attention_gate_jit = torch.compile(_attention_gate_eager)


[docs] class SelfAttentionWithGate(nn.Module): """Self Multi-Head Attention module with gating. Args: c_qkv: Input dimension of query|key|value data tensor (channels). c_hidden: Hidden dimension (per-head). num_heads: Number of attention heads. inf: Safe infinity value. chunk_size: Optional chunk size for a batch-like dimension. Supplementary '1.11.8 Reducing the memory consumption': Inference. """ def __init__( self, c_qkv: int, c_hidden: int, num_heads: int, inf: float, chunk_size: Optional[int], impl: Optional[str] = None, ) -> None: super().__init__() self.c_qkv = c_qkv self.c_hidden = c_hidden self.num_heads = num_heads self.inf = inf self.chunk_size = chunk_size self.impl = impl total_dim = c_hidden * num_heads self.linear_q = Linear(c_qkv, total_dim, bias=False, init="glorot") self.linear_k = Linear(c_qkv, total_dim, bias=False, init="glorot") self.linear_v = Linear(c_qkv, total_dim, bias=False, init="glorot") self.linear_g = Linear(c_qkv, total_dim, init="gating") self.linear_o = Linear(c_hidden * num_heads, c_qkv, bias=True, init="final") try: from deepfold_kernels.evoformer_attn import DS4Sci_EvoformerAttention except ModuleNotFoundError: from deepfold.modules.tweaks import evo_attn # Disable evoformer attention evo_attn.disable()
[docs] def forward( self, input_qkv: torch.Tensor, mask: torch.Tensor, bias: Optional[torch.Tensor], add_transposed_output_to: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Attention forward pass. Args: input_qkv: [*, QKV, c_qkv] query data (QKV == Q == K == V) mask: Logit mask tensor broadcastable to [*, num_heads, Q, K] bias: Optional logit bias tensor broadcastable to [*, num_heads, Q, K] add_transposed_output_to: Optional tensor to which transposed output will be added elementwisely. Returns: output: [*, Q, c_qkv] tensor """ query, key, value, gate = self._prep_qkvg(input_qkv) # query: [*, num_heads, Q, c_hidden] # key: [*, num_heads, K, c_hidden] # value: [*, num_heads, V, c_hidden] # gate: [*, Q, num_heads * c_hidden] if evo_attn.is_enabled(): impl = "evo" else: impl = "torch" if self.impl is not None: impl = self.impl if impl == "torch": output = self._attention_forward(query, key, value, mask, bias) # output: [*, num_heads, Q, c_hidden] for torch implementation. output = output.transpose(-2, -3) # output: [*, Q, num_heads, c_hidden] elif impl == "evo": from deepfold.ops.evoformer_attention import deepspeed_evo_attn mask_bias = (mask - 1.0) * self.inf biases = [mask_bias] if bias is not None: biases.append(bias) # output = deepspeed_evo_attn(query / math.sqrt(self.c_hidden), key, value, biases) output = deepspeed_evo_attn(query, key, value, biases) else: raise ValueError(f"Unsupported implementation '{impl}'") del query, key, value output = output.reshape(*output.shape[:-2], self.num_heads * self.c_hidden) # output = output.reshape(output.shape[:-2] + (self.num_heads * self.c_hidden,)) # output: [*, Q, num_heads * c_hidden] if self.training: output = _attention_gate_jit(output, gate, self.linear_g.bias) else: output = _attention_gate_eager(output, gate, self.linear_g.bias) # output: [*, Q, num_heads * c_hidden] if add_transposed_output_to is None: output = self.linear_o(output) else: output = _linear_transpose_add( output, self.linear_o.weight, self.linear_o.bias, add_transposed_output_to, ) # output: [*, Q, c_qkv] return output
def _prep_qkvg( self, input_qkv: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # input_qkv: [*, QKV, c_qkv] q = self.linear_q(input_qkv) k = self.linear_k(input_qkv) v = self.linear_v(input_qkv) g = self.linear_g(input_qkv) # q: [*, Q, num_heads * c_hidden] # k: [*, K, num_heads * c_hidden] # v: [*, V, num_heads * c_hidden] # g: [*, Q, num_heads * c_hidden] q = q.view(q.shape[:-1] + (self.num_heads, self.c_hidden)) k = k.view(k.shape[:-1] + (self.num_heads, self.c_hidden)) v = v.view(v.shape[:-1] + (self.num_heads, self.c_hidden)) # q: [*, Q, num_heads, c_hidden] # k: [*, K, num_heads, c_hidden] # v: [*, V, num_heads, c_hidden] q = q.transpose(-2, -3) k = k.transpose(-2, -3) v = v.transpose(-2, -3) # q: [*, num_heads, Q, c_hidden] # k: [*, num_heads, K, c_hidden] # v: [*, num_heads, V, c_hidden] # q = q / math.sqrt(self.c_hidden) scaling moved to _attention function return q, k, v, g def _attention_forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor, bias: Optional[torch.Tensor], ) -> torch.Tensor: if self.chunk_size is None: if inductor.is_enabled(): return _attention_eager(query, key, value, mask, bias, self.inf) else: return _attention_jit(query, key, value, mask, bias, self.inf) else: return _attention_chunked(query, key, value, mask, bias, self.inf, self.chunk_size)
[docs] class CrossAttentionNoGate(nn.Module): """Cross Multi-Head Attention module without gating. Args: c_q: Input dimension of query data tensor (channels). c_kv: Input dimension of key|value data tensor (channels). c_hidden: Hidden dimension (per-head). num_heads: Number of attention heads. inf: Safe infinity value. chunk_size: Optional chunk size for a batch-like dimension. Supplementary '1.11.8 Reducing the memory consumption': Inference. """ def __init__( self, c_q: int, c_kv: int, c_hidden: int, num_heads: int, inf: float, chunk_size: Optional[int], impl: Optional[str] = None, ) -> None: super().__init__() self.c_q = c_q self.c_kv = c_kv self.c_hidden = c_hidden self.num_heads = num_heads self.inf = inf self.chunk_size = chunk_size self.impl = impl self.linear_q = Linear(c_q, c_hidden * num_heads, bias=False, init="glorot") self.linear_k = Linear(c_kv, c_hidden * num_heads, bias=False, init="glorot") self.linear_v = Linear(c_kv, c_hidden * num_heads, bias=False, init="glorot") self.linear_o = Linear(c_hidden * num_heads, c_q, bias=True, init="final")
[docs] def forward( self, input_q: torch.Tensor, input_kv: torch.Tensor, mask: torch.Tensor, bias: Optional[torch.Tensor], ) -> torch.Tensor: """Attention forward pass. Args: input_q: [*, Q, c_q] query data input_kv: [*, KV, c_kv] key|value data (KV == K == V) mask: Logit mask tensor broadcastable to [*, num_heads, Q, KV] bias: Optional logit bias tensor broadcastable to [*, num_heads, Q, KV] Returns: output: [*, Q, c_q] tensor """ query, key, value = self._prep_qkv(input_q, input_kv) # query: [*, num_heads, Q, c_hidden] # key: [*, num_heads, K, c_hidden] # value: [*, num_heads, V, c_hidden] if evo_attn.is_enabled(): impl = "evo" else: impl = "torch" if self.impl is not None: impl = self.impl if impl == "torch": output = self._attention_forward(query, key, value, mask, bias) # output: [*, num_heads, Q, c_hidden] for torch implementation. output = output.transpose(-2, -3) # output: [*, Q, num_heads, c_hidden] elif impl == "evo": from deepfold.ops.evoformer_attention import deepspeed_evo_attn mask_bias = (mask - 1.0) * self.inf biases = [mask_bias] if bias is not None: biases.append(bias) # output = deepspeed_evo_attn(query / math.sqrt(self.c_hidden), key, value, biases) output = deepspeed_evo_attn(query, key, value, biases) else: raise ValueError(f"Unsupported implementation '{impl}'") del query, key, value output = output.reshape(output.shape[:-2] + (self.num_heads * self.c_hidden,)) # output: [*, Q, num_heads * c_hidden] output = self.linear_o(output) # output: [*, Q, c_q] return output
def _prep_qkv( self, input_q: torch.Tensor, input_kv: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # input_q: [*, Q, c_q] # input_kv: [*, KV, c_kv] q = self.linear_q(input_q) k = self.linear_k(input_kv) v = self.linear_v(input_kv) # q: [*, Q, num_heads * c_hidden] # k: [*, K, num_heads * c_hidden] # v: [*, V, num_heads * c_hidden] q = q.view(q.shape[:-1] + (self.num_heads, self.c_hidden)) k = k.view(k.shape[:-1] + (self.num_heads, self.c_hidden)) v = v.view(v.shape[:-1] + (self.num_heads, self.c_hidden)) # q: [*, Q, num_heads, c_hidden] # k: [*, K, num_heads, c_hidden] # v: [*, V, num_heads, c_hidden] q = q.transpose(-2, -3) k = k.transpose(-2, -3) v = v.transpose(-2, -3) # q: [*, num_heads, Q, c_hidden] # k: [*, num_heads, K, c_hidden] # v: [*, num_heads, V, c_hidden] # q = q / math.sqrt(self.c_hidden) scaling moved to _attention function return q, k, v def _attention_forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor, bias: Optional[torch.Tensor], ) -> torch.Tensor: if self.chunk_size is None: if inductor.is_enabled(): return _attention_eager(query, key, value, mask, bias, self.inf) else: return _attention_jit(query, key, value, mask, bias, self.inf) else: return _attention_chunked(query, key, value, mask, bias, self.inf, self.chunk_size)
def _attention_eager( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor, bias: Optional[torch.Tensor], inf: float, ) -> torch.Tensor: # query: [*, num_heads, Q, c_hidden] # key: [*, num_heads, K, c_hidden] # value: [*, num_heads, V, c_hidden] # mask: Logit mask tensor broadcastable to [*, num_heads, Q, K] # bias: Optional logit bias tensor broadcastable to [*, num_heads, Q, K] # inf: Safe infinity value. # assuming K == V key = torch.swapdims(key, -2, -1) # key: [*, num_heads, c_hidden, K] scaling = 1.0 / math.sqrt(query.size(-1)) a = torch.matmul(query * scaling, key) # a: [*, num_heads, Q, K] a += (mask - 1.0) * inf # a: [*, num_heads, Q, K] if bias is not None: a += bias # a: [*, num_heads, Q, K] a = torch.softmax(a, dim=-1) # a: [*, num_heads, Q, K] a = torch.matmul(a, value) # a: [*, num_heads, Q, c_hidden] return a _attention_jit = torch.compile(_attention_eager) def _attention_chunked( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor, bias: Optional[torch.Tensor], inf: float, chunk_size: int, ) -> torch.Tensor: output_chunks = [] subbatch_size = query.size(1) if inductor.is_enabled(): attention_fn = _attention_jit else: attention_fn = _attention_eager for left, right in slice_generator(0, subbatch_size, chunk_size): query_chunk = query[:, left:right] key_chunk = key[:, left:right] value_chunk = value[:, left:right] mask_chunk = mask[:, left:right] if mask.size(1) == subbatch_size else mask bias_chunk = bias[:, left:right] if bias is not None and bias.size(1) == subbatch_size else bias output_chunk = attention_fn( query=query_chunk, key=key_chunk, value=value_chunk, mask=mask_chunk, bias=bias_chunk, inf=inf, ) output_chunks.append(output_chunk) return torch.cat(output_chunks, dim=1) def _linear_transpose_add_eager( x: torch.Tensor, w: torch.Tensor, b: torch.Tensor, y: torch.Tensor, ) -> torch.Tensor: return y + F.linear(x, w, b).transpose(-2, -3) _linear_transpose_add_jit = torch.compile(_linear_transpose_add_eager) def _linear_transpose_add( x: torch.Tensor, w: torch.Tensor, b: torch.Tensor, y: torch.Tensor, ) -> torch.Tensor: if inductor.is_enabled(): linear_transpose_add_fn = _linear_transpose_add_jit else: linear_transpose_add_fn = _linear_transpose_add_eager return linear_transpose_add_fn(x, w, b, y)