import logging
from typing import Dict, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from deepfold.common import residue_constants as rc
from deepfold.config import LossConfig
from deepfold.losses.auxillary import experimentally_resolved_loss
from deepfold.losses.confidence import plddt_loss, tm_loss
from deepfold.losses.geometry import compute_renamed_ground_truth, distogram_loss, fape_loss, supervised_chi_loss
from deepfold.losses.masked_msa import masked_msa_loss
from deepfold.losses.violation import find_structural_violations, violation_loss
from deepfold.utils.rigid_utils import Rigid, Rotation
from deepfold.utils.tensor_utils import array_tree_map, tensor_tree_map
logger = logging.getLogger(__name__)
[docs]
class AlphaFoldLoss(nn.Module):
"""AlphaFold loss module.
Supplementary '1.9 Loss functions and auxiliary heads'.
"""
def __init__(self, config: LossConfig) -> None:
super().__init__()
self.fape_loss_config = config.fape_loss_config
self.supervised_chi_loss_config = config.supervised_chi_loss_config
self.distogram_loss_config = config.distogram_loss_config
self.masked_msa_loss_config = config.masked_msa_loss_config
self.plddt_loss_config = config.plddt_loss_config
self.experimentally_resolved_loss_config = config.experimentally_resolved_loss_config
self.violation_loss_config = config.violation_loss_config
self.tm_loss_config = config.tm_loss_config
[docs]
def forward(
self,
outputs: Dict[str, torch.Tensor],
batch: Dict[str, torch.Tensor],
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""AlphaFold loss forward pass.
Args:
outputs: forward pass output dict
batch: train batch dict
Returns:
scaled_weight_total_loss: total loss connected to the graph
losses: dict with loss detached from the graph
"""
batch_size = batch["aatype"].size(0)
if "violations" not in outputs.keys():
outputs["violations"] = find_structural_violations(
batch=batch,
atom14_pred_positions=outputs["sm_positions"][:, -1],
violation_tolerance_factor=self.violation_loss_config.violation_tolerance_factor,
clash_overlap_tolerance=self.violation_loss_config.clash_overlap_tolerance,
)
if "renamed_atom14_gt_positions" not in outputs.keys():
batch.update(compute_renamed_ground_truth(batch=batch, atom14_pred_positions=outputs["sm_positions"][:, -1]))
losses = {}
losses["fape"] = fape_loss(
outputs=outputs,
batch=batch,
backbone_clamp_distance=self.fape_loss_config.backbone_clamp_distance,
backbone_loss_unit_distance=self.fape_loss_config.backbone_loss_unit_distance,
backbone_weight=self.fape_loss_config.backbone_weight,
sidechain_clamp_distance=self.fape_loss_config.sidechain_clamp_distance,
sidechain_length_scale=self.fape_loss_config.sidechain_length_scale,
sidechain_weight=self.fape_loss_config.sidechain_weight,
eps=self.fape_loss_config.eps,
)
losses["supervised_chi"] = supervised_chi_loss(
angles_sin_cos=outputs["sm_angles"],
unnormalized_angles_sin_cos=outputs["sm_unnormalized_angles"],
aatype=batch["aatype"],
seq_mask=batch["seq_mask"],
chi_mask=batch["chi_mask"],
chi_angles_sin_cos=batch["chi_angles_sin_cos"],
chi_weight=self.supervised_chi_loss_config.chi_weight,
angle_norm_weight=self.supervised_chi_loss_config.angle_norm_weight,
eps=self.supervised_chi_loss_config.eps,
)
losses["distogram"] = distogram_loss(
logits=outputs["distogram_logits"],
pseudo_beta=batch["pseudo_beta"],
pseudo_beta_mask=batch["pseudo_beta_mask"],
min_bin=self.distogram_loss_config.min_bin,
max_bin=self.distogram_loss_config.max_bin,
num_bins=self.distogram_loss_config.num_bins,
eps=self.distogram_loss_config.eps,
)
losses["masked_msa"] = masked_msa_loss(
logits=outputs["masked_msa_logits"],
true_msa=batch["true_msa"],
bert_mask=batch["bert_mask"],
eps=self.masked_msa_loss_config.eps,
)
losses["plddt_loss"] = plddt_loss(
logits=outputs["lddt_logits"],
all_atom_pred_pos=outputs["final_atom_positions"],
all_atom_positions=batch["all_atom_positions"],
all_atom_mask=batch["all_atom_mask"],
resolution=batch["resolution"],
cutoff=self.plddt_loss_config.cutoff,
num_bins=self.plddt_loss_config.num_bins,
min_resolution=self.plddt_loss_config.min_resolution,
max_resolution=self.plddt_loss_config.max_resolution,
eps=self.plddt_loss_config.eps,
)
losses["experimentally_resolved"] = experimentally_resolved_loss(
logits=outputs["experimentally_resolved_logits"],
atom37_atom_exists=batch["atom37_atom_exists"],
all_atom_mask=batch["all_atom_mask"],
resolution=batch["resolution"],
min_resolution=self.experimentally_resolved_loss_config.min_resolution,
max_resolution=self.experimentally_resolved_loss_config.max_resolution,
eps=self.experimentally_resolved_loss_config.eps,
)
losses["violation"] = violation_loss(
violations=outputs["violations"],
atom14_atom_exists=batch["atom14_atom_exists"],
eps=self.violation_loss_config.eps,
)
if self.tm_loss_config.enabled:
losses["tm"] = tm_loss(
logits=outputs["tm_logits"],
final_affine_tensor=outputs["final_affine_tensor"],
backbone_rigid_tensor=batch["backbone_rigid_tensor"],
backbone_rigid_mask=batch["backbone_rigid_mask"],
resolution=batch["resolution"],
max_bin=self.tm_loss_config.max_bin,
num_bins=self.tm_loss_config.num_bins,
min_resolution=self.tm_loss_config.min_resolution,
max_resolution=self.tm_loss_config.max_resolution,
eps=self.tm_loss_config.eps,
)
for loss in losses.values():
assert loss.size() == (batch_size,)
weighted_losses = {}
weighted_losses["fape"] = losses["fape"] * self.fape_loss_config.weight
weighted_losses["supervised_chi"] = losses["supervised_chi"] * self.supervised_chi_loss_config.weight
weighted_losses["distogram"] = losses["distogram"] * self.distogram_loss_config.weight
weighted_losses["masked_msa"] = losses["masked_msa"] * self.masked_msa_loss_config.weight
weighted_losses["plddt_loss"] = losses["plddt_loss"] * self.plddt_loss_config.weight
weighted_losses["experimentally_resolved"] = losses["experimentally_resolved"] * self.experimentally_resolved_loss_config.weight
weighted_losses["violation"] = losses["violation"] * self.violation_loss_config.weight
if self.tm_loss_config.enabled:
weighted_losses["tm"] = losses["tm"] * self.tm_loss_config.weight
for name in list(weighted_losses.keys()):
loss = weighted_losses[name]
if torch.isnan(loss).any() or torch.isinf(loss).any():
logger.warning(f"Loss warning! weighted_losses['{name}']: {loss}")
loss = torch.zeros_like(loss, requires_grad=True)
weighted_losses[name] = loss
weighted_total_loss = sum(weighted_losses.values()) # Not torch.sum
# To decrease the relative importance of short sequences, we multiply the final loss
# of each training example by the square root of the number of residues after cropping.
assert batch["seq_length"].size() == (batch_size,)
seq_length = batch["seq_length"].float()
crop_size = torch.ones_like(seq_length) * batch["aatype"].size(1)
scale = torch.sqrt(torch.minimum(seq_length, crop_size))
scaled_weighted_total_loss = scale * weighted_total_loss
losses = {key: tensor.detach().clone().mean() for key, tensor in losses.items()}
losses["weighted_total_loss"] = weighted_total_loss.detach().clone().mean()
losses["scaled_weighted_total_loss"] = scaled_weighted_total_loss.detach().clone().mean()
scaled_weighted_total_loss = scaled_weighted_total_loss.mean()
return scaled_weighted_total_loss, losses