Source code for deepfold.modules.alphafold

import logging
from dataclasses import asdict
from typing import Callable, Dict, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

import deepfold.common.residue_constants as rc
import deepfold.distributed as dist
import deepfold.distributed.model_parallel as mp
import deepfold.modules.inductor as inductor
from deepfold.config import AlphaFoldConfig
from deepfold.modules.auxiliary_heads import AuxiliaryHeads
from deepfold.modules.evoformer_stack import EvoformerStack
from deepfold.modules.extra_msa_embedder import ExtraMSAEmbedder
from deepfold.modules.extra_msa_stack import ExtraMSAStack
from deepfold.modules.input_embedder import InputEmbedder, InputEmbedderMultimer
from deepfold.modules.recycling_embedder import RecyclingEmbedder  # OpenFoldRecyclingEmbedder
from deepfold.modules.structure_module import StructureModule
from deepfold.modules.template_angle_embedder import TemplateAngleEmbedder
from deepfold.modules.template_pair_embedder import TemplatePairEmbedder, TemplatePairEmbedderMultimer
from deepfold.modules.template_pair_stack import TemplatePairStack
from deepfold.modules.template_pointwise_attention import TemplatePointwiseAttention
from deepfold.modules.template_projection import TemplateProjection
from deepfold.utils.tensor_utils import add, batched_gather, tensor_tree_map

logger = logging.getLogger(__name__)


