import math
from typing import Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import deepfold.modules.inductor as inductor
from deepfold.modules.linear import Linear
from deepfold.utils.geometry import Rigid3Array, Vec3Array, square_euclidean_distance
from deepfold.utils.precision import is_fp16_enabled
from deepfold.utils.rigid_utils import Rigid
from deepfold.utils.tensor_utils import add, flatten_final_dims
[docs]
class InvariantPointAttention(nn.Module):
"""Invariant Point Attention (IPA) module.
Supplementary '1.8.2 Invariant point attention (IPA)': Algorithm 22.
Args:
c_s: Single representation dimension (channels).
c_z: Pair representation dimension (channels).
c_hidden: Hidden dimension (channels).
num_heads: Number of attention heads.
num_qk_points: Number of query/key points.
num_v_points: Number of value points.
separate_kv: Separate key/value projection.
inf: Safe infinity value.
eps: Epsilon to prevent division by zero.
"""
def __init__(
self,
c_s: int,
c_z: int,
c_hidden: int,
num_heads: int,
num_qk_points: int,
num_v_points: int,
separate_kv: bool,
inf: float,
eps: float,
) -> None:
super().__init__()
self.c_s = c_s
self.c_z = c_z
self.c_hidden = c_hidden
self.num_heads = num_heads
self.num_qk_points = num_qk_points
self.num_v_points = num_v_points
self.separate_kv = separate_kv
self.inf = inf
self.eps = eps
# These linear layers differ from their specifications in the supplement.
# There, they lack bias and use Glorot initialization.
# Here as in the official source, they have bias
# and use the default Lecun initialization.
hc = c_hidden * num_heads
self.linear_q = Linear(c_s, hc, bias=True, init="default")
if self.separate_kv:
self.linear_k = Linear(c_s, hc, bias=True, init="default")
self.linear_v = Linear(c_s, hc, bias=True, init="default")
else:
self.linear_kv = Linear(c_s, 2 * hc, bias=True, init="default")
hpq = num_heads * num_qk_points * 3
self.linear_q_points = Linear(c_s, hpq, bias=True, init="default")
hpk = self.num_heads * self.num_qk_points * 3
hpv = self.num_heads * self.num_v_points * 3
if self.separate_kv:
self.linear_k_points = Linear(c_s, hpk, bias=True, init="default")
self.linear_v_points = Linear(c_s, hpv, bias=True, init="default")
else:
hpkv = hpk + hpv
self.linear_kv_points = Linear(c_s, hpkv, bias=True, init="default")
self.linear_b = Linear(c_z, num_heads, bias=True, init="default")
self.head_weights = nn.Parameter(torch.zeros((num_heads)))
ipa_point_weights_init_(self.head_weights.data)
concat_out_dim = num_heads * (c_z + c_hidden + num_v_points * 4)
self.linear_out = Linear(concat_out_dim, c_s, bias=True, init="final")
[docs]
def forward(
self,
s: torch.Tensor,
z: torch.Tensor,
r: Rigid,
mask: torch.Tensor,
inplace_safe: bool,
) -> torch.Tensor:
"""Invariant Point Attention (IPA) forward pass.
Args:
s: [batch, N_res, c_s] single representation
z: [batch, N_res, N_res, c_z] pair representation
r: [batch, N_res] rigids transformation
mask: [batch, N_res] sequence mask
Returns:
s_update: [batch, N_res, c_s] single representation update
"""
#######################################
# Generate scalar and point activations
#######################################
if self.separate_kv: # Multimer
q = self.linear_q(s)
bias = self.linear_b(z)
k = self.linear_k(s)
v = self.linear_v(s)
q_pts = self.linear_q_points(s)
k_pts = self.linear_k_points(s)
v_pts = self.linear_v_points(s)
else:
q, bias, kv, q_pts, kv_pts = _forward_linears_on_inputs_eager(
s,
z,
self.linear_q.weight,
self.linear_q.bias,
self.linear_b.weight,
self.linear_b.bias,
self.linear_kv.weight,
self.linear_kv.bias,
self.linear_q_points.weight,
self.linear_q_points.bias,
self.linear_kv_points.weight,
self.linear_kv_points.bias,
)
# q: [batch, N_res, num_heads * c_hidden]
# b: [batch, N_res, N_res, num_heads]
# kv: [batch, N_res, num_heads * 2 * c_hidden]
# k/v: [batch, N_res, num_heads * c_hidden]
# q_pts: [batch, N_res, num_heads * num_qk_points * 3]
# kv_pts: [batch, N_res, num_heads * (num_qk_points + num_v_points) * 3]
q = q.view(q.shape[:-1] + (self.num_heads, self.c_hidden))
# q: [batch, N_res, num_heads, c_hidden]
if self.separate_kv:
k = k.view(k.shape[:-1] + (self.num_heads, -1))
v = v.view(v.shape[:-1] + (self.num_heads, -1))
else:
kv = kv.view(kv.shape[:-1] + (self.num_heads, -1))
# kv: [batch, N_res, num_heads, 2 * c_hidden]
k, v = torch.split(kv, self.c_hidden, dim=-1)
# k: [batch, N_res, num_heads, c_hidden]
# v: [batch, N_res, num_heads, c_hidden]
def process_points(pts: torch.Tensor, num_points: int) -> torch.Tensor:
shape = pts.shape[:-1] + (pts.shape[-1] // 3, 3)
if self.separate_kv:
pts = pts.view(pts.shape[:-1] + (self.num_heads, num_points * 3))
pts = torch.split(pts, pts.shape[-1] // 3, dim=-1)
pts = torch.stack(pts, dim=-1).view(*shape)
pts = r[..., None].apply(pts)
return pts.view(pts.shape[:-2] + (self.num_heads, num_points, 3))
q_pts = process_points(q_pts, self.num_qk_points)
# q_pts: [batch, N_res, num_heads, num_qk_points, 3]
if self.separate_kv:
k_pts = process_points(k_pts, self.num_qk_points)
v_pts = process_points(v_pts, self.num_v_points)
# k_pts: [batch, N_res, num_heads, num_qk_points, 3]
# v_pts: [batch, N_res, num_heads, num_v_points, 3]
else:
kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1)
kv_pts = torch.stack(kv_pts, dim=-1)
kv_pts = r[..., None].apply(kv_pts)
# kv_pts: [batch, N_res, num_heads * (num_qk_points + num_v_points), 3]
kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.num_heads, self.num_qk_points + self.num_v_points, 3))
# kv_pts: [batch, N_res, num_heads, (num_qk_points + num_v_points), 3]
k_pts, v_pts = torch.split(kv_pts, (self.num_qk_points, self.num_v_points), dim=-2)
# k_pts: [batch, N_res, num_heads, num_qk_points, 3]
# v_pts: [batch, N_res, num_heads, num_v_points, 3]
##########################
# Compute attention scores
##########################
if is_fp16_enabled():
with torch.cuda.amp.autocast(enabled=False):
a = torch.matmul(
q.float().movedim(-3, -2), # q: [batch, num_heads, N_res, c_hidden]
k.float().movedim(-3, -1), # k: [batch, num_heads, c_hidden, N_res]
)
else:
a = torch.matmul(
q.movedim(-3, -2), # q: [batch, num_heads, N_res, c_hidden]
k.movedim(-3, -1), # k: [batch, num_heads, c_hidden, N_res]
)
# a: [batch, num_heads, N_res, N_res]
if inductor.is_enabled():
forward_a_fn = _forward_a_jit
else:
forward_a_fn = _forward_a_eager
a = forward_a_fn(a, bias, q_pts, k_pts, mask, self.c_hidden, self.num_heads, self.head_weights, self.num_qk_points, self.inf, inplace_safe)
# a: [batch, num_heads, N_res, N_res]
################
# Compute output
################
o = torch.matmul(a, v.transpose(-2, -3).to(dtype=a.dtype)).transpose(-2, -3)
# o: [batch, N_res, num_heads, c_hidden]
o = o.reshape(o.shape[:-2] + (self.num_heads * self.c_hidden,))
# o: [batch, N_res, num_heads * c_hidden]
o_pt = torch.sum((a.unsqueeze(-3).unsqueeze(-1) * v_pts.swapdims(-4, -3).movedim(-1, -3).unsqueeze(-3)), dim=-2)
# o_pt: [batch, num_heads, 3, N_res, num_v_points]
o_pt = o_pt.movedim(-3, -1).swapdims(-3, -4)
# o_pt: [batch, N_res, num_heads, num_v_points, 3]
o_pt = r.unsqueeze(-1).unsqueeze(-2).invert_apply(o_pt)
# o_pt: [batch, N_res, num_heads, num_v_points, 3]
if inductor.is_enabled():
forward_o_pt_norm_fn = _forward_o_pt_norm_jit
else:
forward_o_pt_norm_fn = _forward_o_pt_norm_eager
o_pt_norm = forward_o_pt_norm_fn(o_pt, self.eps)
# o_pt_norm: [batch, N_res, num_heads, num_v_points]
o_pt_norm = o_pt_norm.reshape(o_pt_norm.shape[:-2] + (self.num_heads * self.num_v_points,))
# o_pt_norm: [batch, N_res, num_heads * num_v_points]
o_pt = o_pt.reshape(o_pt.shape[:-3] + (self.num_heads * self.num_v_points, 3))
# o_pt: [batch, N_res, num_heads * num_v_points, 3]
o_pair = torch.matmul(a.transpose(-2, -3), z.to(dtype=a.dtype))
# o_pair: [batch, N_res, num_heads, c_z]
o_pair = o_pair.reshape(o_pair.shape[:-2] + (self.num_heads * self.c_z,))
# o_pair: [batch, N_res, num_heads * c_z]
o_cat = (o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair)
o_cat = torch.cat(o_cat, dim=-1)
# o_cat: [batch, N_res, num_heads * (c_hidden + num_v_points * 4 + c_z)]
s_update = self.linear_out(o_cat.to(dtype=z.dtype))
# s_update: [batch, N_res, c_s]
return s_update
[docs]
def ipa_point_weights_init_(weights_data: torch.Tensor) -> None:
softplus_inverse_1 = 0.541324854612918
weights_data.fill_(softplus_inverse_1)
def _forward_linears_on_inputs_eager(
s: torch.Tensor,
z: torch.Tensor,
w_q: torch.Tensor,
b_q: torch.Tensor,
w_b: torch.Tensor,
b_b: torch.Tensor,
w_kv: torch.Tensor,
b_kv: torch.Tensor,
w_q_points: torch.Tensor,
b_q_points: torch.Tensor,
w_kv_points: torch.Tensor,
b_kv_points: torch.Tensor,
) -> torch.Tensor:
q = F.linear(s, w_q, b_q)
b = F.linear(z, w_b, b_b)
kv = F.linear(s, w_kv, b_kv)
q_pts = F.linear(s, w_q_points, b_q_points)
kv_pts = F.linear(s, w_kv_points, b_kv_points)
return q, b, kv, q_pts, kv_pts
_forward_linears_on_inputs_jit = torch.compile(_forward_linears_on_inputs_eager)
def _forward_a_eager(
a: torch.Tensor,
b: torch.Tensor,
q_pts: torch.Tensor,
k_pts: torch.Tensor,
mask: torch.Tensor,
c_hidden: int,
num_heads: int,
head_weights: torch.Tensor,
num_qk_points: int,
inf: float,
inplace: bool,
) -> torch.Tensor:
# a: [batch, num_heads, N_res, N_res]
# b: [batch, N_res, N_res, num_heads]
# q_pts: [batch, N_res, num_heads, num_qk_points, 3]
# k_pts: [batch, N_res, num_heads, num_qk_points, 3]
# mask: [batch, N_res]
if inplace:
a *= math.sqrt(1.0 / (3 * c_hidden))
a += math.sqrt(1.0 / 3) * b.movedim(-1, -3)
else:
a = a * math.sqrt(1.0 / (3 * c_hidden))
a = a + (math.sqrt(1.0 / 3) * b.movedim(-1, -3))
# a: [batch, num_heads, N_res, N_res]
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) # outer subtraction
if inplace:
pt_att *= pt_att
else:
pt_att = pt_att**2
# pt_att: [batch, N_res, N_res, num_heads, num_qk_points, 3]
pt_att = sum(torch.unbind(pt_att, dim=-1))
# pt_att: [batch, N_res, N_res, num_heads, num_qk_points]
head_weights = F.softplus(head_weights)
head_weights = head_weights.view((1,) * (pt_att.ndim - 2) + (num_heads, 1))
if inplace:
head_weights *= math.sqrt(1.0 / (3 * (num_qk_points * 9.0 / 2)))
else:
head_weights = head_weights * math.sqrt(1.0 / (3 * (num_qk_points * 9.0 / 2)))
# head_weights: [1, 1, 1, num_heads, 1]
if inplace:
pt_att *= head_weights
else:
pt_att = pt_att * head_weights
# pt_att: [batch, N_res, N_res, num_heads, num_qk_points]
pt_att = -0.5 * torch.sum(pt_att, dim=-1)
# pt_att: [batch, N_res, N_res, num_heads]
square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) # outer product
square_mask = (square_mask - 1.0) * inf
# square_mask: [batch, N_res, N_res]
pt_att = pt_att.movedim(-1, -3)
# square_mask: [batch, num_heads, N_res, N_res]
if inplace:
a += pt_att
else:
a = a + pt_att
# a: [batch, num_heads, N_res, N_res]
if inplace:
a += square_mask.unsqueeze(-3)
else:
a = a + square_mask.unsqueeze(-3)
# a: [batch, num_heads, N_res, N_res]
a = torch.softmax(a, dim=-1)
# a: [batch, num_heads, N_res, N_res]
return a
_forward_a_jit = torch.compile(_forward_a_eager)
def _forward_o_pt_norm_eager(o_pt: torch.Tensor, eps: float) -> torch.Tensor:
return torch.sqrt(torch.sum(o_pt**2, dim=-1) + eps)
_forward_o_pt_norm_jit = torch.compile(_forward_o_pt_norm_eager)
[docs]
class PointProjection(nn.Module):
def __init__(
self,
c_hidden: int,
num_points: int,
no_heads: int,
is_multimer: bool,
return_local_points: bool = False,
):
super().__init__()
self.return_local_points = return_local_points
self.no_heads = no_heads
self.num_points = num_points
self.is_multimer = is_multimer
# TODO: Multimer requires this to be run with fp32 precision during training
self.linear = Linear(c_hidden, no_heads * 3 * num_points)
[docs]
def forward(
self,
activations: torch.Tensor,
rigids: Union[Rigid, Rigid3Array],
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# TODO: Needs to run in high precision during training
points_local = self.linear(activations)
out_shape = points_local.shape[:-1] + (self.no_heads, self.num_points, 3)
if self.is_multimer:
points_local = points_local.view(points_local.shape[:-1] + (self.no_heads, -1))
points_local = torch.split(points_local, points_local.shape[-1] // 3, dim=-1)
points_local = torch.stack(points_local, dim=-1).view(out_shape)
points_global = rigids[..., None, None].apply(points_local)
if self.return_local_points:
return points_global, points_local
return points_global
[docs]
class InvariantPointAttentionMultimer(nn.Module):
def __init__(
self,
c_s: int,
c_z: int,
c_hidden: int,
num_heads: int,
num_qk_points: int,
num_v_points: int,
inf: float = 1e5,
eps: float = 1e-8,
):
"""
Args:
c_s:
Single representation channel dimension
c_z:
Pair representation channel dimension
c_hidden:
Hidden channel dimension
num_heads:
Number of attention heads
num_qk_points:
Number of query/key points to generate
num_v_points:
Number of value points to generate
"""
super(InvariantPointAttentionMultimer, self).__init__()
self.c_s = c_s
self.c_z = c_z
self.c_hidden = c_hidden
self.num_heads = num_heads
self.num_qk_points = num_qk_points
self.num_v_points = num_v_points
self.inf = inf
self.eps = eps
# These linear layers differ from their specifications in the supplement.
# There, they lack bias and use Glorot initialization.
# Here as in the official source, they have bias and use the default Lecun initialization.
hc = self.c_hidden * self.num_heads
self.linear_q = Linear(self.c_s, hc, bias=False)
self.linear_q_points = PointProjection(self.c_s, self.num_qk_points, self.num_heads, is_multimer=True)
self.linear_k = Linear(self.c_s, hc, bias=False)
self.linear_v = Linear(self.c_s, hc, bias=False)
self.linear_k_points = PointProjection(self.c_s, self.num_qk_points, self.num_heads, is_multimer=True)
self.linear_v_points = PointProjection(self.c_s, self.num_v_points, self.num_heads, is_multimer=True)
self.linear_b = Linear(self.c_z, self.num_heads)
self.head_weights = nn.Parameter(torch.zeros((num_heads)))
ipa_point_weights_init_(self.head_weights)
concat_out_dim = self.num_heads * (self.c_z + self.c_hidden + self.num_v_points * 4)
self.linear_out = Linear(concat_out_dim, self.c_s, init="final")
self.softmax = nn.Softmax(dim=-2)
[docs]
def forward(
self,
s: torch.Tensor,
z: torch.Tensor,
r: Rigid3Array,
mask: torch.Tensor,
inplace_safe: bool,
) -> torch.Tensor:
"""
Args:
s:
[*, N_res, C_s] single representation
z:
[*, N_res, N_res, C_z] pair representation
r:
[*, N_res] transformation object
mask:
[*, N_res] mask
Returns:
[*, N_res, C_s] single representation update
"""
a = 0.0
point_variance = max(self.num_qk_points, 1) * 9.0 / 2
point_weights = math.sqrt(1.0 / point_variance)
softplus = lambda x: torch.logaddexp(x, torch.zeros_like(x))
head_weights = softplus(self.head_weights)
point_weights = point_weights * head_weights
#######################################
# Generate scalar and point activations
#######################################
# [*, N_res, H, P_qk]
q_pts = Vec3Array.from_array(self.linear_q_points(s, r))
# [*, N_res, H, P_qk, 3]
k_pts = Vec3Array.from_array(self.linear_k_points(s, r))
pt_att = square_euclidean_distance(q_pts.unsqueeze(-3), k_pts.unsqueeze(-4), epsilon=0.0)
pt_att = torch.sum(pt_att * point_weights[..., None], dim=-1) * (-0.5)
pt_att = pt_att.to(dtype=s.dtype)
a = add(a, pt_att, inplace_safe)
scalar_variance = max(self.c_hidden, 1) * 1.0
scalar_weights = math.sqrt(1.0 / scalar_variance)
# [*, N_res, H * C_hidden]
q = self.linear_q(s)
k = self.linear_k(s)
# [*, N_res, H, C_hidden]
q = q.view(q.shape[:-1] + (self.num_heads, -1))
k = k.view(k.shape[:-1] + (self.num_heads, -1))
if inplace_safe:
q *= scalar_weights
a += torch.einsum("...qhc,...khc->...qkh", q, k)
else:
q = q * scalar_weights
a = a + torch.einsum("...qhc,...khc->...qkh", q, k)
##########################
# Compute attention scores
##########################
# [*, N_res, N_res, H]
b = self.linear_b(z)
if inplace_safe:
a += b
else:
a = a + b
# [*, N_res, N_res]
square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
square_mask = self.inf * (square_mask - 1)
if inplace_safe:
a += square_mask.unsqueeze(-1)
a *= math.sqrt(1.0 / 3) # Normalize by number of logit terms (3)
else:
a = a + square_mask.unsqueeze(-1)
a = a * math.sqrt(1.0 / 3) # Normalize by number of logit terms (3)
a = self.softmax(a)
# [*, N_res, H * C_hidden]
v = self.linear_v(s)
# [*, N_res, H, C_hidden]
v = v.view(v.shape[:-1] + (self.num_heads, -1))
o = torch.einsum("...qkh,...khc->...qhc", a, v)
# [*, N_res, H * C_hidden]
o = flatten_final_dims(o, 2)
# [*, N_res, H, P_v, 3]
v_pts = Vec3Array.from_array(self.linear_v_points(s, r))
# [*, N_res, H, P_v]
o_pt = v_pts[..., None, :, :, :] * a.unsqueeze(-1)
o_pt = o_pt.sum(dim=-3)
# o_pt = Vec3Array(
# torch.sum(a.unsqueeze(-1) * v_pts[..., None, :, :, :].x, dim=-3),
# torch.sum(a.unsqueeze(-1) * v_pts[..., None, :, :, :].y, dim=-3),
# torch.sum(a.unsqueeze(-1) * v_pts[..., None, :, :, :].z, dim=-3),
# )
# [*, N_res, H * P_v, 3]
o_pt = o_pt.reshape(o_pt.shape[:-2] + (-1,))
# [*, N_res, H, P_v]
o_pt = r[..., None].apply_inverse_to_point(o_pt)
o_pt_flat = [o_pt.x, o_pt.y, o_pt.z]
o_pt_flat = [x.to(dtype=a.dtype) for x in o_pt_flat]
# [*, N_res, H * P_v]
o_pt_norm = o_pt.norm(epsilon=1e-8)
o_pair = torch.einsum("...ijh,...ijc->...ihc", a, z.to(dtype=a.dtype))
# [*, N_res, H * C_z]
o_pair = flatten_final_dims(o_pair, 2)
# [*, N_res, C_s]
s = self.linear_out(torch.cat((o, *o_pt_flat, o_pt_norm, o_pair), dim=-1).to(dtype=z.dtype))
return s