import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import deepfold.modules.inductor as inductor
from deepfold.modules.linear import Linear
from deepfold.utils.iter_utils import slice_generator
[docs]
class GlobalAttention(nn.Module):
"""Global Attention module.
Args:
c_e: Extra MSA representation dimension (channels).
c_hidden: Per-head hidden dimension (channels).
num_heads: Number of attention heads.
inf: Safe infinity value.
eps: Epsilon to prevent division by zero.
chunk_size: Optional chunk size for a batch-like dimension.
"""
def __init__(
self,
c_e: int,
c_hidden: int,
num_heads: int,
inf: float,
eps: float,
chunk_size: Optional[int],
) -> None:
super().__init__()
self.c_e = c_e
self.c_hidden = c_hidden
self.num_heads = num_heads
self.inf = inf
self.eps = eps
self.chunk_size = chunk_size
self.linear_q = Linear(c_e, c_hidden * num_heads, bias=False, init="glorot")
self.linear_k = Linear(c_e, c_hidden, bias=False, init="glorot")
self.linear_v = Linear(c_e, c_hidden, bias=False, init="glorot")
self.linear_g = Linear(c_e, c_hidden * num_heads, bias=True, init="gating")
self.linear_o = Linear(c_hidden * num_heads, c_e, bias=True, init="final")
[docs]
def forward(
self,
m: torch.Tensor,
mask: torch.Tensor,
add_transposed_output_to: Optional[torch.Tensor],
) -> torch.Tensor:
"""Global Attention forward pass.
Args:
m: [batch, N_res, N_extra_seq, c_e] transposed extra MSA representation
mask: [batch, N_res, N_extra_seq] transposed extra MSA mask
add_transposed_output_to:
Optional tensor to which transposed output will be added elementwisely.
Returns:
m: [batch, N_extra_seq, N_res, c_e] updated extra MSA representation
"""
if self.chunk_size is None:
return self._forward(
m=m,
mask=mask,
add_transposed_output_to=add_transposed_output_to,
)
else:
return self._forward_chunked(
m=m,
mask=mask,
chunk_size=self.chunk_size,
add_transposed_output_to=add_transposed_output_to,
)
def _forward(
self,
m: torch.Tensor,
mask: torch.Tensor,
add_transposed_output_to: Optional[torch.Tensor],
) -> torch.Tensor:
# torch.cuda..range_push("global_attention")
q = _mul_sum_x2_add_div(m, mask, self.eps)
# q: [batch, N_res, c_e]
q = self.linear_q(q)
# q: [batch, N_res, num_heads * c_hidden]
q = q * math.sqrt(1 / self.c_hidden)
# q: [batch, N_res, num_heads * c_hidden]
q = q.view(q.shape[:-1] + (self.num_heads, self.c_hidden))
# q: [batch, N_res, num_heads, c_hidden]
k = self.linear_k(m)
# k: [batch, N_res, N_extra_seq, c_hidden]
v = self.linear_v(m)
# v: [batch, N_res, N_extra_seq, c_hidden]
a = torch.matmul(q, k.transpose(-1, -2))
# a: [batch, N_res, num_heads, N_extra_seq]
a = _add_softmax(a, mask, self.inf)
# a: [batch, N_res, num_heads, N_extra_seq]
o = torch.matmul(a, v)
# o: [batch, N_res, num_heads, c_hidden]
g = _linear(m, self.linear_g.weight, self.linear_g.bias)
# g: [batch, N_res, N_extra_seq, num_heads * c_hidden]
o = _sigmoid_mul(g, o, self.num_heads, self.c_hidden)
o = o.reshape(o.shape[:-2] + (self.num_heads * self.c_hidden,))
# o: [batch, N_res, N_extra_seq, num_heads * c_hidden]
if add_transposed_output_to is None:
m = self.linear_o(o)
else:
m = _linear_transpose_add(
o,
self.linear_o.weight,
self.linear_o.bias,
add_transposed_output_to,
)
# m: [batch, N_extra_seq, N_res, c_e]
return m
def _forward_chunked(
self,
m: torch.Tensor,
mask: torch.Tensor,
chunk_size: int,
add_transposed_output_to: Optional[torch.Tensor],
) -> torch.Tensor:
output_chunks = []
subbatch_size = m.size(1)
for left, right in slice_generator(0, subbatch_size, chunk_size):
m_chunk = m[:, left:right]
mask_chunk = mask[:, left:right]
output_chunk = self._forward(
m=m_chunk,
mask=mask_chunk,
add_transposed_output_to=None,
)
output_chunks.append(output_chunk)
out = torch.cat(output_chunks, dim=1)
if add_transposed_output_to is None:
return out
else:
return add_transposed_output_to + out.transpose(-2, -3)
def _mul_sum_x2_add_div_eager(
m: torch.Tensor,
mask: torch.Tensor,
eps: float,
) -> torch.Tensor:
q_num = torch.sum(m * mask.unsqueeze(-1), dim=-2)
q_den = torch.sum(mask, dim=-1).add(eps).unsqueeze(-1)
q = q_num / q_den
return q
_mul_sum_x2_add_div_jit = torch.compile(_mul_sum_x2_add_div_eager)
def _mul_sum_x2_add_div(
m: torch.Tensor,
mask: torch.Tensor,
eps: float,
) -> torch.Tensor:
if inductor.is_enabled():
mul_sum_x2_add_div_fn = _mul_sum_x2_add_div_jit
else:
mul_sum_x2_add_div_fn = _mul_sum_x2_add_div_eager
return mul_sum_x2_add_div_fn(m, mask, eps)
def _add_softmax_eager(
a: torch.Tensor,
mask: torch.Tensor,
inf: float,
) -> torch.Tensor:
bias = ((mask - 1.0) * inf).unsqueeze(-2)
a = a + bias
a = torch.softmax(a, dim=-1)
return a
_add_softmax_jit = torch.compile(_add_softmax_eager)
def _add_softmax(
a: torch.Tensor,
mask: torch.Tensor,
inf: float,
) -> torch.Tensor:
if inductor.is_enabled():
add_softmax_fn = _add_softmax_jit
else:
add_softmax_fn = _add_softmax_eager
return add_softmax_fn(a, mask, inf)
def _linear_eager(
x: torch.Tensor,
w: torch.Tensor,
b: torch.Tensor,
) -> torch.Tensor:
return F.linear(x, w, b)
_linear_jit = torch.compile(_linear_eager)
def _linear(
x: torch.Tensor,
w: torch.Tensor,
b: torch.Tensor,
) -> torch.Tensor:
if inductor.is_enabled_and_autograd_off():
linear_fn = _linear_jit
else:
linear_fn = _linear_eager
return linear_fn(x, w, b)
def _sigmoid_mul_eager(
g: torch.Tensor,
o: torch.Tensor,
num_heads: int,
c_hidden: int,
) -> torch.Tensor:
g = torch.sigmoid(g)
g = g.view(g.shape[:-1] + (num_heads, c_hidden))
o = o.unsqueeze(-3) * g
return o
_sigmoid_mul_jit = torch.compile(_sigmoid_mul_eager)
def _sigmoid_mul(
g: torch.Tensor,
o: torch.Tensor,
num_heads: int,
c_hidden: int,
) -> torch.Tensor:
if inductor.is_enabled():
sigmoid_mul_fn = _sigmoid_mul_jit
else:
sigmoid_mul_fn = _sigmoid_mul_eager
return sigmoid_mul_fn(g, o, num_heads, c_hidden)
def _linear_transpose_add_eager(
x: torch.Tensor,
w: torch.Tensor,
b: torch.Tensor,
out: torch.Tensor,
) -> torch.Tensor:
return out + F.linear(x, w, b).transpose(-2, -3)
_linear_transpose_add_jit = torch.compile(_linear_transpose_add_eager)
def _linear_transpose_add(
x: torch.Tensor,
w: torch.Tensor,
b: torch.Tensor,
out: torch.Tensor,
) -> torch.Tensor:
if inductor.is_enabled():
linear_transpose_add_fn = _linear_transpose_add_jit
else:
linear_transpose_add_fn = _linear_transpose_add_eager
return linear_transpose_add_fn(x, w, b, out)