[docs]defstep(self,iteration:int)->None:ifiteration<=self.warmup_lr_length:lr_value=self.warmup_linspace[iteration-1].item()lr_value=round(lr_value,10)elifiteration<=self.init_lr_length:lr_value=self.init_lrelse:lr_value=self.final_lr# Set only if differes from the previous call:iflr_value!=self.prev_lr_value:set_learning_rate(optimizer=self.optimizer,lr_value=lr_value)self.prev_lr_value=lr_value
[docs]classOpenFoldBenchmarkLRScheduler:def__init__(self,base_lr:float,warmup_lr_init:float,warmup_lr_iters:int,optimizer:torch.optim.Optimizer,)->None:self.base_lr=base_lrself.warmup_lr_init=warmup_lr_initself.warmup_lr_iters=warmup_lr_itersself.optimizer=optimizer# create LR values for the warm-up:assertwarmup_lr_iters>=0self._warmup_linspace=torch.linspace(start=warmup_lr_init,end=base_lr,steps=warmup_lr_iters,dtype=torch.float64,)self._prev_lr_value=Nonedef__call__(self,iteration:int)->None:# Determine lr_value for given iteration:ifiteration<=self.warmup_lr_iters:lr_value=self._warmup_linspace[iteration-1].item()lr_value=round(lr_value,10)else:lr_value=self.base_lr# Set only if differs from the previous call:iflr_value!=self._prev_lr_value:set_learning_rate(optimizer=self.optimizer,lr_value=lr_value)self._prev_lr_value=lr_value