Source code for deepfold.ops.evoformer_attention

from typing import List

import torch

try:
    from deepfold_kernels.evoformer_attn import DS4Sci_EvoformerAttention
except ModuleNotFoundError:
    from deepfold.modules.tweaks import evo_attn

    # Disable evoformer attention
    evo_attn.disable()


[docs] @torch.jit.ignore def deepspeed_evo_attn( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, biases: List[torch.Tensor], ) -> torch.Tensor: """ "" Compute attention using the DeepSpeed DS4Sci_EvoformerAttention kernel. Args: q: [*, H, Q, C_hidden] query data k: [*, H, K, C_hidden] key data v: [*, H, V, C_hidden] value data biases: List of biases that broadcast to [*, H, Q, K] """ def reshape_dims(x): no_batch_dims = len(x.shape[:-3]) if no_batch_dims < 2: return x.reshape(*((1,) * (2 - no_batch_dims) + x.shape)) if no_batch_dims > 2: return x.reshape(*((x.shape[0], -1) + x.shape[-3:])) return x # [*, Q/K, H, C_hidden] q = q.transpose(-2, -3) k = k.transpose(-2, -3) v = v.transpose(-2, -3) # Reshape tensors to match expected input shape [B, N, Q/K, H, C_hidden] # for DS4Sci_EvoformerAttention() by adding or flattening batch dims as needed. orig_shape = q.shape if len(orig_shape[:-3]) != 2: q = reshape_dims(q) k = reshape_dims(k) v = reshape_dims(v) biases = [reshape_dims(b) for b in biases] # DeepSpeed attn. kernel requires inputs to be type bf16 or fp16 # Cast to bf16 so kernel can be used during inference orig_dtype = q.dtype if orig_dtype not in [torch.bfloat16, torch.float16]: o = DS4Sci_EvoformerAttention( q.to(dtype=torch.bfloat16), k.to(dtype=torch.bfloat16), v.to(dtype=torch.bfloat16), [b.to(dtype=torch.bfloat16) for b in biases], ) o = o.to(dtype=orig_dtype) else: o = DS4Sci_EvoformerAttention(q, k, v, biases) o = o.reshape(orig_shape) return o