Source code for deepfold.modules.template_pair_embedder

from typing import Dict, List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

import deepfold.common.residue_constants as rc
import deepfold.modules.inductor as inductor
from deepfold.modules.layer_norm import LayerNorm
from deepfold.modules.linear import Linear
from deepfold.utils.rigid_utils import Rigid


[docs] class TemplatePairEmbedder(nn.Module): """Template Pair Embedder module. Embeds the "template_pair_feat" feature. Supplementary '1.4 AlphaFold Inference': Algorithm 2, line 9. Args: tp_dim: Input `template_pair_feat` dimension (channels). c_t: Output template representation dimension (channels). """ def __init__( self, tp_dim: int, c_t: int, **kwargs, ) -> None: super().__init__() self.tp_dim = tp_dim self.c_t = c_t self.linear = Linear(tp_dim, c_t, bias=True, init="relu")
[docs] def forward( self, template_pair_feat: torch.Tensor, ) -> torch.Tensor: """Template Pair Embedder forward pass. Args: template_pair_feat: [batch, N_res, N_res, tp_dim] Returns: template_pair_embedding: [batch, N_res, N_res, c_t] """ return self.linear(template_pair_feat)
[docs] def build_template_pair_feat( self, feats: Dict[str, torch.Tensor], min_bin: float, max_bin: float, num_bins: int, use_unit_vector: bool, inf: float, eps: float, dtype: torch.dtype, ) -> torch.Tensor: template_pseudo_beta = feats["template_pseudo_beta"] template_pseudo_beta_mask = feats["template_pseudo_beta_mask"] template_aatype = feats["template_aatype"] template_all_atom_mask = feats["template_all_atom_mask"] self._initialize_buffers( min_bin=min_bin, max_bin=max_bin, num_bins=num_bins, inf=inf, device=template_pseudo_beta.device, ) if inductor.is_enabled(): compute_part1_fn = _compute_part1_jit else: compute_part1_fn = _compute_part1_eager to_concat, aatype_one_hot = compute_part1_fn( template_pseudo_beta, template_pseudo_beta_mask, template_aatype, self.lower, self.upper, rc.restype_num + 2, ) num_res = template_aatype.shape[-1] to_concat.append(aatype_one_hot.unsqueeze(-3).expand(*aatype_one_hot.shape[:-2], num_res, -1, -1)) to_concat.append(aatype_one_hot.unsqueeze(-2).expand(*aatype_one_hot.shape[:-2], -1, num_res, -1)) n, ca, c = [rc.atom_order[a] for a in ["N", "CA", "C"]] if inductor.is_enabled(): make_transform_from_reference = Rigid.make_transform_from_reference else: make_transform_from_reference = Rigid.make_transform_from_reference rigids = make_transform_from_reference( n_xyz=feats["template_all_atom_positions"][..., n, :], ca_xyz=feats["template_all_atom_positions"][..., ca, :], c_xyz=feats["template_all_atom_positions"][..., c, :], eps=eps, ) points = rigids.get_trans().unsqueeze(-3) rigid_vec = rigids.unsqueeze(-1).invert_apply(points) if inductor.is_enabled(): compute_part2_fn = _compute_part2_jit else: compute_part2_fn = _compute_part2_eager t = compute_part2_fn( rigid_vec, eps, template_all_atom_mask, n, ca, c, use_unit_vector, to_concat, dtype, ) return t
def _initialize_buffers( self, min_bin: float, max_bin: float, num_bins: int, inf: float, device: torch.device, ) -> None: if not hasattr(self, "lower") or not hasattr(self, "upper"): bins = torch.linspace( start=min_bin, end=max_bin, steps=num_bins, device=device, requires_grad=False, ) lower = torch.pow(bins, 2) upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1) self.register_buffer("lower", lower, persistent=False) self.register_buffer("upper", upper, persistent=False)
def _compute_part1_eager( tpb: torch.Tensor, template_mask: torch.Tensor, template_aatype: torch.Tensor, lower: torch.Tensor, upper: torch.Tensor, num_classes: int, ) -> Tuple[List[torch.Tensor], torch.Tensor]: template_mask_2d = template_mask.unsqueeze(-1) * template_mask.unsqueeze(-2) dgram = torch.sum( input=(tpb.unsqueeze(-2) - tpb.unsqueeze(-3)) ** 2, dim=-1, keepdim=True, ) dgram = ((dgram > lower) * (dgram < upper)).to(dtype=dgram.dtype) to_concat = [dgram, template_mask_2d.unsqueeze(-1)] aatype_one_hot = F.one_hot( template_aatype, num_classes=num_classes, ) return to_concat, aatype_one_hot _compute_part1_jit = torch.compile(_compute_part1_eager) def _compute_part2_eager( rigid_vec: torch.Tensor, eps: float, t_aa_masks: torch.Tensor, n: int, ca: int, c: int, use_unit_vector: bool, to_concat: List[torch.Tensor], dtype: torch.dtype, ) -> torch.Tensor: inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec**2, dim=-1)) template_mask = t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., c] template_mask_2d = template_mask.unsqueeze(-1) * template_mask.unsqueeze(-2) inv_distance_scalar = inv_distance_scalar * template_mask_2d unit_vector = rigid_vec * inv_distance_scalar.unsqueeze(-1) if not use_unit_vector: unit_vector = unit_vector * 0.0 to_concat.extend(torch.unbind(unit_vector.unsqueeze(-2), dim=-1)) to_concat.append(template_mask_2d.unsqueeze(-1)) t = torch.cat(to_concat, dim=-1) t = t * template_mask_2d.unsqueeze(-1) t = t.to(dtype=dtype) return t _compute_part2_jit = torch.compile(_compute_part2_eager)
[docs] class TemplatePairEmbedderMultimer(nn.Module): def __init__( self, c_z: int, c_t: int, c_dgram: int, c_aatype: int, **kwargs, ) -> None: super().__init__() self.c_dgram = c_dgram self.c_aatype = c_aatype self.dgram_linear = Linear(c_dgram, c_t, init="relu") self.aatype_linear_1 = Linear(c_aatype, c_t, init="relu") self.aatype_linear_2 = Linear(c_aatype, c_t, init="relu") self.query_embedding_layer_norm = LayerNorm(c_z) self.query_embedding_linear = Linear(c_z, c_t, init="relu") self.pseudo_beta_mask_linear = Linear(1, c_t, init="relu") self.x_linear = Linear(1, c_t, init="relu") self.y_linear = Linear(1, c_t, init="relu") self.z_linear = Linear(1, c_t, init="relu") self.backbone_mask_linear = Linear(1, c_t, init="relu") _initialize_buffers = TemplatePairEmbedder._initialize_buffers
[docs] def forward( self, query_embedding: torch.Tensor, multichain_mask_2d: torch.Tensor, template_dgram: torch.Tensor, aatype_one_hot: torch.Tensor, pseudo_beta_mask: torch.Tensor, backbone_mask: torch.Tensor, # [..., N_res] unit_vector: torch.Tensor, # [..., N_res, N_res, 3] ) -> torch.Tensor: # Build 2D pseudo beta mask pseudo_beta_mask_2d = pseudo_beta_mask[..., :, None] * pseudo_beta_mask[..., None, :] pseudo_beta_mask_2d *= multichain_mask_2d template_dgram *= pseudo_beta_mask_2d[..., None] act = self.dgram_linear(template_dgram) act += self.pseudo_beta_mask_linear(pseudo_beta_mask_2d[..., None]) aatype_one_hot = aatype_one_hot.to(template_dgram.dtype) act += self.aatype_linear_1(aatype_one_hot[..., None, :, :]) act += self.aatype_linear_2(aatype_one_hot[..., :, None, :]) backbone_mask_2d = backbone_mask[..., :, None] * backbone_mask[..., None, :] backbone_mask_2d *= multichain_mask_2d # backbone_mask_2d: [1, N_res, N_res] x, y, z = [(coord * backbone_mask_2d).to(dtype=query_embedding.dtype) for coord in unit_vector.unbind(dim=-1)] act += self.x_linear(x[..., None]) act += self.y_linear(y[..., None]) act += self.z_linear(z[..., None]) act += self.backbone_mask_linear(backbone_mask_2d[..., None].to(dtype=query_embedding.dtype)) query_embedding = self.query_embedding_layer_norm(query_embedding) act += self.query_embedding_linear(query_embedding) return act
[docs] def build_template_pair_feat( self, feats: Dict[str, torch.Tensor], min_bin: float, max_bin: float, num_bins: int, inf: float, eps: float, dtype: torch.dtype, ) -> Dict[str, torch.Tensor]: # template_pseudo_beta = feats["template_pseudo_beta"] # template_pseudo_beta_mask = feats["template_pseudo_beta_mask"] template_aatype = feats["template_aatype"] template_all_atom_positions = feats["template_all_atom_positions"] template_all_atom_mask = feats["template_all_atom_mask"] if inductor.is_enabled(): pseudo_beta_fn = _pseudo_beta_fn_jit else: pseudo_beta_fn = _pseudo_beta_fn_eager template_pseudo_beta, template_pseudo_beta_mask = pseudo_beta_fn( aatype=template_aatype, all_atom_positions=template_all_atom_positions, all_atom_mask=template_all_atom_mask, ) self._initialize_buffers( min_bin=min_bin, max_bin=max_bin, num_bins=num_bins, inf=inf, device=template_pseudo_beta.device, ) if inductor.is_enabled(): compute_part1_fn = _compute_multimer_part1_jit else: compute_part1_fn = _compute_multimer_part1_eager dgram, aatype_one_hot = compute_part1_fn( template_pseudo_beta, template_aatype, self.lower, self.upper, self.c_aatype, ) # dgram, aa_one_hot n, ca, c = [rc.atom_order[a] for a in ["N", "CA", "C"]] rigids = Rigid.make_transform_from_reference( n_xyz=feats["template_all_atom_positions"][..., n, :], ca_xyz=feats["template_all_atom_positions"][..., ca, :], c_xyz=feats["template_all_atom_positions"][..., c, :], eps=eps, ) backbone_mask = template_all_atom_mask[..., n] * template_all_atom_mask[..., ca] * template_all_atom_mask[..., c] if inductor.is_enabled(): compute_unit_vector = _compute_multimer_part2_jit else: compute_unit_vector = _compute_multimer_part2_eager points = rigids.get_trans().unsqueeze(-3) rigid_vec = rigids.unsqueeze(-1).invert_apply(points) unit_vector = compute_unit_vector(rigid_vec, eps, template_all_atom_mask, n, ca, c) return { "template_dgram": dgram, "aatype_one_hot": aatype_one_hot, "pseudo_beta_mask": template_pseudo_beta_mask, "backbone_mask": backbone_mask, "unit_vector": unit_vector, }
def _compute_multimer_part1_eager( template_pseudo_beta: torch.Tensor, template_aatype: torch.Tensor, lower: torch.Tensor, upper: torch.Tensor, num_classes: int, ) -> Tuple[torch.Tensor, torch.Tensor]: dgram = torch.sum( input=(template_pseudo_beta.unsqueeze(-2) - template_pseudo_beta.unsqueeze(-3)) ** 2, dim=-1, keepdim=True, ) dgram = ((dgram > lower) * (dgram < upper)).to(dtype=dgram.dtype) aatype_one_hot = F.one_hot( template_aatype, num_classes=num_classes, ) return dgram, aatype_one_hot _compute_multimer_part1_jit = torch.compile(_compute_multimer_part1_eager) def _compute_multimer_part2_eager( rigid_vec: torch.Tensor, eps: float, t_aa_masks: torch.Tensor, n: int, ca: int, c: int, ) -> torch.Tensor: inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec**2, dim=-1)) template_mask = t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., c] template_mask_2d = template_mask.unsqueeze(-1) * template_mask.unsqueeze(-2) inv_distance_scalar = inv_distance_scalar * template_mask_2d unit_vector = rigid_vec * inv_distance_scalar.unsqueeze(-1) return unit_vector _compute_multimer_part2_jit = torch.compile(_compute_multimer_part2_eager) def _pseudo_beta_fn_eager( aatype: torch.Tensor, all_atom_positions: torch.Tensor, all_atom_mask: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: is_gly = torch.eq(aatype, rc.restype_order["G"]) ca_idx = rc.atom_order["CA"] cb_idx = rc.atom_order["CB"] pseudo_beta = torch.where( torch.tile(is_gly.unsqueeze(-1), [1] * is_gly.ndim + [3]), all_atom_positions[..., ca_idx, :], all_atom_positions[..., cb_idx, :], ) pseudo_beta_mask = torch.where( is_gly, all_atom_mask[..., ca_idx], all_atom_mask[..., cb_idx], ) return pseudo_beta, pseudo_beta_mask _pseudo_beta_fn_jit = torch.compile(_pseudo_beta_fn_eager)