[docs]@torch.jit.ignoredefdeepspeed_evo_attn(q:torch.Tensor,k:torch.Tensor,v:torch.Tensor,biases:List[torch.Tensor],)->torch.Tensor:""" "" Compute attention using the DeepSpeed DS4Sci_EvoformerAttention kernel. Args: q: [*, H, Q, C_hidden] query data k: [*, H, K, C_hidden] key data v: [*, H, V, C_hidden] value data biases: List of biases that broadcast to [*, H, Q, K] """defreshape_dims(x):no_batch_dims=len(x.shape[:-3])ifno_batch_dims<2:returnx.reshape(*((1,)*(2-no_batch_dims)+x.shape))ifno_batch_dims>2:returnx.reshape(*((x.shape[0],-1)+x.shape[-3:]))returnx# [*, Q/K, H, C_hidden]q=q.transpose(-2,-3)k=k.transpose(-2,-3)v=v.transpose(-2,-3)# Reshape tensors to match expected input shape [B, N, Q/K, H, C_hidden]# for DS4Sci_EvoformerAttention() by adding or flattening batch dims as needed.orig_shape=q.shapeiflen(orig_shape[:-3])!=2:q=reshape_dims(q)k=reshape_dims(k)v=reshape_dims(v)biases=[reshape_dims(b)forbinbiases]# DeepSpeed attn. kernel requires inputs to be type bf16 or fp16# Cast to bf16 so kernel can be used during inferenceorig_dtype=q.dtypeiforig_dtypenotin[torch.bfloat16,torch.float16]:o=DS4Sci_EvoformerAttention(q.to(dtype=torch.bfloat16),k.to(dtype=torch.bfloat16),v.to(dtype=torch.bfloat16),[b.to(dtype=torch.bfloat16)forbinbiases],)o=o.to(dtype=orig_dtype)else:o=DS4Sci_EvoformerAttention(q,k,v,biases)o=o.reshape(orig_shape)returno