Source code for deepfold.losses.auxillary

import torch

from deepfold.common import residue_constants as rc
from deepfold.losses.utils import sigmoid_cross_entropy
from deepfold.utils.tensor_utils import masked_mean


[docs] def experimentally_resolved_loss( logits: torch.Tensor, atom37_atom_exists: torch.Tensor, all_atom_mask: torch.Tensor, resolution: torch.Tensor, min_resolution: float, max_resolution: float, eps: float = 1e-8, ) -> torch.Tensor: """Predicts if an atom is experimentally resolved in a high-res structure. Jumper et al. (2021) Suppl. Sec. 1.9.10 '"Experimentally resolved" prediction' Args: logits: logits of shape [*, N_res, 37]. Log probability that an atom is resolved in atom37 representation, can be converted to probability by applying sigmoid. atom37_atom_exists: labels of shape [*, N_res, 37] all_atom_mask: mask of shape [*, N_res, 37] resolution: resolution of each example of shape [*] NOTE: This loss is used during fine-tuning on high-resolution X-ray crystals and cryo-EM structures resolution better than 0.3 nm. NMR and distillation examples have zero resolution. """ errors = sigmoid_cross_entropy(logits=logits, labels=all_atom_mask) loss = torch.sum(errors * atom37_atom_exists, dim=-1) loss = loss / (eps + torch.sum(atom37_atom_exists, dim=(-1, -2)).view(-1, 1)) loss = torch.sum(loss, dim=-1) loss = loss * ((resolution >= min_resolution) & (resolution <= max_resolution)) return loss
[docs] def repr_norm_loss( msa_norm: torch.Tensor, pair_norm: torch.Tensor, msa_mask: torch.Tensor, pseudo_beta_mask: torch.Tensor, eps=1e-5, tolerance=0.0, ) -> torch.Tensor: """Representation norm loss of Uni-Fold.""" def norm_loss(x): max_norm = x.shape[-1] ** 0.5 norm = torch.sqrt(torch.sum(x**2, dim=-1) + eps) error = torch.nn.functional.relu((norm - max_norm).abs() - tolerance) return error pair_norm_error = norm_loss(pair_norm.float()) msa_norm_error = norm_loss(msa_norm.float()) pair_mask = pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :] pair_norm_loss = masked_mean(pair_mask.float(), pair_norm_error, dim=(-1, -2)) msa_norm_loss = masked_mean(msa_mask.float(), msa_norm_error, dim=(-1, -2)) loss = pair_norm_loss + msa_norm_loss return loss
[docs] def get_asym_mask(asym_id): """Get the mask for each asym_id. [*, NR] -> [*, NC, NR]""" # this func presumes that valid asym_id ranges [1, NC] and is dense. asym_type = torch.arange(1, torch.amax(asym_id) + 1, device=asym_id.device) # [NC] return (asym_id[..., None, :] == asym_type[:, None]).float()
[docs] def chain_centre_mass_loss( pred_atom_positions: torch.Tensor, true_atom_positions: torch.Tensor, atom_mask: torch.Tensor, asym_id: torch.Tensor, eps: float = 1e-10, ) -> torch.Tensor: ca_pos = rc.atom_order["CA"] pred_atom_positions = pred_atom_positions[..., ca_pos, :].float() # [B, NR, 3] true_atom_positions = true_atom_positions[..., ca_pos, :].float() # [B, NR, 3] atom_mask = atom_mask[..., ca_pos].bool() # [B, NR] assert len(pred_atom_positions.shape) == 3 asym_mask = get_asym_mask(asym_id) * atom_mask[..., None, :] # [B, NC, NR] asym_exists = torch.any(asym_mask, dim=-1).float() # [B, NC] def get_asym_centres(pos): pos = pos[..., None, :, :] * asym_mask[..., :, :, None] # [B, NC, NR, 3] return torch.sum(pos, dim=-2) / (torch.sum(asym_mask, dim=-1)[..., None] + eps) pred_centres = get_asym_centres(pred_atom_positions) # [B, NC, 3] true_centres = get_asym_centres(true_atom_positions) # [B, NC, 3] def get_dist(p1: torch.Tensor, p2: torch.Tensor): return torch.sqrt((p1[..., :, None, :] - p2[..., None, :, :]).square().sum(-1) + eps) pred_centres2 = pred_centres true_centres2 = true_centres pred_dists = get_dist(pred_centres, pred_centres2) # [B, NC, NC] true_dists = get_dist(true_centres, true_centres2) # [B, NC, NC] losses = (pred_dists - true_dists + 4).clamp(max=0).square() * 0.0025 loss_mask = asym_exists[..., :, None] * asym_exists[..., None, :] # [B, NC, NC] loss = masked_mean(loss_mask, losses, dim=(-1, -2)) return loss