Source code for deepfold.modules.structure_module

from typing import Dict, List

import torch
import torch.nn as nn
import torch.nn.functional as F

import deepfold.common.residue_constants as rc
from deepfold.modules.angle_resnet import AngleResnet
from deepfold.modules.backbone_update import BackboneUpdate
from deepfold.modules.invariant_point_attention import InvariantPointAttention, InvariantPointAttentionMultimer
from deepfold.modules.layer_norm import LayerNorm
from deepfold.modules.linear import Linear
from deepfold.modules.single_transition import SingleTransition
from deepfold.utils.geometry import Rigid3Array
from deepfold.utils.geometry.quat_rigid import QuatRigid
from deepfold.utils.rigid_utils import Rigid, Rotation
from deepfold.utils.tensor_utils import add


[docs] class StructureModule(nn.Module): """Structure Module. Supplementary '1.8 Structure module': Algorithm 20. Args: c_s: Single representation dimension (channels). c_z: Pair representation dimension (channels). c_hidden_ipa: Hidden dimension in invariant point attention. c_hidden_ang_res: Hidden dimension in angle resnet. num_heads_ipa: Number of heads used in invariant point attention. num_qk_points: Number of query/key points in invariant point attention. num_v_points: Number of value points in invariant point attention. dropout_rate: Dropout rate in structure module. num_blocks: Number of shared blocks in the forward pass. num_ang_res_blocks: Number of blocks in angle resnet. num_angles: Number of angles in angle resnet. scale_factor: Scale translation factor. inf: Safe infinity value. eps: Epsilon to prevent division by zero. """ def __init__( self, c_s: int, c_z: int, c_hidden_ipa: int, c_hidden_ang_res: int, num_heads_ipa: int, num_qk_points: int, num_v_points: int, is_multimer: bool, dropout_rate: float, num_blocks: int, num_ang_res_blocks: int, num_angles: int, scale_factor: float, inf: float, eps: float, ) -> None: super(StructureModule, self).__init__() self.c_s = c_s self.c_z = c_z self.c_hidden_ipa = c_hidden_ipa self.c_hidden_ang_res = c_hidden_ang_res self.num_heads_ipa = num_heads_ipa self.num_qk_points = num_qk_points self.num_v_points = num_v_points self.dropout_rate = dropout_rate self.num_blocks = num_blocks self.num_ang_res_blocks = num_ang_res_blocks self.num_angles = num_angles self.scale_factor = scale_factor self.is_multimer = is_multimer # is_multimer self.inf = inf self.eps = eps self.layer_norm_s = LayerNorm(c_s) self.layer_norm_z = LayerNorm(c_z) self.linear_in = Linear(c_s, c_s, bias=True, init="default") if not self.is_multimer: self.ipa = InvariantPointAttention( c_s=c_s, c_z=c_z, c_hidden=c_hidden_ipa, num_heads=num_heads_ipa, num_qk_points=num_qk_points, num_v_points=num_v_points, separate_kv=self.is_multimer, # is_multimer inf=inf, eps=eps, ) else: self.ipa = InvariantPointAttentionMultimer( c_s=c_s, c_z=c_z, c_hidden=c_hidden_ipa, num_heads=num_heads_ipa, num_qk_points=num_qk_points, num_v_points=num_v_points, inf=inf, eps=eps, ) self.ipa_dropout = nn.Dropout(dropout_rate) self.layer_norm_ipa = LayerNorm(c_s) self.transition = SingleTransition( c_s=c_s, dropout_rate=dropout_rate, ) if self.is_multimer: # is_multimer self.bb_update = QuatRigid(c_hidden=c_s, full_quat=False) else: self.bb_update = BackboneUpdate(c_s=c_s) self.angle_resnet = AngleResnet( c_s=c_s, c_hidden=c_hidden_ang_res, num_blocks=num_ang_res_blocks, num_angles=num_angles, eps=eps, ) # buffers: # self.default_frames # self.group_idx # self.atom_mask # self.lit_positions
[docs] def forward( self, s: torch.Tensor, z: torch.Tensor, mask: torch.Tensor, aatype: torch.Tensor, inplace_safe: bool, ) -> Dict[str, torch.Tensor]: if self.is_multimer: # is_multimer outputs = self._forward_multimer(s, z, mask, aatype, inplace_safe) else: outputs = self._forward_monomer(s, z, mask, aatype, inplace_safe) return outputs
def _forward_monomer( self, s: torch.Tensor, z: torch.Tensor, mask: torch.Tensor, aatype: torch.Tensor, inplace_safe: bool, ) -> Dict[str, torch.Tensor]: """Structure Module forward pass. Args: s: [batch, N_res, c_s] single representation z: [batch, N_res, N_res, c_z] pair representation mask: [batch, N_res] sequence mask aatype: [batch, N_res] amino acid indices Returns: dictionary containing: "sm_frames": [batch, num_blocks, N_res, 7] "sm_sidechain_frames": [batch, num_blocks, N_res, 8, 4, 4] "sm_unnormalized_angles": [batch, num_blocks, N_res, 7, 2] "sm_angles": [batch, num_blocks, N_res, 7, 2] "sm_positions": [batch, num_blocks, N_res, 14, 3] "sm_states": [batch, num_blocks, N_res, c_s] "sm_single": [batch, N_res, c_s] updated single representation """ self._initialize_buffers(dtype=s.dtype, device=s.device) s = self.layer_norm_s(s) # s: [batch, N_res, c_s] z = self.layer_norm_z(z) # z: [batch, N_res, N_res, c_z] s_initial = s # s_initial: [batch, N_res, c_s] s = self.linear_in(s) # s: [batch, N_res, c_s] rigids = Rigid.identity( shape=s.shape[:-1], dtype=s.dtype, device=s.device, requires_grad=self.training, fmt="quat", ) outputs = [] for _ in range(self.num_blocks): s = add(s, self.ipa(s=s, z=z, r=rigids, mask=mask, inplace_safe=inplace_safe), inplace_safe) s = self.ipa_dropout(s) s = self.layer_norm_ipa(s) # s: [batch, N_res, c_s] s = self.transition(s, inplace_safe) # s: [batch, N_res, c_s] rigids = rigids.compose_q_update_vec(self.bb_update(s)) # To hew as closely as possible to AlphaFold, we convert our # quaternion-based transformations to rotation-matrix ones here backb_to_global = Rigid( rots=Rotation(rot_mats=rigids.get_rots().get_rot_mats(), quats=None), trans=rigids.get_trans(), ) backb_to_global = backb_to_global.scale_translation(self.scale_factor) unnormalized_angles, angles = self.angle_resnet(s=s, s_initial=s_initial) # unnormalized_angles: [batch, N_res, num_angles, 2] # angles: [batch, N_res, num_angles, 2] all_frames_to_global = _torsion_angles_to_frames( r=backb_to_global, alpha=angles, aatype=aatype, rrgdf=self.default_frames, ) pred_xyz = _frames_and_literature_positions_to_atom14_pos( r=all_frames_to_global, aatype=aatype, default_frames=self.default_frames, group_idx=self.group_idx, atom_mask=self.atom_mask, lit_positions=self.lit_positions, ) # pred_xyz: [batch, N_res, 14, 3] scaled_rigids = rigids.scale_translation(self.scale_factor) output = { "sm_frames": scaled_rigids.to_tensor_7(), "sm_sidechain_frames": all_frames_to_global.to_tensor_4x4(), "sm_unnormalized_angles": unnormalized_angles, "sm_angles": angles, "sm_positions": pred_xyz, "sm_states": s, } outputs.append(output) rigids = rigids.stop_rot_gradient() outputs = self._stack_tensors(outputs, dim=1) outputs["sm_single"] = s return outputs def _initialize_buffers(self, dtype: torch.dtype, device: torch.device) -> None: if not hasattr(self, "default_frames"): self.register_buffer( "default_frames", torch.tensor( rc.restype_rigid_group_default_frame, dtype=dtype, device=device, requires_grad=False, ), persistent=False, ) if not hasattr(self, "group_idx"): self.register_buffer( "group_idx", torch.tensor( rc.restype_atom14_to_rigid_group, device=device, requires_grad=False, ), persistent=False, ) if not hasattr(self, "atom_mask"): self.register_buffer( "atom_mask", torch.tensor( rc.restype_atom14_mask, dtype=dtype, device=device, requires_grad=False, ), persistent=False, ) if not hasattr(self, "lit_positions"): self.register_buffer( "lit_positions", torch.tensor( rc.restype_atom14_rigid_group_positions, dtype=dtype, device=device, requires_grad=False, ), persistent=False, ) def _stack_tensors( self, outputs: List[Dict[str, torch.Tensor]], dim: int, ) -> Dict[str, torch.Tensor]: stacked = {} for key in list(outputs[0].keys()): stacked[key] = torch.stack( tensors=[output[key] for output in outputs], dim=dim, ) return stacked def _forward_multimer( self, s: torch.Tensor, z: torch.Tensor, mask: torch.Tensor, aatype: torch.Tensor, inplace_safe: bool, ): self._initialize_buffers(dtype=s.dtype, device=s.device) # [*, N, C_s] s = self.layer_norm_s(s) # [*, N, N, C_z] z = self.layer_norm_z(z) # [*, N, C_s] s_initial = s s = self.linear_in(s) # [*, N] rigids = Rigid3Array.identity(s.shape[:-1], s.device) outputs = [] for _ in range(self.num_blocks): # [*, N, C_s] s = add(s, self.ipa(s=s, z=z, r=rigids, mask=mask, inplace_safe=inplace_safe), inplace_safe) s = self.ipa_dropout(s) s = self.layer_norm_ipa(s) s = self.transition(s, inplace_safe) # [*, N] rigids = rigids @ self.bb_update(s) # [*, N, 7, 2] unnormalized_angles, angles = self.angle_resnet(s, s_initial) all_frames_to_global = _torsion_angles_to_frames( r=rigids.scale_translation(self.scale_factor), alpha=angles, aatype=aatype, rrgdf=self.default_frames, ) pred_xyz = _frames_and_literature_positions_to_atom14_pos( r=all_frames_to_global, aatype=aatype, default_frames=self.default_frames, group_idx=self.group_idx, atom_mask=self.atom_mask, lit_positions=self.lit_positions, ) output = { "sm_frames": rigids.scale_translation(self.scale_factor).to_tensor(), "sm_sidechain_frames": all_frames_to_global.to_tensor_4x4(), "sm_unnormalized_angles": unnormalized_angles, "sm_angles": angles, "sm_positions": pred_xyz, "sm_states": s, } outputs.append(output) rigids = rigids.stop_rot_gradient() outputs = self._stack_tensors(outputs, dim=1) outputs["sm_single"] = s return outputs
def _torsion_angles_to_frames( r: Rigid | Rigid3Array, alpha: torch.Tensor, aatype: torch.Tensor, rrgdf: torch.Tensor, ) -> Rigid | Rigid3Array: rigid_type = type(r) # [*, N, 8, 4, 4] default_4x4 = rrgdf[aatype, ...] # [*, N, 8] transformations, i.e. # One [*, N, 8, 3, 3] rotation matrix and # One [*, N, 8, 3] translation matrix default_r = rigid_type.from_tensor_4x4(default_4x4) bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2)) bb_rot[..., 1] = 1 # [*, N, 8, 2] alpha = torch.cat([bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2) # [*, N, 8, 3, 3] # Produces rotation matrices of the form: # [ # [1, 0 , 0 ], # [0, a_2,-a_1], # [0, a_1, a_2] # ] # This follows the original code rather than the supplement, # which uses different indices. # all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape) all_rots = alpha.new_zeros(default_r.shape + (4, 4)) all_rots[..., 0, 0] = 1 all_rots[..., 1, 1] = alpha[..., 1] all_rots[..., 1, 2] = -alpha[..., 0] all_rots[..., 2, 1:3] = alpha # all_rots = Rigid(Rotation(rot_mats=all_rots), None) all_rots = rigid_type.from_tensor_4x4(all_rots) all_frames = default_r.compose(all_rots) chi2_frame_to_frame = all_frames[..., 5] chi3_frame_to_frame = all_frames[..., 6] chi4_frame_to_frame = all_frames[..., 7] chi1_frame_to_bb = all_frames[..., 4] chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame) chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame) chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame) all_frames_to_bb = rigid_type.cat( [ all_frames[..., :5], chi2_frame_to_bb.unsqueeze(-1), chi3_frame_to_bb.unsqueeze(-1), chi4_frame_to_bb.unsqueeze(-1), ], dim=-1, ) # all_frames_to_global = r[..., None].compose_jit(all_frames_to_bb) all_frames_to_global = r[..., None].compose(all_frames_to_bb) return all_frames_to_global def _frames_and_literature_positions_to_atom14_pos( r: Rigid, aatype: torch.Tensor, default_frames: torch.Tensor, group_idx: torch.Tensor, atom_mask: torch.Tensor, lit_positions: torch.Tensor, ) -> torch.Tensor: # [*, N, 14, 4, 4] # default_4x4 = default_frames[aatype, ...] # [*, N, 14] group_mask = group_idx[aatype, ...] # [*, N, 14, 8] group_mask = F.one_hot( group_mask, num_classes=default_frames.shape[-3], ) # [*, N, 14, 8] t_atoms_to_global = r[..., None, :] * group_mask # [*, N, 14] t_atoms_to_global = t_atoms_to_global.map_tensor_fn(lambda x: torch.sum(x, dim=-1)) # [*, N, 14, 1] atom_mask = atom_mask[aatype, ...].unsqueeze(-1) # [*, N, 14, 3] lit_positions = lit_positions[aatype, ...] pred_positions = t_atoms_to_global.apply(lit_positions) pred_positions = pred_positions * atom_mask return pred_positions