Source code for deepfold.train.gradient_clipping

# Copyright 2023 NVIDIA CORPORATION


import math
from typing import Optional

import torch
import torch.distributed


[docs] class AsyncGradientClipping: def __init__( self, device: torch.device, comm_group: Optional[torch.distributed.ProcessGroup] = None, norm_type: float = 2.0, ) -> None: self.comm_group = comm_group if comm_group is not None else torch.distributed.group.WORLD self.norm_type = norm_type self._norm_acc = torch.tensor(0.0, device=device)
[docs] def get_clip_scale( self, max_norm: float, eps: float = 1e-6, ) -> float: grad_norm_acc = self._norm_acc.item() grad_norm = math.pow(grad_norm_acc + eps, 1.0 / self.norm_type) clip_scale = min(max_norm / grad_norm, 1.0) self._norm_acc.zero_() return clip_scale
[docs] def update_norm_from_buckets( state: AsyncGradientClipping, bucket: torch.distributed.GradBucket, ) -> torch.futures.Future[torch.Tensor]: grad = bucket.buffer() world_size = state.comm_group.size() if state.comm_group is not None else 1 grad.div_(world_size) def _acc_grad_norm(fut: torch.futures.Future[torch.Tensor]) -> torch.Tensor: synced_grad = fut.value()[0] # List[torch.Tensor] grad_to_power_p = synced_grad.detach().pow(state.norm_type) state._norm_acc += grad_to_power_p.sum() return synced_grad return ( torch.distributed.all_reduce( grad, group=state.comm_group, async_op=True, ) .get_future() .then(_acc_grad_norm) )