from typing import Dict, Optional, Tuple
import torch
import torch.nn.functional as F
from deepfold.common import residue_constants as rc
from deepfold.losses.procrustes import kabsch
from deepfold.losses.utils import softmax_cross_entropy
from deepfold.utils.rigid_utils import Rigid, Rotation
from deepfold.utils.tensor_utils import masked_mean
[docs]
def compute_fape(
pred_frames: Rigid,
target_frames: Rigid,
frames_mask: torch.Tensor,
pred_positions: torch.Tensor,
target_positions: torch.Tensor,
positions_mask: torch.Tensor,
length_scale: float,
l1_clamp_distance: Optional[float] = None,
eps: float = 1e-8,
) -> torch.Tensor:
"""Computes FAPE loss.
Args:
pred_frames: Rigid object of predicted frames. [*, N_frames]
target_frames: Rigid object of ground truth frames. [*, N_frames]
frames_mask: Binary mask for the frames. [*, N_frames]
pred_positions: Predicted atom positions. [*, N_pts, 3]
target_positions: Ground truth positions. [*, N_pts, 3]
positions_mask: Positions mask. [*, N_pts]
length_scale: Length scale by which the loss is divided.
l1_clamp_distance: Cutoff above which distance errors are disregarded.
eps: Small value used to regularize denominators.
Returns:
FAPE loss tensor.
"""
# [*, N_frames, N_pts, 3]
local_pred_pos = pred_frames.invert()[..., None].apply(pred_positions[..., None, :, :])
local_target_pos = target_frames.invert()[..., None].apply(target_positions[..., None, :, :])
error_dist = torch.sqrt(torch.sum((local_pred_pos - local_target_pos) ** 2, dim=-1) + eps)
if l1_clamp_distance is not None:
error_dist = torch.clamp(error_dist, min=0, max=l1_clamp_distance)
normed_error = error_dist / length_scale
normed_error = normed_error * frames_mask[..., None]
normed_error = normed_error * positions_mask[..., None, :]
# FP16-friendly averaging.
normed_error = torch.sum(normed_error, dim=-1)
normed_error = normed_error / (eps + torch.sum(frames_mask, dim=-1))[..., None]
normed_error = torch.sum(normed_error, dim=-1)
normed_error = normed_error / (eps + torch.sum(positions_mask, dim=-1))
return normed_error
[docs]
def backbone_loss(
backbone_rigid_tensor: torch.Tensor,
backbone_rigid_mask: torch.Tensor,
traj: torch.Tensor,
use_clamped_fape: Optional[torch.Tensor] = None,
clamp_distance: float = 10.0,
loss_unit_distance: float = 10.0,
eps: float = 1e-4,
) -> torch.Tensor:
pred_aff = Rigid.from_tensor_7(traj)
pred_aff = Rigid(
Rotation(rot_mats=pred_aff.get_rots().get_rot_mats(), quats=None),
pred_aff.get_trans(),
)
# NOTE: DeepMind somehow gets a hold of a tensor_7 version of backbone tensor,
# normalizes it, and then turns it back to a rotation matrix.
gt_aff = Rigid.from_tensor_4x4(backbone_rigid_tensor)
fape_value = compute_fape(
pred_frames=pred_aff,
target_frames=gt_aff[:, None],
frames_mask=backbone_rigid_mask[:, None],
pred_positions=pred_aff.get_trans(),
target_positions=gt_aff[:, None].get_trans(),
positions_mask=backbone_rigid_mask[:, None],
l1_clamp_distance=clamp_distance,
length_scale=loss_unit_distance,
eps=eps,
)
if use_clamped_fape is not None:
unclamped_fape_value = compute_fape(
pred_frames=pred_aff,
target_frames=gt_aff[:, None],
frames_mask=backbone_rigid_mask[:, None],
pred_positions=pred_aff.get_trans(),
target_positions=gt_aff[:, None].get_trans(),
positions_mask=backbone_rigid_mask[:, None],
l1_clamp_distance=None,
length_scale=loss_unit_distance,
eps=eps,
)
use_clamped_fape = use_clamped_fape.unsqueeze(-1)
fape_value = fape_value * use_clamped_fape + unclamped_fape_value * (1 - use_clamped_fape)
fape_value = torch.mean(fape_value, dim=1)
return fape_value
[docs]
def sidechain_loss(
sidechain_frames: torch.Tensor,
sidechain_atom_pos: torch.Tensor,
rigidgroups_gt_frames: torch.Tensor,
rigidgroups_alt_gt_frames: torch.Tensor,
rigidgroups_gt_exists: torch.Tensor,
renamed_atom14_gt_positions: torch.Tensor,
renamed_atom14_gt_exists: torch.Tensor,
alt_naming_is_better: torch.Tensor,
clamp_distance: float = 10.0,
length_scale: float = 10.0,
eps: float = 1e-4,
) -> torch.Tensor:
renamed_gt_frames = (1.0 - alt_naming_is_better[..., None, None, None]) * rigidgroups_gt_frames
renamed_gt_frames = renamed_gt_frames + alt_naming_is_better[..., None, None, None] * rigidgroups_alt_gt_frames
sidechain_frames = sidechain_frames[:, -1]
batch_dims = sidechain_frames.shape[:-4]
sidechain_frames = sidechain_frames.view(*batch_dims, -1, 4, 4)
sidechain_frames = Rigid.from_tensor_4x4(sidechain_frames)
renamed_gt_frames = renamed_gt_frames.view(*batch_dims, -1, 4, 4)
renamed_gt_frames = Rigid.from_tensor_4x4(renamed_gt_frames)
rigidgroups_gt_exists = rigidgroups_gt_exists.reshape(*batch_dims, -1)
sidechain_atom_pos = sidechain_atom_pos[:, -1]
sidechain_atom_pos = sidechain_atom_pos.view(*batch_dims, -1, 3)
renamed_atom14_gt_positions = renamed_atom14_gt_positions.view(*batch_dims, -1, 3)
renamed_atom14_gt_exists = renamed_atom14_gt_exists.view(*batch_dims, -1)
fape_value = compute_fape(
pred_frames=sidechain_frames,
target_frames=renamed_gt_frames,
frames_mask=rigidgroups_gt_exists,
pred_positions=sidechain_atom_pos,
target_positions=renamed_atom14_gt_positions,
positions_mask=renamed_atom14_gt_exists,
l1_clamp_distance=clamp_distance,
length_scale=length_scale,
eps=eps,
)
return fape_value
[docs]
def fape_loss(
outputs: Dict[str, torch.Tensor],
batch: Dict[str, torch.Tensor],
backbone_clamp_distance: float,
backbone_loss_unit_distance: float,
backbone_weight: float,
sidechain_clamp_distance: float,
sidechain_length_scale: float,
sidechain_weight: float,
eps: float = 1e-4,
) -> torch.Tensor:
backbone_loss_value = backbone_loss(
backbone_rigid_tensor=batch["backbone_rigid_tensor"],
backbone_rigid_mask=batch["backbone_rigid_mask"],
traj=outputs["sm_frames"],
use_clamped_fape=batch.get("use_clamped_fape", None),
clamp_distance=backbone_clamp_distance,
loss_unit_distance=backbone_loss_unit_distance,
eps=eps,
)
sidechain_loss_value = sidechain_loss(
sidechain_frames=outputs["sm_sidechain_frames"],
sidechain_atom_pos=outputs["sm_positions"],
rigidgroups_gt_frames=batch["rigidgroups_gt_frames"],
rigidgroups_alt_gt_frames=batch["rigidgroups_alt_gt_frames"],
rigidgroups_gt_exists=batch["rigidgroups_gt_exists"],
renamed_atom14_gt_positions=batch["renamed_atom14_gt_positions"],
renamed_atom14_gt_exists=batch["renamed_atom14_gt_exists"],
alt_naming_is_better=batch["alt_naming_is_better"],
clamp_distance=sidechain_clamp_distance,
length_scale=sidechain_length_scale,
eps=eps,
)
fape_loss_value = backbone_loss_value * backbone_weight + sidechain_loss_value * sidechain_weight
return fape_loss_value
[docs]
def supervised_chi_loss(
angles_sin_cos: torch.Tensor,
unnormalized_angles_sin_cos: torch.Tensor,
aatype: torch.Tensor,
seq_mask: torch.Tensor,
chi_mask: torch.Tensor,
chi_angles_sin_cos: torch.Tensor,
chi_weight: float,
angle_norm_weight: float,
eps: float = 1e-6,
) -> torch.Tensor:
"""Torsion Angle Loss.
Supplementary '1.9.1 Side chain and backbone torsion angle loss':
Algorithm 27 Side chain and backbone torsion angle loss.
Args:
angles_sin_cos: Predicted angles. [*, N, 7, 2]
unnormalized_angles_sin_cos: [*, N, 7, 2]
The same angles, but unnormalized.
aatype: Residue indices. [*, N]
seq_mask: Sequence mask. [*, N]
chi_mask: Angle mask. [*, N, 7]
chi_angles_sin_cos: Ground truth angles. [*, N, 7, 2]
chi_weight: Weight for the angle component of the loss.
angle_norm_weight: Weight for the normalization component of the loss.
Returns:
Torsion angle loss tensor.
"""
pred_angles = angles_sin_cos[..., 3:, :]
residue_type_one_hot = F.one_hot(aatype, rc.restype_num + 1)
chi_pi_periodic = torch.einsum(
"ijk,kl->ijl",
residue_type_one_hot.type(angles_sin_cos.dtype),
angles_sin_cos.new_tensor(rc.chi_pi_periodic),
)
true_chi = chi_angles_sin_cos[:, None]
shifted_mask = (1 - 2 * chi_pi_periodic).unsqueeze(-1).unsqueeze(-4)
true_chi_shifted = shifted_mask * true_chi
sq_chi_error = torch.sum((true_chi - pred_angles) ** 2, dim=-1)
sq_chi_error_shifted = torch.sum((true_chi_shifted - pred_angles) ** 2, dim=-1)
sq_chi_error = torch.minimum(sq_chi_error, sq_chi_error_shifted)
sq_chi_loss = masked_mean(chi_mask[..., None, :, :], sq_chi_error, dim=(-1, -2, -3))
loss = chi_weight * sq_chi_loss
angle_norm = torch.sqrt(torch.sum(unnormalized_angles_sin_cos**2, dim=-1) + eps)
norm_error = torch.abs(angle_norm - 1.0)
angle_norm_loss = masked_mean(seq_mask[..., None, :, None], norm_error, dim=(-1, -2, -3))
loss = loss + angle_norm_weight * angle_norm_loss
return loss
[docs]
def compute_distogram(
positions: torch.Tensor,
mask: torch.Tensor,
min_bin: float = 2.3125,
max_bin: float = 21.6875,
num_bins: int = 64,
) -> Tuple[torch.Tensor, torch.Tensor]:
boundaries = torch.linspace(min_bin, max_bin, steps=(num_bins - 1), device=positions.device)
boundaries = boundaries**2
positions = positions.float()
dists = torch.sum((positions[..., :, None, :] - positions[..., None, :, :]) ** 2, dim=-1, keepdim=True).detach()
true_bins = torch.sum(dists > boundaries, dim=-1)
mask = mask.float()
pair_mask = mask[..., :, None] * mask[..., None, :]
return true_bins, pair_mask
[docs]
def distogram_loss(
logits: torch.Tensor,
pseudo_beta: torch.Tensor,
pseudo_beta_mask: torch.Tensor,
min_bin: float = 2.3125,
max_bin: float = 21.6875,
num_bins: int = 64,
eps: float = 1e-6,
) -> torch.Tensor:
true_bins, square_mask = compute_distogram(
positions=pseudo_beta,
mask=pseudo_beta_mask,
min_bin=min_bin,
max_bin=max_bin,
num_bins=num_bins,
)
errors = softmax_cross_entropy(logits=logits, labels=F.one_hot(true_bins, num_bins))
# FP16-friendly sum.
denom = eps + torch.sum(square_mask, dim=(-1, -2))
mean = errors * square_mask
mean = torch.sum(mean, dim=-1)
mean = mean / denom[..., None]
mean = torch.sum(mean, dim=-1)
return mean
[docs]
def compute_renamed_ground_truth(
batch: Dict[str, torch.Tensor],
atom14_pred_positions: torch.Tensor,
eps=1e-10,
) -> Dict[str, torch.Tensor]:
atom14_pred_positions = atom14_pred_positions.float()
pred_dists = torch.sqrt(
eps + torch.sum((atom14_pred_positions[..., None, :, None, :] - atom14_pred_positions[..., None, :, None, :, :]) ** 2, dim=-1)
)
atom14_gt_positions = batch["atom14_gt_positions"].float()
gt_dists = torch.sqrt(eps + torch.sum((atom14_gt_positions[..., None, :, None, :] - atom14_gt_positions[..., None, :, None, :, :]) ** 2, dim=-1))
atom14_alt_gt_positions = batch["atom14_alt_gt_positions"].float()
alt_gt_dists = torch.sqrt(
eps + torch.sum((atom14_alt_gt_positions[..., :, None, :, None, :] - atom14_alt_gt_positions[..., None, :, None, :, :]) ** 2, dim=-1)
)
lddt = torch.sqrt(eps + (pred_dists - gt_dists) ** 2)
alt_lddt = torch.sqrt(eps + (pred_dists - alt_gt_dists) ** 2)
atom14_gt_exists = batch["atom14_gt_exists"].float()
atom14_atom_is_ambiguous = batch["atom14_atom_is_ambiguous"].float()
mask = (
atom14_gt_exists[..., None, :, None]
* atom14_atom_is_ambiguous[..., None, :, None]
* atom14_gt_exists[..., None, :, None, :]
* (1.0 - atom14_atom_is_ambiguous[..., None, :, None, :])
)
per_res_lddt = torch.sum(mask * lddt, dim=(-1, -2, -3))
alt_per_res_lddt = torch.sum(mask * alt_lddt, dim=(-1, -2, -3))
fp_type = atom14_pred_positions.dtype
alt_naming_is_better = (alt_per_res_lddt < per_res_lddt).type(fp_type)
renamed_atom14_gt_positions = (1.0 - alt_naming_is_better[..., None, None]) * atom14_gt_positions
renamed_atom14_gt_positions = renamed_atom14_gt_positions + alt_naming_is_better[..., None, None] * atom14_alt_gt_positions
renamed_atom14_gt_mask = (1.0 - alt_naming_is_better[..., None]) * atom14_gt_exists
renamed_atom14_gt_mask = renamed_atom14_gt_mask + alt_naming_is_better[..., None] * batch["atom14_alt_gt_exists"].float()
return {
"alt_naming_is_better": alt_naming_is_better,
"renamed_atom14_gt_positions": renamed_atom14_gt_positions,
"renamed_atom14_gt_exists": renamed_atom14_gt_mask,
}
@torch.jit.script
def compute_rmsd(
true_atom_pos: torch.Tensor,
pred_atom_pos: torch.Tensor,
atom_mask: Optional[torch.Tensor] = None,
eps: float = 1e-6,
nan_to: float = 1e8,
) -> torch.Tensor:
"""Calculate RMSD.
This function doesn't superimpose positions.
"""
sq_diff = torch.square(true_atom_pos - pred_atom_pos).sum(dim=-1, keepdim=False)
if atom_mask is not None:
sq_diff = sq_diff[atom_mask]
msd = torch.mean(sq_diff)
msd = torch.nan_to_num(msd, nan=nan_to)
return torch.sqrt(msd + eps)
[docs]
def kabsch_rmsd(
true_atom_pos: torch.Tensor,
pred_atom_pos: torch.Tensor,
atom_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r, x = get_optimal_transform(true_atom_pos, pred_atom_pos, atom_mask)
aligned_true_atom_pos = true_atom_pos @ r + x
return compute_rmsd(aligned_true_atom_pos, pred_atom_pos, atom_mask=atom_mask)
[docs]
def superimpose(
src_atoms: torch.Tensor,
tgt_atoms: torch.Tensor,
mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Superimposes coordinates onto a `tgt_atoms` by minimizing RMSD using SVD.
Args:
src_atoms: reference tensor shaped [*, N, 3]
tgt_atoms: target tensor shaped [*, N, 3]
mask: mask tensor shaped [*, N]
Returns:
superimposed: superimposed coords [*, N, 3]
rmsds: final RMSDs [*]
"""
r, t = kabsch(src_atoms, tgt_atoms, weights=mask)
superimposed = src_atoms @ r + t
rmsds = compute_rmsd(superimpose, tgt_atoms, atom_mask=mask)
return superimposed, rmsds