[docs]defforward(self,batch):ifself.averaged_modelisNone:raiseRuntimeError("Weight averaging is not enabled")returnself.averaged_model(batch)
[docs]classswa_avg_fn:"""Averaging function for EMA with configurable decay rate. Suppl. '1.11.7 Evaluator setup'. """def__init__(self,decay_rate:float)->None:self.decay_rate=decay_ratedef__call__(self,averaged_model_parameter:torch.Tensor,model_parameter:torch.Tensor,num_averaged:int,)->torch.Tensor:returnaveraged_model_parameter+(model_parameter-averaged_model_parameter)*(1.0-self.decay_rate)