# Copyright 2024 DeepFold Team
"""DeepFold2 model configuration."""
from __future__ import annotations
from dataclasses import asdict, dataclass, field
from typing import List, Optional
import dacite
NUM_RES = "NUM_RES"
NUM_MSA_SEQ = "NUM_MSA_SEQ"
NUM_EXTRA_SEQ = "NUM_EXTRA_SEQ"
NUM_TEMPLATES = "NUM_TEMPLATES"
[docs]
@dataclass
class RecyclingEmbedderConfig:
c_m: int = 256
c_z: int = 128
min_bin: float = 3.25
max_bin: float = 20.75
num_bins: int = 15
inf: float = 1e8
[docs]
@dataclass
class TemplateAngleEmbedderConfig:
ta_dim: int = 57 # 34
c_m: int = 256
[docs]
@dataclass
class TemplatePairEmbedderConfig:
tp_dim: int = 88
c_t: int = 64
c_z: int = 128
c_dgram: int = 39
c_aatype: int = 22
[docs]
@dataclass
class TemplatePairStackConfig:
c_t: int = 64
c_hidden_tri_att: int = 16
c_hidden_tri_mul: int = 64
num_blocks: int = 2
num_heads_tri: int = 4
pair_transition_n: int = 2
dropout_rate: float = 0.25
inf: float = 1e9
chunk_size_tri_att: Optional[int] = None
block_size_tri_mul: Optional[int] = None
tri_att_first: bool = True # False
[docs]
@dataclass
class TemplatePointwiseAttentionConfig:
c_t: int = 64
c_z: int = 128
c_hidden: int = 16
num_heads: int = 4
inf: float = 1e5
chunk_size: Optional[int] = None
[docs]
@dataclass
class TemplateProjectionConfig:
c_t: int = 64
c_z: int = 128
[docs]
@dataclass
class StructureModuleConfig:
c_s: int = 384
c_z: int = 128
c_hidden_ipa: int = 16
c_hidden_ang_res: int = 128
num_heads_ipa: int = 12
num_qk_points: int = 4
num_v_points: int = 8
is_multimer: bool = False
dropout_rate: float = 0.1
num_blocks: int = 8
num_ang_res_blocks: int = 2
num_angles: int = 7
scale_factor: float = 10.0 # 20.0
inf: float = 1e5
eps: float = 1e-8
[docs]
@dataclass
class PerResidueLDDTCaPredictorConfig:
c_s: int = 384
c_hidden: int = 128
num_bins: int = 50
[docs]
@dataclass
class DistogramHeadConfig:
c_z: int = 128
num_bins: int = 64
[docs]
@dataclass
class MaskedMSAHeadConfig:
c_m: int = 256
c_out: int = 23 # 22
[docs]
@dataclass
class ExperimentallyResolvedHeadConfig:
c_s: int = 384
c_out: int = 37
[docs]
@dataclass
class TMScoreHeadConfig:
c_z: int = 128
num_bins: int = 64
max_bin: int = 31
[docs]
@dataclass
class AuxiliaryHeadsConfig:
per_residue_lddt_ca_predictor_config: PerResidueLDDTCaPredictorConfig = field(
default=PerResidueLDDTCaPredictorConfig(),
)
distogram_head_config: DistogramHeadConfig = field(
default=DistogramHeadConfig(),
)
masked_msa_head_config: MaskedMSAHeadConfig = field(
default=MaskedMSAHeadConfig(),
)
experimentally_resolved_head_config: ExperimentallyResolvedHeadConfig = field(
default=ExperimentallyResolvedHeadConfig(),
)
tm_score_head_config: TMScoreHeadConfig = field(
default=TMScoreHeadConfig(),
)
tm_score_head_enabled: bool = False
ptm_weight: float = 0.2
iptm_weight: float = 0.8
[docs]
@dataclass
class FAPELossConfig:
weight: float = 1.0
# `backbone_` is used as `intra_chain_backbone_` in the multimer model.
backbone_clamp_distance: float = 10.0
backbone_loss_unit_distance: float = 10.0
backbone_weight: float = 0.5
interface_clamp_distance: float = 30.0
interface_loss_unit_distance: float = 20.0
interface_weight: float = 0.5
sidechain_clamp_distance: float = 10.0
sidechain_length_scale: float = 10.0
sidechain_weight: float = 0.5
eps: float = 1e-4
[docs]
@dataclass
class SupervisedChiLossConfig:
weight: float = 1.0
chi_weight: float = 0.5
angle_norm_weight: float = 0.01
eps: float = 1e-8
[docs]
@dataclass
class DistogramLossConfig:
weight: float = 0.3
min_bin: float = 2.3125
max_bin: float = 21.6875
num_bins: int = 64
eps: float = 1e-8
[docs]
@dataclass
class MaskedMSALossConfig:
weight: float = 2.0
eps: float = 1e-8
num_classes: int = 23
[docs]
@dataclass
class PLDDTLossConfig:
weight: float = 0.01
cutoff: float = 15.0
min_resolution: float = 0.1
max_resolution: float = 3.0
num_bins: int = 50
eps: float = 1e-8
[docs]
@dataclass
class ExperimentallyResolvedLossConfig:
weight: float = 0.0
min_resolution: float = 0.1
max_resolution: float = 3.0
eps: float = 1e-8
[docs]
@dataclass
class ViolationLossConfig:
weight: float = 0.0
violation_tolerance_factor: float = 12.0
average_clashes: bool = False
clash_overlap_tolerance: float = 1.5
eps: float = 1e-6 # 1e-8
[docs]
@dataclass
class TMLossConfig:
enabled: bool = False
weight: float = 0.0
min_resolution: float = 0.1
max_resolution: float = 3.0
num_bins: int = 64
max_bin: int = 31
eps: float = 1e-8
[docs]
@dataclass
class CenterOfMassConfig:
enabled: bool = False
clamp_distance: float = -4.0
eps: float = 1e-10
weight: float = 0.05
[docs]
@dataclass
class LossConfig:
fape_loss_config: FAPELossConfig = field(
default=FAPELossConfig(),
)
supervised_chi_loss_config: SupervisedChiLossConfig = field(
default=SupervisedChiLossConfig(),
)
distogram_loss_config: DistogramLossConfig = field(
default=DistogramLossConfig(),
)
masked_msa_loss_config: MaskedMSALossConfig = field(
default=MaskedMSALossConfig(),
)
plddt_loss_config: PLDDTLossConfig = field(
default=PLDDTLossConfig(),
)
experimentally_resolved_loss_config: ExperimentallyResolvedLossConfig = field(
default=ExperimentallyResolvedLossConfig(),
)
violation_loss_config: ViolationLossConfig = field(
default=ViolationLossConfig(),
)
tm_loss_config: TMLossConfig = field(
default=TMLossConfig(),
)
chain_center_of_mass_config: CenterOfMassConfig = field(
default=CenterOfMassConfig(),
)
[docs]
@dataclass
class AlphaFoldConfig:
is_multimer: bool = False
# AlphaFold modules configuration:
input_embedder_config: InputEmbedderConfig = field(
default=InputEmbedderConfig(),
)
recycling_embedder_config: RecyclingEmbedderConfig = field(
default=RecyclingEmbedderConfig(),
)
template_angle_embedder_config: TemplateAngleEmbedderConfig = field(
default=TemplateAngleEmbedderConfig(),
)
template_pair_embedder_config: TemplatePairEmbedderConfig = field(
default=TemplatePairEmbedderConfig(),
)
template_pair_stack_config: TemplatePairStackConfig = field(
default=TemplatePairStackConfig(),
)
template_pointwise_attention_config: TemplatePointwiseAttentionConfig = field(
default=TemplatePointwiseAttentionConfig(),
)
template_projection_config: TemplateProjectionConfig = field(
default=TemplateProjectionConfig(),
)
extra_msa_embedder_config: ExtraMSAEmbedderConfig = field(
default=ExtraMSAEmbedderConfig(),
)
extra_msa_stack_config: ExtraMSAStackConfig = field(
default=ExtraMSAStackConfig(),
)
evoformer_stack_config: EvoformerStackConfig = field(
default=EvoformerStackConfig(),
)
structure_module_config: StructureModuleConfig = field(
default=StructureModuleConfig(),
)
auxiliary_heads_config: AuxiliaryHeadsConfig = field(
default=AuxiliaryHeadsConfig(),
)
# Training loss configuration:
loss_config: LossConfig = field(default=LossConfig())
# Recycling (last dimension in the batch dict):
recycle_early_stop_enabled: bool = False
recycle_early_stop_tolerance: float = 0.5
# Template features configuration:
templates_enabled: bool = True
embed_template_torsion_angles: bool = True
# max_templates: int = 4 # Number of templates (N_templ)
# Template pair features embedder configuration:
template_pair_feat_distogram_min_bin: float = 3.25
template_pair_feat_distogram_max_bin: float = 50.75
template_pair_feat_distogram_num_bins: int = 39 #
template_pair_feat_use_unit_vector: bool = False # True
template_pair_feat_inf: float = 1e5
template_pair_feat_eps: float = 1e-6
[docs]
@classmethod
def from_preset(
cls,
is_multimer: bool = False,
precision: str = "fp32",
enable_ptm: bool = False,
enable_templates: bool = False,
inference_chunk_size: Optional[int] = 4,
inference_block_size: Optional[int] = None,
**additional_options,
) -> "AlphaFoldConfig":
cfg = {
"is_multimer": is_multimer,
"templates_enabled": enable_templates,
"embed_template_torsion_angles": enable_templates,
}
if is_multimer:
cfg = _update(
cfg,
{
"input_embedder_config": {
"tf_dim": 21,
"max_relative_chain": 2,
"max_relative_index": 32,
"use_chain_relative": True,
},
"template_angle_embedder_config": {
"ta_dim": 34,
},
"template_pair_embedder_config": {
"c_dgram": 39,
"c_aatype": 22,
},
"template_pair_stack_config": {
"tri_att_first": False,
},
"evoformer_stack_config": {
"outer_product_mean_first": True,
},
"extra_msa_stack_config": {
"outer_product_mean_first": True,
},
"structure_module_config": {"scale_factor": 20.0, "is_multimer": True},
"auxiliary_heads_config": {
"masked_msa_head_config": {
"c_out": 22,
},
},
"loss_config": {
"chain_center_of_mass_config": {
"enabled": True,
},
},
},
)
if inference_chunk_size is not None or inference_block_size is not None:
cfg = _update(cfg, _inference_stage(chunk_size=inference_chunk_size, block_size=inference_block_size))
if enable_ptm:
cfg = _update(cfg, _ptm_preset())
if precision in {"fp32", "tf32", "bf16"}:
pass
elif precision in {"amp", "fp16"}:
cfg = _update(cfg, _half_precision_settings())
else:
raise ValueError(f"Unknown precision={repr(precision)}")
cfg = _update(cfg, additional_options)
return cls.from_dict(cfg)
[docs]
def to_dict(self) -> dict:
return asdict(self)
[docs]
@classmethod
def from_dict(cls, cfg: dict) -> AlphaFoldConfig:
return dacite.from_dict(cls, cfg, dacite.Config(strict=True, check_types=True))
def _inference_stage(
chunk_size: Optional[int] = None,
block_size: Optional[int] = None,
) -> dict:
return {
"template_pair_stack_config": {
"chunk_size_tri_att": chunk_size,
"block_size_tri_mul": block_size,
},
"template_pointwise_attention_config": {
"chunk_size": chunk_size,
},
"extra_msa_stack_config": {
"chunk_size_msa_att": chunk_size,
"chunk_size_opm": chunk_size,
"chunk_size_tri_att": chunk_size,
"block_size_tri_mul": block_size,
},
"evoformer_stack_config": {
"chunk_size_msa_att": chunk_size,
"chunk_size_opm": chunk_size,
"chunk_size_tri_att": chunk_size,
"block_size_tri_mul": block_size,
},
}
def _ptm_preset() -> dict:
return {
"auxiliary_heads_config": {
"tm_score_head_enabled": True,
},
"loss_config": {
"tm_loss_config": {
"enabled": True,
"weight": 0.1,
},
},
}
def _half_precision_settings() -> dict:
return {
"recycling_embedder_config": {"inf": 1e4},
"template_pair_stack_config": {"inf": 1e4},
"template_pointwise_attention_config": {"inf": 1e4},
"extra_msa_stack_config": {"inf": 1e4},
"evoformer_stack_config": {"inf": 1e4},
"structure_module_config": {"inf": 1e4},
"template_pair_feat_inf": 1e4,
}
[docs]
@dataclass
class FeaturePipelineConfig:
preset: str = ""
is_multimer: bool = False
seed: int = 0
num_chunks: int = 8
# Fix input sizes:
fixed_size: bool = True
# MSA features configuration:
max_msa_clusters: int = 128 # Number of clustered MSA sequences (N_clust)
max_extra_msa: int = 1024 # Number of unclustered extra sequences (N_extra_seq)
sample_msa_distillation_enabled: bool = False
max_distillation_msa_clusters: int = 1000
# Supplementary '1.2.6 MSA block deletion'
# MSA block deletion configurations:
block_delete_msa_enabled: bool = True
msa_fraction_per_deletion_block: float = 0.3
randomize_num_msa_deletion_blocks: bool = False
num_msa_deletion_blocks: int = 5
# Supplementary '1.2.7 MSA clustering':
# Masked MSA configurations:
masked_msa_enabled: bool = True
masked_msa_profile_prob: float = 0.1
masked_msa_same_prob: float = 0.1
masked_msa_uniform_prob: float = 0.1
masked_msa_replace_fraction: float = 0.15
# Recycling (last dimension in the batch dict):
max_recycling_iters: int = 3
# uniform_recycling: bool = False
# Resample MSA in recycling:
resample_msa_in_recycling: bool = True
# Concatenate template sequences to MSA clusters:
# reduce_msa_clusters_by_max_templates: bool = True
# Sequence crop & pad size (for "train" mode only):
residue_cropping_enabled: bool = False
crop_size: int = 256 # N_res
spatial_crop_prob: float = 0.5
interface_threshold: float = 10.0
# Primary sequence and MSA related features names:
primary_raw_feature_names: List[str] = field(
default_factory=lambda: [
"aatype",
"residue_index",
"msa",
"num_alignments",
"seq_length",
"deletion_matrix",
"num_recycling_iters",
]
)
msa_cluster_features_enabled: bool = True
# Template features configuration:
templates_enabled: bool = True
embed_template_torsion_angles: bool = True
max_templates: int = 4 # Number of templates (N_templ)
# max_template_hits: int = 4
shuffle_top_k_prefiltered: int = 20
subsample_templates: bool = False
# Template related raw features names:
template_raw_feature_names: List[str] = field(
default_factory=lambda: [
"template_all_atom_positions",
"template_sum_probs",
"template_aatype",
"template_all_atom_mask",
]
)
# Generate supervised features:
supervised_features_enabled: bool = False
# Target and related to supervised training feature names:
supervised_raw_feature_names: List[str] = field(
default_factory=lambda: [
"all_atom_mask",
"all_atom_positions",
"resolution",
"is_distillation",
"use_clamped_fape",
]
)
# Training loss configuration:
clamped_fape_enabled: bool = False
clamped_fape_probability: float = 0.9
self_distillation_plddt_threshold: float = 50.0
[docs]
def feature_names(self) -> List[str]:
names = self.primary_raw_feature_names.copy()
if self.templates_enabled:
names += self.template_raw_feature_names
if self.supervised_features_enabled:
names += self.supervised_raw_feature_names
return names
def __post_init__(self):
if self.is_multimer:
self.primary_raw_feature_names.extend(
[
"msa_mask",
"seq_mask",
"asym_id",
"entity_id",
"sym_id",
]
)
else:
self.primary_raw_feature_names.append("between_segment_residues")
[docs]
@classmethod
def from_preset(
cls,
preset: str,
seed: int,
is_multimer: bool = False,
**additional_options,
) -> FeaturePipelineConfig:
cfg = {
"seed": seed,
}
if preset == "predict":
cfg = _update(cfg, _predict_mode(is_multimer))
elif preset == "eval":
cfg = _update(cfg, _eval_mode(is_multimer))
elif preset == "train":
cfg = _update(cfg, _train_mode(is_multimer))
else:
raise ValueError(f"Unknown preset: '{preset}'")
if is_multimer:
cfg = _update(
cfg,
{
"is_multimer": True,
"max_recycling_iters": 20,
},
)
cfg = _update(cfg, additional_options)
return cls.from_dict(cfg)
[docs]
def to_dict(self) -> dict:
return asdict(self)
[docs]
@classmethod
def from_dict(cls, cfg: dict) -> FeaturePipelineConfig:
return dacite.from_dict(cls, cfg, dacite.Config(strict=True, check_types=True))
def _predict_mode(is_multimer: bool = False) -> dict:
dic = {
"preset": "predict",
"fixed_size": True,
"subsample_templates": False,
"block_delete_msa_enabled": False,
"max_msa_clusters": 508,
"max_extra_msa": 1024,
"max_templates": 4,
"residue_cropping_enabled": False,
"supervised_features_enabled": False,
# "uniform_recycling": False,
}
if is_multimer:
dic.update(
{
"max_msa_clusters": 508,
"max_extra_msa": 2048,
}
)
return dic
def _eval_mode(is_multimer: bool = False) -> dict:
dic = {
"preset": "eval",
"fixed_size": True,
"subsample_templates": False,
"block_delete_msa_enabled": False,
"max_msa_clusters": 128,
"max_extra_msa": 1024,
"max_templates": 4,
"residue_cropping_enabled": False,
"supervised_features_enabled": True,
# "uniform_recycling": False,
}
if is_multimer:
dic.update(
{
"max_msa_clusters": 508,
"max_extra_msa": 2048,
}
)
return dic
def _train_mode(is_multimer: bool = False) -> dict:
dic = {
"preset": "train",
"fixed_size": True,
"subsample_templates": True,
"block_delete_msa_enabled": True,
"max_msa_clusters": 128,
"max_extra_msa": 1024,
"max_templates": 4,
"shuffle_top_k_prefiltered": 20,
"residue_cropping_enabled": True,
"crop_size": 256,
"supervised_features_enabled": True,
# "uniform_recycling": True,
"clamped_fape_enabled": True,
"clamped_fape_probability": 0.9,
"sample_msa_distillation_enabled": True,
"max_distillation_msa_clusters": 1000,
"distillation_prob": 0.75,
}
if is_multimer:
dic.update(
{
"max_msa_clusters": 508,
"max_extra_msa": 2048,
"block_delete_msa_enabled": False,
"crop_size": 640,
"spatial_crop_prob": 0.5,
"interface_threshold": 10.0,
"clamped_fape_enabled": True,
"clamped_fape_probability": 1.0,
}
)
return dic
FEATURE_SHAPES = {
"aatype": (NUM_RES,),
"all_atom_mask": (NUM_RES, 37),
"all_atom_positions": (NUM_RES, 37, 3),
"atom14_alt_gt_exists": (NUM_RES, 14),
"atom14_alt_gt_positions": (NUM_RES, 14, 3),
"atom14_atom_exists": (NUM_RES, 14),
"atom14_atom_is_ambiguous": (NUM_RES, 14),
"atom14_gt_exists": (NUM_RES, 14),
"atom14_gt_positions": (NUM_RES, 14, 3),
"atom37_atom_exists": (NUM_RES, 37),
"backbone_rigid_mask": (NUM_RES,),
"backbone_rigid_tensor": (NUM_RES, 4, 4),
"bert_mask": (NUM_MSA_SEQ, NUM_RES),
"chi_angles_sin_cos": (NUM_RES, 4, 2),
"chi_mask": (NUM_RES, 4),
"extra_deletion_value": (NUM_EXTRA_SEQ, NUM_RES),
"extra_has_deletion": (NUM_EXTRA_SEQ, NUM_RES),
"extra_msa_mask": (NUM_EXTRA_SEQ, NUM_RES),
"extra_msa": (NUM_EXTRA_SEQ, NUM_RES),
"extra_msa_row_mask": (NUM_EXTRA_SEQ,),
"is_distillation": (),
"msa_feat": (NUM_MSA_SEQ, NUM_RES, 49),
"msa_mask": (NUM_MSA_SEQ, NUM_RES),
"msa_row_mask": (NUM_MSA_SEQ,),
"num_alignments": (),
"num_recycling_iters": (),
"num_templates": (),
"pseudo_beta_mask": (NUM_RES,),
"pseudo_beta": (NUM_RES, 3),
"residue_index": (NUM_RES,),
"residx_atom14_to_atom37": (NUM_RES, 14),
"residx_atom37_to_atom14": (NUM_RES, 37),
"resolution": (),
"rigidgroups_alt_gt_frames": (NUM_RES, 8, 4, 4),
"rigidgroups_group_exists": (NUM_RES, 8),
"rigidgroups_group_is_ambiguous": (NUM_RES, 8),
"rigidgroups_gt_exists": (NUM_RES, 8),
"rigidgroups_gt_frames": (NUM_RES, 8, 4, 4),
"seq_length": (),
"seq_mask": (NUM_RES,),
"target_feat": (NUM_RES, 22),
"template_aatype": (NUM_TEMPLATES, NUM_RES),
"template_all_atom_mask": (NUM_TEMPLATES, NUM_RES, 37),
"template_all_atom_positions": (NUM_TEMPLATES, NUM_RES, 37, 3),
"template_alt_torsion_angles_sin_cos": (NUM_TEMPLATES, NUM_RES, 7, 2),
"template_mask": (NUM_TEMPLATES,),
"template_pseudo_beta_mask": (NUM_TEMPLATES, NUM_RES),
"template_pseudo_beta": (NUM_TEMPLATES, NUM_RES, 3),
"template_sum_probs": (NUM_TEMPLATES, 1),
"template_torsion_angles_mask": (NUM_TEMPLATES, NUM_RES, 7),
"template_torsion_angles_sin_cos": (NUM_TEMPLATES, NUM_RES, 7, 2),
"true_msa": (NUM_MSA_SEQ, NUM_RES),
}
MULTIMER_FEATURE_SHAPES = {
"aatype": (NUM_RES,),
"all_atom_mask": (NUM_RES, 37),
"all_atom_positions": (NUM_RES, 37, 3),
"alt_chi_angles": (NUM_RES, 4),
"assembly_num_chains": (),
"asym_id": (NUM_RES,),
"atom14_alt_gt_exists": (NUM_RES, 14),
"atom14_alt_gt_positions": (NUM_RES, 14, 3),
"atom14_atom_exists": (NUM_RES, 14),
"atom14_atom_is_ambiguous": (NUM_RES, 14),
"atom14_gt_exists": (NUM_RES, 14),
"atom14_gt_positions": (NUM_RES, 14, 3),
"atom37_atom_exists": (NUM_RES, 37),
"backbone_rigid_mask": (NUM_RES,),
"backbone_rigid_tensor": (NUM_RES, 4, 4),
"bert_mask": (NUM_MSA_SEQ, NUM_RES),
"chi_angles": (NUM_RES, 4),
"chi_mask": (NUM_RES, 4),
"cluster_bias_mask": (NUM_MSA_SEQ),
"cluster_deletion_mean": (NUM_MSA_SEQ, NUM_RES),
"cluster_profile": (NUM_MSA_SEQ, NUM_RES, 23),
"deletion_matrix": (NUM_MSA_SEQ, NUM_RES),
"deletion_mean": (NUM_RES,),
"entity_id": (NUM_RES,),
"entity_mask": (NUM_RES,),
"extra_deletion_matrix": (NUM_EXTRA_SEQ, NUM_RES),
"extra_deletion_value": (NUM_EXTRA_SEQ, NUM_RES),
"extra_has_deletion": (NUM_EXTRA_SEQ, NUM_RES),
"extra_msa_mask": (NUM_EXTRA_SEQ, NUM_RES),
"extra_msa": (NUM_EXTRA_SEQ, NUM_RES),
"extra_msa_row_mask": (NUM_EXTRA_SEQ),
"is_distillation": (),
"msa_feat": (NUM_MSA_SEQ, NUM_RES, 49),
"msa_mask": (NUM_MSA_SEQ, NUM_RES),
"msa": (NUM_MSA_SEQ, NUM_RES),
"msa_profile": (NUM_RES, 22),
"msa_row_mask": (NUM_MSA_SEQ,),
"num_alignments": (),
"num_recycling_iters": (),
"num_templates": (),
"pseudo_beta_mask": (NUM_RES,),
"pseudo_beta": (NUM_RES, 3),
"residue_index": (NUM_RES,),
"residx_atom14_to_atom37": (NUM_RES, 14),
"residx_atom37_to_atom14": (NUM_RES, 37),
"resolution": (),
"rigidgroups_alt_gt_frames": (NUM_RES, 8, 4, 4),
"rigidgroups_group_exists": (NUM_RES, 8),
"rigidgroups_group_is_ambiguous": (NUM_RES, 8),
"rigidgroups_gt_exists": (NUM_RES, 8),
"rigidgroups_gt_frames": (NUM_RES, 8, 4, 4),
"seq_length": (),
"seq_mask": (NUM_RES,),
"sym_id": (NUM_RES,),
"target_feat": (NUM_RES, 21),
"template_aatype": (NUM_TEMPLATES, NUM_RES),
"template_all_atom_mask": (NUM_TEMPLATES, NUM_RES, 37),
"template_all_atom_positions": (NUM_TEMPLATES, NUM_RES, 37, 3),
"template_backbone_affine_mask": (NUM_TEMPLATES, NUM_RES),
"template_backbone_affine_tensor": (NUM_TEMPLATES, NUM_RES, 4, 4),
"template_mask": (NUM_TEMPLATES),
"template_pseudo_beta_mask": (NUM_TEMPLATES, NUM_RES),
"template_pseudo_beta": (NUM_TEMPLATES, NUM_RES, 3),
"template_sum_probs": (NUM_TEMPLATES, 1),
"true_msa": (NUM_MSA_SEQ, NUM_RES),
}
MONOMER_OUTPUT_SHAPES = {
"msa": (NUM_MSA_SEQ, NUM_RES, 256),
"pair": (NUM_RES, NUM_RES, 128),
"single": (NUM_RES, 384),
"sm_frames": (8, NUM_RES, 7),
"sm_sidechain_frames": (8, NUM_RES, 8, 4, 4),
"sm_unnormalized_angles": (8, NUM_RES, 7, 2),
"sm_angles": (8, NUM_RES, 7, 2),
"sm_positions": (8, NUM_RES, 14, 3),
"sm_states": (8, NUM_RES, 384),
"sm_single": (NUM_RES, 384),
"final_atom_positions": (NUM_RES, 37, 3),
"final_atom_mask": (NUM_RES, 37),
"final_affine_tensor": (NUM_RES, 7),
"lddt_logits": (NUM_RES, 50),
"plddt": (NUM_RES,),
"distogram_logits": (NUM_RES, NUM_RES, 64),
"masked_msa_logits": (NUM_MSA_SEQ, NUM_RES, 23),
"experimentally_resolved_logits": (NUM_RES, 37),
"tm_logits": (NUM_RES, NUM_RES, 64),
"ptm_score": (),
"aligned_confidence_probs": (NUM_RES, NUM_RES, 64),
"predicted_aligned_error": (NUM_RES, NUM_RES),
"max_predicted_aligned_error": (),
"mean_plddt": (),
}
MULTIMER_OUTPUT_SHAPES = {
"msa": (NUM_MSA_SEQ, NUM_RES, 256),
"pair": (NUM_RES, NUM_RES, 128),
"single": (NUM_RES, 384),
"sm_frames": (8, NUM_RES, 4, 4),
"sm_sidechain_frames": (8, NUM_RES, 8, 4, 4),
"sm_unnormalized_angles": (8, NUM_RES, 7, 2),
"sm_angles": (8, NUM_RES, 7, 2),
"sm_positions": (8, NUM_RES, 14, 3),
"sm_states": (8, NUM_RES, 384),
"sm_single": (NUM_RES, 384),
"final_atom_positions": (NUM_RES, 37, 3),
"final_atom_mask": (NUM_RES, 37),
"final_affine_tensor": (NUM_RES, 4, 4),
"lddt_logits": (NUM_RES, 50),
"plddt": (NUM_RES,),
"distogram_logits": (NUM_RES, NUM_RES, 64),
"masked_msa_logits": (NUM_MSA_SEQ, NUM_RES, 22),
"experimentally_resolved_logits": (NUM_RES, 37),
"tm_logits": (NUM_RES, NUM_RES, 64),
"ptm_score": (),
"iptm_score": (),
"weighted_ptm_score": (),
"aligned_confidence_probs": (NUM_RES, NUM_RES, 64),
"predicted_aligned_error": (NUM_RES, NUM_RES),
"max_predicted_aligned_error": (),
"mean_plddt": (),
}
def _update(left: dict, right: dict) -> dict:
assert isinstance(left, dict)
assert isinstance(right, dict)
for k, v in right.items():
if isinstance(v, dict):
left[k] = _update(left.get(k, {}), v)
else:
left[k] = v
return left
[docs]
@dataclass
class TrainingConfig:
# Adam optimizer constants:
optimizer_adam_beta_1 = 0.9
optimizer_adam_beta_2 = 0.999
optimizer_adam_eps = 1e-6
optimizer_adam_weight_decay = 0.0
optimizer_adam_amsgrad = False
# Whether to enable gradient clipping by the max norm value:
gradient_clipping: bool = True
clip_grad_max_nrom: float = 0.1
# Whether to enable Stochastic Weight Averaging (SWA):
swa_enabled: bool = True
swa_decay_rate: float = 0.9
# Sequence crop & pad size:
# train_sequence_crop_size: int = 256 # N_res
# Recycling (last dimension in the batch dict):
# num_recycling_iters: int = 3
[docs]
@classmethod
def from_preset(cls, **additional_options) -> TrainingConfig:
cfg = {**additional_options}
return cls.from_dict(cfg)
[docs]
def to_dict(self) -> dict:
return asdict(self)
[docs]
@classmethod
def from_dict(cls, cfg: dict) -> TrainingConfig:
return cls(**cfg)