deepfold.train package¶
Submodules¶
deepfold.train.gradient_clipping module¶
- class deepfold.train.gradient_clipping.AsyncGradientClipping(device: device, comm_group: ProcessGroup | None = None, norm_type: float = 2.0)[source]¶
Bases:
object
- deepfold.train.gradient_clipping.update_norm_from_buckets(state: AsyncGradientClipping, bucket: GradBucket) Future[Tensor] [source]¶
deepfold.train.lr_scheduler module¶
- class deepfold.train.lr_scheduler.AlphaFoldLRScheduler(init_lr: float, final_lr: float, warmup_lr_length: int, init_lr_length: int, optimizer: Optimizer)[source]¶
Bases:
object
AlphaFold learning rate schedule.
Suppl. ‘1.11.3 Optimization details’.
deepfold.train.validation_metrics module¶
- deepfold.train.validation_metrics.compute_validation_metrics(predicted_atom_positions: Tensor, target_atom_positions: Tensor, atom_mask: Tensor, metrics_names: Set[str]) Dict[str, Tensor] [source]¶
- deepfold.train.validation_metrics.drmsd(structure_1: Tensor, structure_2: Tensor, mask: Tensor | None = None) Tensor [source]¶