from typing import Dict, Optional, Tuple
import torch
import torch.nn.functional as F
from deepfold.common import residue_constants as rc
from deepfold.losses.utils import calculate_bin_centers, softmax_cross_entropy
from deepfold.utils.rigid_utils import Rigid
[docs]
def compute_plddt(logits: torch.Tensor) -> torch.Tensor:
num_bins = logits.shape[-1]
bin_width = 1.0 / num_bins
bounds = torch.arange(
start=(0.5 * bin_width),
end=1.0,
step=bin_width,
device=logits.device,
)
probs = torch.softmax(logits, dim=-1)
pred_lddt_ca = torch.sum(
probs * bounds.view(*((1,) * (probs.ndim - 1)), *bounds.shape),
dim=-1,
)
return pred_lddt_ca * 100
[docs]
def lddt(
all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
cutoff: float = 15.0,
eps: float = 1e-10,
per_residue: bool = True,
) -> torch.Tensor:
"""Calculate lDDT score."""
n = all_atom_mask.shape[-2]
dmat_true = torch.sqrt(eps + torch.sum((all_atom_positions[..., None, :] - all_atom_positions[..., None, :, :]) ** 2, dim=-1))
dmat_pred = torch.sqrt(eps + torch.sum((all_atom_pred_pos[..., None, :] - all_atom_pred_pos[..., None, :, :]) ** 2, dim=-1))
dists_to_score = (dmat_true < cutoff) * all_atom_mask * torch.swapdims(all_atom_mask, -2, -1) * (1.0 - torch.eye(n, device=all_atom_mask.device))
dist_l1 = torch.abs(dmat_true - dmat_pred)
score = (
(dist_l1 < 0.5).type(dist_l1.dtype)
+ (dist_l1 < 1.0).type(dist_l1.dtype)
+ (dist_l1 < 2.0).type(dist_l1.dtype)
+ (dist_l1 < 4.0).type(dist_l1.dtype)
)
score = score * 0.25
dims = (-1,) if per_residue else (-2, -1)
norm = 1.0 / (eps + torch.sum(dists_to_score, dim=dims))
score = norm * (eps + torch.sum(dists_to_score * score, dim=dims))
return score
[docs]
def lddt_ca(
all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
cutoff: float = 15.0,
eps: float = 1e-10,
per_residue: bool = True,
) -> torch.Tensor:
"""Calculate lDDT score with only alhpa-carbon."""
ca_pos = rc.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
return lddt(
all_atom_pred_pos,
all_atom_positions,
all_atom_mask,
cutoff=cutoff,
eps=eps,
per_residue=per_residue,
)
[docs]
def plddt_loss(
logits: torch.Tensor,
all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
resolution: torch.Tensor,
cutoff: float = 15.0,
num_bins: int = 50,
min_resolution: float = 0.1,
max_resolution: float = 3.0,
eps: float = 1e-10,
) -> torch.Tensor:
"""Calculate plDDT loss."""
ca_pos = rc.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
score = lddt(all_atom_pred_pos, all_atom_positions, all_atom_mask, cutoff=cutoff, eps=eps).detach()
bin_index = torch.floor(score * num_bins).long()
bin_index = torch.clamp(bin_index, max=(num_bins - 1))
lddt_ca_one_hot = F.one_hot(bin_index, num_classes=num_bins)
errors = softmax_cross_entropy(logits=logits, labels=lddt_ca_one_hot)
all_atom_mask = all_atom_mask.squeeze(-1)
loss = torch.sum(errors * all_atom_mask, dim=-1) / (eps + torch.sum(all_atom_mask, dim=-1))
# High resolution only
loss = loss * ((resolution >= min_resolution) & (resolution <= max_resolution))
return loss
def _calculate_expected_aligned_error(
alignment_confidence_breaks: torch.Tensor,
aligned_distance_error_probs: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
bin_centers = calculate_bin_centers(alignment_confidence_breaks)
return (
torch.sum(aligned_distance_error_probs * bin_centers, dim=-1),
bin_centers[-1],
)
[docs]
def compute_predicted_aligned_error(
logits: torch.Tensor,
max_bin: int = 31,
num_bins: int = 64,
) -> Dict[str, torch.Tensor]:
"""Computes aligned confidence metrics from logits.
Args:
logits: [*, num_res, num_res, num_bins] the logits output from PredictedAlignedErrorHead.
max_bin: Maximum bin value
num_bins: Number of bins
Returns:
aligned_confidence_probs: [*, num_res, num_res, num_bins] the predicted
aligned error probabilities over bins for each residue pair.
predicted_aligned_error: [*, num_res, num_res] the expected aligned distance
error for each pair of residues.
max_predicted_aligned_error: [*] the maximum predicted error possible.
"""
boundaries = torch.linspace(0, max_bin, steps=(num_bins - 1), device=logits.device)
aligned_confidence_probs = torch.softmax(logits, dim=-1)
expected_aligned_error = _calculate_expected_aligned_error(
alignment_confidence_breaks=boundaries,
aligned_distance_error_probs=aligned_confidence_probs,
)
predicted_aligned_error = expected_aligned_error[0]
max_predicted_aligned_error = expected_aligned_error[1]
return {
"aligned_confidence_probs": aligned_confidence_probs,
"predicted_aligned_error": predicted_aligned_error,
"max_predicted_aligned_error": max_predicted_aligned_error,
}
[docs]
def compute_tm(
logits: torch.Tensor,
residue_weights: Optional[torch.Tensor] = None,
asym_id: Optional[torch.Tensor] = None,
interface: bool = False,
max_bin: int = 31,
num_bins: int = 64,
eps: float = 1e-8,
) -> torch.Tensor:
"""Compute TM score from logis."""
if residue_weights is None:
residue_weights = logits.new_ones(logits.shape[-2])
boundaries = torch.linspace(
start=0,
end=max_bin,
steps=(num_bins - 1),
device=logits.device,
)
bin_centers = calculate_bin_centers(boundaries)
clipped_n = max(torch.sum(residue_weights), 19)
d0 = 1.24 * (clipped_n - 15) ** (1.0 / 3) - 1.8
probs = torch.softmax(logits, dim=-1)
tm_per_bin = 1.0 / (1 + (bin_centers**2) / (d0**2))
predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1)
n = residue_weights.shape[-1]
pair_mask = residue_weights.new_ones((n, n), dtype=torch.int32)
if interface and (asym_id is not None):
if len(asym_id.shape) > 1:
assert len(asym_id.shape) <= 2
batch_size = asym_id.shape[0]
pair_mask = residue_weights.new_ones((batch_size, n, n), dtype=torch.int32)
pair_mask *= (asym_id[..., None] != asym_id[..., None, :]).to(dtype=pair_mask.dtype)
predicted_tm_term *= pair_mask
pair_residue_weights = pair_mask * (residue_weights[..., None, :] * residue_weights[..., :, None])
denom = eps + torch.sum(pair_residue_weights, dim=-1, keepdim=True)
normed_residue_mask = pair_residue_weights / denom
per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1)
weighted = per_alignment * residue_weights
argmax = (weighted == torch.max(weighted)).nonzero()[0]
return per_alignment[tuple(argmax)]
[docs]
def tm_loss(
logits: torch.Tensor,
final_affine_tensor: torch.Tensor,
backbone_rigid_tensor: torch.Tensor,
backbone_rigid_mask: torch.Tensor,
resolution: torch.Tensor,
max_bin: int = 31,
num_bins: int = 64,
min_resolution: float = 0.1,
max_resolution: float = 3.0,
eps: float = 1e-8,
) -> torch.Tensor:
pred_affine = Rigid.from_tensor_7(final_affine_tensor)
backbone_rigid = Rigid.from_tensor_4x4(backbone_rigid_tensor)
def _points(affine):
pts = affine.get_trans()[..., None, :, :]
return affine.invert()[..., None].apply(pts)
sq_diff = torch.sum((_points(pred_affine) - _points(backbone_rigid)) ** 2, dim=-1)
sq_diff = sq_diff.detach()
boundaries = torch.linspace(start=0, end=max_bin, steps=(num_bins - 1), device=logits.device)
boundaries = boundaries**2
true_bins = torch.sum(sq_diff[..., None] > boundaries, dim=-1)
errors = softmax_cross_entropy(
logits=logits,
labels=F.one_hot(true_bins, num_bins),
)
square_mask = backbone_rigid_mask[..., None] * backbone_rigid_mask[..., None, :]
loss = torch.sum(errors * square_mask, dim=-1)
scale = 0.5 # hack to help FP16 training along
denom = eps + torch.sum(scale * square_mask, dim=(-1, -2))
loss = loss / denom[..., None]
loss = torch.sum(loss, dim=-1)
loss = loss * scale
loss = loss * ((resolution >= min_resolution) & (resolution <= max_resolution))
return loss