Source code for deepfold.losses.masked_msa
import torch
import torch.nn.functional as F
from deepfold.losses.utils import softmax_cross_entropy
[docs]
def masked_msa_loss(
logits: torch.Tensor,
true_msa: torch.Tensor,
bert_mask: torch.Tensor,
eps: float = 1e-8,
) -> torch.Tensor:
"""Computes BERT-style masked MSA loss.
Supplementary '1.9.9 Masked MSA prediction'.
Args:
logits: [*, N_seq, N_res, 23] predicted residue distribution
true_msa: [*, N_seq, N_res] true MSA
bert_mask: [*, N_seq, N_res] MSA mask
Returns:
Masked MSA loss
"""
errors = softmax_cross_entropy(logits=logits, labels=F.one_hot(true_msa, num_classes=23))
# FP16-friendly averaging.
loss = errors * bert_mask
loss = torch.sum(loss, dim=-1)
scale = 0.5
denom = eps + torch.sum(scale * bert_mask, dim=(-1, -2))
loss = loss / denom[..., None]
loss = torch.sum(loss, dim=-1)
loss = loss * scale
return loss