deepfold.losses.masked_msa.masked_msa_loss

deepfold.losses.masked_msa.masked_msa_loss(logits: Tensor, true_msa: Tensor, bert_mask: Tensor, eps: float = 1e-08) Tensor[source]

Computes BERT-style masked MSA loss.

Supplementary ‘1.9.9 Masked MSA prediction’.

Parameters:
  • logits – [*, N_seq, N_res, 23] predicted residue distribution

  • true_msa – [*, N_seq, N_res] true MSA

  • bert_mask – [*, N_seq, N_res] MSA mask

Returns:

Masked MSA loss