from dataclasses import asdict
from typing import Dict, Optional
import torch
import torch.nn as nn
from deepfold.config import AuxiliaryHeadsConfig
from deepfold.losses.confidence import compute_plddt, compute_predicted_aligned_error, compute_tm
from deepfold.modules.layer_norm import LayerNorm
from deepfold.modules.linear import Linear
from deepfold.utils.precision import is_fp16_enabled
[docs]
class AuxiliaryHeads(nn.Module):
"""Auxiliary Heads module."""
def __init__(self, config: AuxiliaryHeadsConfig) -> None:
super().__init__()
self.plddt = PerResidueLDDTCaPredictor(
**asdict(config.per_residue_lddt_ca_predictor_config),
)
self.distogram = DistogramHead(
**asdict(config.distogram_head_config),
)
self.masked_msa = MaskedMSAHead(
**asdict(config.masked_msa_head_config),
)
self.experimentally_resolved = ExperimentallyResolvedHead(
**asdict(config.experimentally_resolved_head_config),
)
self.tm_score_head_enabled = config.tm_score_head_enabled
if self.tm_score_head_enabled:
self.tm = TMScoreHead(
**asdict(config.tm_score_head_config),
)
self.ptm_weight = config.ptm_weight
self.iptm_weight = config.iptm_weight
[docs]
def forward(
self,
outputs: Dict[str, torch.Tensor],
seq_mask: Optional[torch.Tensor] = None,
asym_id: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
aux_outputs = {}
aux_outputs["lddt_logits"] = self.plddt(s=outputs["sm_single"])
aux_outputs["plddt"] = compute_plddt(logits=aux_outputs["lddt_logits"])
aux_outputs["distogram_logits"] = self.distogram(outputs["pair"])
aux_outputs["masked_msa_logits"] = self.masked_msa(outputs["msa"])
aux_outputs["experimentally_resolved_logits"] = self.experimentally_resolved(outputs["single"])
if self.tm_score_head_enabled:
aux_outputs["tm_logits"] = self.tm(outputs["pair"])
aux_outputs["ptm_score"] = compute_tm(
logits=aux_outputs["tm_logits"],
residue_weights=seq_mask,
max_bin=self.tm.max_bin,
num_bins=self.tm.num_bins,
)
if asym_id is not None:
aux_outputs["iptm_score"] = compute_tm(
logits=aux_outputs["tm_logits"],
residue_weights=seq_mask,
asym_id=asym_id,
interface=True,
max_bin=self.tm.max_bin,
num_bins=self.tm.num_bins,
)
aux_outputs["weighted_ptm_score"] = self.iptm_weight * aux_outputs["iptm_score"] + self.ptm_weight * aux_outputs["ptm_score"]
aux_outputs.update(
compute_predicted_aligned_error(
logits=aux_outputs["tm_logits"],
max_bin=self.tm.max_bin,
num_bins=self.tm.num_bins,
)
)
return aux_outputs
[docs]
class PerResidueLDDTCaPredictor(nn.Module):
"""Per-Residue LDDT-Ca Predictor module.
Supplementary '1.9.6 Model confidence prediction (pLDDT)': Algorithm 29.
Args:
c_s: Single representation dimension (channels).
c_hidden: Hidden dimension (channels).
num_bins: Output dimension (channels).
"""
def __init__(
self,
c_s: int,
c_hidden: int,
num_bins: int,
) -> None:
super().__init__()
self.layer_norm = LayerNorm(c_s)
self.linear_1 = Linear(c_s, c_hidden, bias=True, init="relu")
self.linear_2 = Linear(c_hidden, c_hidden, bias=True, init="relu")
self.linear_3 = Linear(c_hidden, num_bins, bias=True, init="final")
[docs]
def forward(self, s: torch.Tensor) -> torch.Tensor:
s = self.layer_norm(s)
s = self.linear_1(s)
s = torch.relu(s)
s = self.linear_2(s)
s = torch.relu(s)
s = self.linear_3(s)
return s
[docs]
class DistogramHead(nn.Module):
"""Distogram Head module.
Computes a distogram probability distribution.
Supplementary '1.9.8 Distogram prediction'.
Args:
c_z: Pair representation dimension (channels).
num_bins: Output dimension (channels).
"""
def __init__(
self,
c_z: int,
num_bins: int,
) -> None:
super().__init__()
self.linear = Linear(c_z, num_bins, bias=True, init="final")
def _forward(self, z: torch.Tensor) -> torch.Tensor:
logits = self.linear(z)
logits = logits + logits.transpose(-2, -3)
return logits
[docs]
def forward(self, z: torch.Tensor) -> torch.Tensor:
if is_fp16_enabled():
with torch.cuda.amp.autocast(enabled=False):
return self._forward(z.float())
else:
return self._forward(z)
[docs]
class MaskedMSAHead(nn.Module):
"""Masked MSA Head module.
Supplementary '1.9.9 Masked MSA prediction'.
Args:
c_m: MSA representation dimension (channels).
c_out: Output dimension (channels).
"""
def __init__(
self,
c_m: int,
c_out: int,
) -> None:
super().__init__()
self.linear = Linear(c_m, c_out, bias=True, init="final")
[docs]
def forward(self, m: torch.Tensor) -> torch.Tensor:
logits = self.linear(m)
return logits
[docs]
class ExperimentallyResolvedHead(nn.Module):
"""Experimentally Resolved Head module.
Supplementary '1.9.10 Experimentally resolved prediction'.
Args:
c_s: Single representation dimension (channels).
c_out: Output dimension (channels).
"""
def __init__(
self,
c_s: int,
c_out: int,
) -> None:
super().__init__()
self.linear = Linear(c_s, c_out, bias=True, init="final")
[docs]
def forward(self, s: torch.Tensor) -> torch.Tensor:
logits = self.linear(s)
return logits
[docs]
class TMScoreHead(nn.Module):
"""TM-Score Head module.
Supplementary '1.9.7 TM-score prediction'.
Args:
c_z: Pair representation dimension (channels).
num_bins: Output dimension (channels).
max_bin: Max bin range for discretizing the distribution.
"""
def __init__(
self,
c_z: int,
num_bins: int,
max_bin: int,
) -> None:
super().__init__()
self.num_bins = num_bins
self.max_bin = max_bin
self.linear = Linear(c_z, num_bins, bias=True, init="final")
[docs]
def forward(self, z: torch.Tensor) -> torch.Tensor:
logits = self.linear(z)
return logits