[docs] class AlphaFold(nn.Module): """AlphaFold2 module. Supplementary '1.4 AlphaFold Inference': Algorithm 2. """ def __init__(self, config: AlphaFoldConfig) -> None: super().__init__() if not config.is_multimer: self.input_embedder = InputEmbedder( **asdict(config.input_embedder_config), ) else: self.input_embedder = InputEmbedderMultimer( **asdict(config.input_embedder_config), ) self.recycling_embedder = RecyclingEmbedder( **asdict(config.recycling_embedder_config), ) if config.templates_enabled: self.template_angle_embedder = TemplateAngleEmbedder( **asdict(config.template_angle_embedder_config), ) self.template_pair_stack = TemplatePairStack( **asdict(config.template_pair_stack_config), ) if not config.is_multimer: self.template_pair_embedder = TemplatePairEmbedder( **asdict(config.template_pair_embedder_config), ) self.template_pointwise_attention = TemplatePointwiseAttention( **asdict(config.template_pointwise_attention_config), ) else: self.template_pair_embedder = TemplatePairEmbedderMultimer( **asdict(config.template_pair_embedder_config), ) self.template_projection = TemplateProjection( **asdict(config.template_projection_config), ) self.extra_msa_embedder = ExtraMSAEmbedder( **asdict(config.extra_msa_embedder_config), ) self.extra_msa_stack = ExtraMSAStack( **asdict(config.extra_msa_stack_config), ) self.evoformer_stack = EvoformerStack( **asdict(config.evoformer_stack_config), ) self.structure_module = StructureModule( **asdict(config.structure_module_config), ) self.auxiliary_heads = AuxiliaryHeads(config.auxiliary_heads_config) self.config = config
[docs] def forward( self, batch: Dict[str, torch.Tensor], recycle_hook: Callable[[int, dict, dict], None] | None = None, save_all: bool = True, ) -> Dict[str, torch.Tensor]: # Initialize previous recycling embeddings: prevs = self._initialize_prevs(batch) # Asym id for multimer asym_id = None if "asym_id" in batch: # NOTE: Multimer asym_id = batch["asym_id"][..., -1].contiguous() # Forward iterations with autograd disabled: num_recycling_iters = batch["aatype"].shape[-1] - 1 recycle_iter = 0 for _ in range(num_recycling_iters): feats = tensor_tree_map(fn=lambda t: t[..., recycle_iter].contiguous(), tree=batch) with torch.no_grad(): outputs, prevs = self._forward_iteration( feats=feats, prevs=prevs, gradient_checkpointing=False, ) if recycle_hook is not None: # Inference aux_outputs = self.auxiliary_heads(outputs, feats["seq_mask"], asym_id) outputs.update(aux_outputs) recycle_hook(recycle_iter, feats, outputs) del outputs recycle_iter += 1 # For the last iteration # https://github.com/pytorch/pytorch/issues/65766 if torch.is_autocast_enabled(): torch.clear_autocast_cache() # Final iteration with autograd enabled: feats = tensor_tree_map(fn=lambda t: t[..., -1].contiguous(), tree=batch) outputs, prevs = self._forward_iteration( feats=feats, prevs=prevs, gradient_checkpointing=(self.training and mp.size() <= 1), ) del prevs outputs["msa"] = outputs["msa"].to(dtype=torch.float32) outputs["pair"] = outputs["pair"].to(dtype=torch.float32) outputs["single"] = outputs["single"].to(dtype=torch.float32) # Run auxiliary heads: aux_outputs = self.auxiliary_heads(outputs, feats["seq_mask"], asym_id) outputs.update(aux_outputs) if recycle_hook is not None: # Inference recycle_hook(recycle_iter, feats, outputs) if not save_all: outputs.pop("msa", None) outputs.pop("pair", None) return outputs
def _forward_iteration( self, feats: Dict[str, torch.Tensor], prevs: Dict[str, torch.Tensor], gradient_checkpointing: bool, ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: outputs = {} batch_dims = feats["target_feat"].shape[:-2] num_batch_dims = len(batch_dims) num_res = feats["target_feat"].shape[-2] num_clust = feats["msa_feat"].shape[-3] # 1 # Whehter the model uses in-place operations: inplace_safe = not (self.training or torch.is_grad_enabled()) # inplace_safe = False seq_mask = feats["seq_mask"] # seq_mask: [batch, N_res] pair_mask = seq_mask.unsqueeze(-1) * seq_mask.unsqueeze(-2) # outer product # pair_mask: [batch, N_res, N_res] msa_mask = feats["msa_mask"] # msa_mask: [batch, N_clust, N_res] # Initialize MSA and pair representations: if not self.config.is_multimer: m, z = self.input_embedder( target_feat=feats["target_feat"], residue_index=feats["residue_index"], msa_feat=feats["msa_feat"], inplace_safe=inplace_safe, ) else: m, z = self.input_embedder( target_feat=feats["target_feat"], residue_index=feats["residue_index"], msa_feat=feats["msa_feat"], asym_id=feats["asym_id"], entity_id=feats["entity_id"], sym_id=feats["sym_id"], inplace_safe=inplace_safe, ) # m: [batch, N_clust, N_res, c_m] # z: [batch, N_res, N_res, c_z] # Extract recycled representations: m0_prev = prevs.pop("m0_prev") z_prev = prevs.pop("z_prev") x_prev = prevs.pop("x_prev") x_prev = _pseudo_beta( aatype=feats["aatype"], all_atom_positions=x_prev, dtype=z.dtype, ) m, z = self.recycling_embedder( m=m, z=z, m0_prev=m0_prev, z_prev=z_prev, x_prev=x_prev, inplace_safe=inplace_safe, ) del m0_prev, z_prev, x_prev # Embed templates and merge with MSA/pair representation: if self.config.templates_enabled: template_feats = {k: t for k, t in feats.items() if k.startswith("template_")} template_embeds = self._embed_templates( feats=template_feats, z=z, pair_mask=pair_mask, asym_id=feats["asym_id"] if self.config.is_multimer else None, gradient_checkpointing=gradient_checkpointing, multichain_mask_2d=feats.get("template_multichain_mask_2d", None), inplace_safe=inplace_safe, ) # multichain_mask_2d: [batch, N_res, N_res, N_templ] t = template_embeds["template_pair_embedding"] z = add(z, t, inplace_safe) # z: [batch, N_res, N_res, c_z] del t if self.config.embed_template_torsion_angles: m = torch.cat([m, template_embeds["template_angle_embedding"]], dim=-3) # m: [batch, N_seq, N_res, c_m] if not self.config.is_multimer: msa_mask = torch.cat( [ feats["msa_mask"], feats["template_torsion_angles_mask"][..., 2], ], dim=-2, ) # msa_mask: [batch, N_seq, N_res] else: msa_mask = torch.cat( [ feats["msa_mask"], template_embeds["template_mask"], ], dim=-2, ) del template_feats, template_embeds # num_seq = m.shape[1] # Embed extra MSA features and merge with pairwise embeddings: if self.config.is_multimer: extra_msa_fn = _build_extra_msa_feat_multimer else: extra_msa_fn = _build_extra_msa_feat # N_extra_seq = feats["extra_msa"].shape[1] a = self.extra_msa_embedder(extra_msa_fn(feats)) # a: [batch, N_extra_seq, N_res, c_e] z = self.extra_msa_stack( m=a, z=z, msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype), pair_mask=pair_mask.to(dtype=m.dtype), gradient_checkpointing=gradient_checkpointing, inplace_safe=inplace_safe, ) # z: [batch, N_res, N_res, c_z] del a # Evoformer forward pass: m, z, s = self.evoformer_stack( m=m, z=z, msa_mask=msa_mask.to(dtype=m.dtype), pair_mask=pair_mask.to(dtype=z.dtype), gradient_checkpointing=gradient_checkpointing, inplace_safe=inplace_safe, ) # m: [batch, N_seq, N_res, c_m] # z: [batch, N_res, N_res, c_z] # s: [batch, N_res, c_s] outputs["msa"] = m[:, :num_clust] outputs["pair"] = z outputs["single"] = s # Predict 3D structure: sm_outputs = self.structure_module( s=outputs["single"].to(dtype=torch.float32), z=outputs["pair"].to(dtype=torch.float32), mask=feats["seq_mask"].to(dtype=s.dtype), aatype=feats["aatype"], inplace_safe=inplace_safe, ) outputs.update(sm_outputs) outputs["final_atom_positions"] = _atom14_to_atom37( atom14_positions=outputs["sm_positions"][:, -1], residx_atom37_to_atom14=feats["residx_atom37_to_atom14"], atom37_atom_exists=feats["atom37_atom_exists"], ) outputs["final_atom_mask"] = feats["atom37_atom_exists"].to(dtype=outputs["final_atom_positions"].dtype) outputs["final_affine_tensor"] = outputs["sm_frames"][:, -1] # Save embeddings for next recycling iteration: prevs = {} prevs["m0_prev"] = m[:, 0] prevs["z_prev"] = outputs["pair"] prevs["x_prev"] = outputs["final_atom_positions"] return outputs, prevs def _initialize_prevs( self, batch: Dict[str, torch.Tensor], ) -> Dict[str, torch.Tensor]: prevs = {} batch_size = batch["aatype"].shape[0] num_res = batch["aatype"].shape[1] c_m = self.input_embedder.c_m c_z = self.input_embedder.c_z device = batch["msa_feat"].device dtype = batch["msa_feat"].dtype prevs["m0_prev"] = torch.zeros( size=[batch_size, num_res, c_m], device=device, dtype=dtype, ) prevs["z_prev"] = torch.zeros( size=[batch_size, num_res, num_res, c_z], device=device, dtype=dtype, ) prevs["x_prev"] = torch.zeros( size=[batch_size, num_res, rc.atom_type_num, 3], device=device, dtype=torch.float32, ) return prevs def _embed_templates( self, feats: Dict[str, torch.Tensor], z: torch.Tensor, pair_mask: torch.Tensor, gradient_checkpointing: bool, asym_id: Optional[torch.Tensor] = None, multichain_mask_2d: Optional[torch.Tensor] = None, # [..., N_res, N_res, N_templ] inplace_safe: bool = False, ) -> Dict[str, torch.Tensor]: # Embed the templates one at a time: num_res = z.shape[-2] num_templ = feats["template_aatype"].shape[-2] # 1 if inplace_safe: t_pair = z.new_ones( z.shape[:-3] + ( num_templ, num_res, num_res, self.config.template_pair_embedder_config.c_t, ) ) else: pair_embeds = [] for i in range(num_templ): single_template_feats = tensor_tree_map(fn=lambda t: t[:, i], tree=feats) if multichain_mask_2d is not None: single_multichain_mask_2d = multichain_mask_2d[..., i] else: single_multichain_mask_2d = None if not self.config.is_multimer: t = self.template_pair_embedder.build_template_pair_feat( feats=single_template_feats, min_bin=self.config.template_pair_feat_distogram_min_bin, max_bin=self.config.template_pair_feat_distogram_max_bin, num_bins=self.config.template_pair_feat_distogram_num_bins, use_unit_vector=self.config.template_pair_feat_use_unit_vector, inf=self.config.template_pair_feat_inf, eps=self.config.template_pair_feat_eps, dtype=z.dtype, ) t = self.template_pair_embedder(t) # t: [batch, N_res, N_res, c_t] else: assert asym_id is not None if single_multichain_mask_2d is None: single_multichain_mask_2d = asym_id[..., :, None] == asym_id[..., None, :] # single_multichain_mask_2d: [batch, N_res, N_res] t = self.template_pair_embedder.build_template_pair_feat( feats=single_template_feats, min_bin=self.config.template_pair_feat_distogram_min_bin, max_bin=self.config.template_pair_feat_distogram_max_bin, num_bins=self.config.template_pair_feat_distogram_num_bins, inf=self.config.template_pair_feat_inf, eps=self.config.template_pair_feat_eps, dtype=z.dtype, ) t = self.template_pair_embedder( query_embedding=z, multichain_mask_2d=single_multichain_mask_2d, **t, ) # t: [batch, N_res, N_res, c_t] if inplace_safe: t_pair[..., i, :, :, :] = t else: pair_embeds.append(t) del t if not inplace_safe: t_pair = torch.stack(pair_embeds, dim=-4) # 1 # t_pair: [batch, N_templ, N_res, N_res, c_t] del pair_embeds t = self.template_pair_stack( t=t_pair, mask=pair_mask.to(dtype=z.dtype), gradient_checkpointing=gradient_checkpointing, inplace_safe=inplace_safe, ) # t: [batch, N_templ, N_res, N_res, c_t] del t_pair if self.config.is_multimer: t = self.template_projection(t=t) else: t = self.template_pointwise_attention( t=t, z=z, template_mask=feats["template_mask"].to(dtype=z.dtype), ) t = _apply_template_mask( t=t, template_mask=feats["template_mask"], inplace_safe=inplace_safe, ) # t: [batch, N_res, N_res, c_z] template_embeds = {} template_embeds["template_pair_embedding"] = t if self.config.embed_template_torsion_angles: if self.config.is_multimer: template_angle_feat, template_mask = _build_template_angle_feat_multimer(feats) template_embeds["template_mask"] = template_mask else: template_angle_feat = _build_template_angle_feat(feats) a = self.template_angle_embedder(template_angle_feat) # a: [batch, N_templ, N_res, c_m] template_embeds["template_angle_embedding"] = a return template_embeds
[docs] def register_dap_gradient_scaling_hooks(self, dap_size: int) -> None: num_registered_hooks = { "evoformer_stack": 0, "extra_msa_stack": 0, "template_pair_stack": 0, } evoformer_stack = self.evoformer_stack for name, param in evoformer_stack.named_parameters(): if name.startswith("blocks."): param.register_hook(lambda grad: grad * dap_size) num_registered_hooks["evoformer_stack"] += 1 for name, param in self.extra_msa_stack.named_parameters(): if name.startswith("blocks."): param.register_hook(lambda grad: grad * dap_size) num_registered_hooks["extra_msa_stack"] += 1 if hasattr(self, "template_pair_stack"): for name, param in self.template_pair_stack.named_parameters(): if name.startswith("blocks."): param.register_hook(lambda grad: grad * dap_size) num_registered_hooks["template_pair_stack"] += 1 if dist.is_main_process(): logger.info("register_dap_gradient_scaling_hooks: " f"num_registered_hooks={num_registered_hooks}")
def _pseudo_beta_eager( aatype: torch.Tensor, all_atom_positions: torch.Tensor, dtype: torch.dtype, ) -> torch.Tensor: is_gly = torch.eq(aatype, rc.restype_order["G"]) ca_idx = rc.atom_order["CA"] cb_idx = rc.atom_order["CB"] pseudo_beta = torch.where( torch.tile(is_gly.unsqueeze(-1), [1] * is_gly.ndim + [3]), all_atom_positions[..., ca_idx, :], all_atom_positions[..., cb_idx, :], ) return pseudo_beta.to(dtype=dtype) _pseudo_beta_jit = torch.compile(_pseudo_beta_eager) def _pseudo_beta( aatype: torch.Tensor, all_atom_positions: torch.Tensor, dtype, ) -> torch.Tensor: if inductor.is_enabled(): pseudo_beta_fn = _pseudo_beta_jit else: pseudo_beta_fn = _pseudo_beta_eager return pseudo_beta_fn( aatype, all_atom_positions, dtype, ) def _apply_template_mask_eager( t: torch.Tensor, template_mask: torch.Tensor, inplace_safe: bool, ) -> torch.Tensor: t_mask = (torch.sum(template_mask, dim=1) > 0).to(dtype=t.dtype) t_mask = t_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3) if inplace_safe: t *= t_mask else: t = t * t_mask return t _apply_template_mask_jit = torch.compile(_apply_template_mask_eager) def _apply_template_mask( t: torch.Tensor, template_mask: torch.Tensor, inplace_safe: bool, ) -> torch.Tensor: if inductor.is_enabled(): apply_template_mask_fn = _apply_template_mask_jit else: apply_template_mask_fn = _apply_template_mask_eager return apply_template_mask_fn(t, template_mask, inplace_safe) def _build_extra_msa_feat_eager( extra_msa: torch.Tensor, extra_has_deletion: torch.Tensor, extra_deletion_value: torch.Tensor, num_classes: int, ) -> torch.Tensor: msa_1hot = F.one_hot(extra_msa, num_classes=num_classes) msa_feat = [ msa_1hot, extra_has_deletion.unsqueeze(-1), extra_deletion_value.unsqueeze(-1), ] return torch.cat(msa_feat, dim=-1) _build_extra_msa_feat_jit = torch.compile(_build_extra_msa_feat_eager) def _build_extra_msa_feat(feats: Dict[str, torch.Tensor]) -> torch.Tensor: if inductor.is_enabled(): build_extra_msa_feat_fn = _build_extra_msa_feat_jit else: build_extra_msa_feat_fn = _build_extra_msa_feat_eager return build_extra_msa_feat_fn( extra_msa=feats["extra_msa"], extra_has_deletion=feats["extra_has_deletion"], extra_deletion_value=feats["extra_deletion_value"], num_classes=23, ) def _build_extra_msa_feat_multimer_eager( extra_msa: torch.Tensor, extra_deletion_matrix: torch.Tensor, num_classes: int, ) -> torch.Tensor: msa_1hot = F.one_hot(extra_msa, num_classes=num_classes) has_deletion = torch.clamp(extra_deletion_matrix, min=0.0, max=1.0)[..., None] deletion_value = (torch.atan(extra_deletion_matrix / 3.0) * (2.0 / torch.pi))[..., None] return torch.cat([msa_1hot, has_deletion, deletion_value], dim=-1) _build_extra_msa_feat_multimer_jit = torch.compile(_build_extra_msa_feat_multimer_eager) def _build_extra_msa_feat_multimer(feats: Dict[str, torch.Tensor]) -> torch.Tensor: # 23 = 20 amino acids + 'X' for unknown + gap + bert mask if inductor.is_enabled(): build_extra_msa_feat_fn = _build_extra_msa_feat_multimer_eager else: build_extra_msa_feat_fn = _build_extra_msa_feat_multimer_jit return build_extra_msa_feat_fn( extra_msa=feats["extra_msa"], extra_deletion_matrix=feats["extra_deletion_matrix"], num_classes=23, ) def _build_template_angle_feat(feats: Dict[str, torch.Tensor]) -> torch.Tensor: template_aatype = feats["template_aatype"] torsion_angles_sin_cos = feats["template_torsion_angles_sin_cos"] alt_torsion_angles_sin_cos = feats["template_alt_torsion_angles_sin_cos"] torsion_angles_mask = feats["template_torsion_angles_mask"] template_angle_feat = torch.cat( [ F.one_hot(template_aatype, num_classes=22), torsion_angles_sin_cos.reshape(*torsion_angles_sin_cos.shape[:-2], 14), alt_torsion_angles_sin_cos.reshape(*alt_torsion_angles_sin_cos.shape[:-2], 14), torsion_angles_mask, ], dim=-1, ) return template_angle_feat def _build_template_angle_feat_multimer(feats: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: template_aatype = feats["template_aatype"] template_all_atom_positions = feats["template_all_atom_positions"] template_all_atom_mask = feats["template_all_atom_mask"] dtype = template_all_atom_positions.dtype template_chi_angles, template_chi_mask = _compute_chi_angles( positions=template_all_atom_positions, mask=template_all_atom_mask, aatype=template_aatype, ) template_angle_feat = torch.cat( [ F.one_hot(template_aatype, num_classes=22), torch.sin(template_chi_angles) * template_chi_mask, torch.cos(template_chi_angles) * template_chi_mask, template_chi_mask, ], dim=-1, ).to(dtype) # NOTE: Multimer model gets `template_mask` from the angle features. template_mask = template_chi_mask[..., 0].to(dtype=dtype) return template_angle_feat, template_mask def _atom14_to_atom37( atom14_positions: torch.Tensor, residx_atom37_to_atom14: torch.Tensor, atom37_atom_exists: torch.Tensor, ) -> torch.Tensor: # atom14_positions: [batch, N_res, 14, 3] # residx_atom37_to_atom14: [batch, N_res, 37] # atom37_atom_exists: [batch, N_res, 37] indices = residx_atom37_to_atom14.unsqueeze(-1) # indices: [batch, N_res, 37, 1] indices = indices.expand(-1, -1, -1, 3) # indices: [batch, N_res, 37, 3] atom37_positions = torch.gather(atom14_positions, 2, indices) # atom37_positions: [batch, N_res, 37, 3] atom37_mask = atom37_atom_exists.unsqueeze(-1) # atom37_mask: [batch, N_res, 37, 1] atom37_positions = atom37_positions * atom37_mask # atom37_positions: [batch, N_res, 37, 3] return atom37_positions def _compute_chi_angles( positions: torch.Tensor, mask: torch.Tensor, aatype: torch.Tensor, chi_atom_indices: Optional[torch.Tensor] = None, chi_angles_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute the chi angles given all atom positions and the amino acid type. Args: positions: [..., 37, 3] Atom positions in atom37 format. atom_mask: [..., 37] Atom mask. aatype: [...] Amino acid type integer code. Returns: chi_angles: [batch, N_res, 4]. chi_mask: [batch, N_res, 4]. """ assert positions.shape[-2] == rc.atom_type_num assert mask.shape[-1] == rc.atom_type_num num_batch_dims = aatype.ndim if chi_atom_indices is None: chi_atom_indices = positions.new_tensor(rc.CHI_ATOM_INDICES, dtype=torch.int64) # chi_atom_indices: [restype=21, chis=4, atoms=4] # Remove gaps aatype = torch.clamp(aatype, max=20) # Select atoms to compute chis. atom_indices = chi_atom_indices[..., aatype, :, :] # atom_indices: [batch, N_res, chis=4, atoms=4] x, y, z = torch.unbind(positions, dim=-1) x = batched_gather(x, atom_indices, -1, num_batch_dims).unsqueeze(-1) y = batched_gather(y, atom_indices, -1, num_batch_dims).unsqueeze(-1) z = batched_gather(z, atom_indices, -1, num_batch_dims).unsqueeze(-1) xyz = torch.cat([x, y, z], dim=-1) a, b, c, d = torch.unbind(xyz, dim=-2) chi_angles = _dihedral_angle(a, b, c, d) if chi_angles_mask is None: chi_angles_mask = list(rc.chi_angles_mask) chi_angles_mask.append([0.0, 0.0, 0.0, 0.0]) # UNK chi_angles_mask = mask.new_tensor(chi_angles_mask) chi_mask = chi_angles_mask[aatype, :] # chi_mask[batch, N_res, chi=4] chi_angle_atoms_mask = batched_gather(mask, atom_indices, -1, num_batch_dims) chi_angle_atoms_mask = torch.prod(chi_angle_atoms_mask, -1, dtype=chi_angle_atoms_mask.dtype) chi_mask = chi_mask * chi_angle_atoms_mask return chi_angles, chi_mask def _dihedral_angle( a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, d: torch.Tensor, ) -> torch.Tensor: """Computes torsion angle for a quadruple of points. For points (a, b, c, d), this is the angle between the planes defined by points (a, b, c) and (b, c, d). It is also known as the dihedral angle. """ v1 = a - b v2 = b - c v3 = d - c c1 = v1.cross(v2, dim=-1) c2 = v3.cross(v2, dim=-1) c3 = c2.cross(c1, dim=-1) first = torch.einsum("...i,...i", c3, v2) v2_mag = v2.norm(dim=-1) second = v2_mag * torch.einsum("...i,...i", c1, c2) return torch.atan2(first, second)