importmathimportnumpyasnpimporttorchimporttorch.nnasnnfromscipy.statsimporttruncnorm# Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0.0, scale=1.0)TRUNCATED_NORMAL_STDDEV_FACTOR=0.87962566103423978
[docs]classLinear(nn.Linear):"""Linear transformation with extra non-standard initializations. Supplementary '1.11.4 Parameters initialization': Linear layers. Args: in_features: Last dimension of the input tensor. out_features: Last dimension of the output tensor. bias: Whether to learn an additive bias. Default: `True`. init: Parameter initialization method. One of: - "default": LeCun (fan-in) with a truncated normal distribution - "relu": He initialization with a truncated normal distribution - "glorot": fan-average Glorot uniform initialization - "gating": Weights=0, Bias=1 - "normal": Normal initialization with std=1/sqrt(fan_in) - "final": Weights=0, Bias=0 """def__init__(self,in_features:int,out_features:int,bias:bool=True,init:str="default",)->None:super().__init__(in_features=in_features,out_features=out_features,bias=bias,)# By default, the biases of the Linear layers are filled with zeros.ifbias:self.bias.data.fill_(0.0)ifinit=="default":lecun_normal_init_(self.weight.data)elifinit=="relu":he_normal_init_(self.weight.data)elifinit=="glorot":glorot_uniform_init_(self.weight.data)elifinit=="gating":gating_init_(self.weight.data)ifbias:self.bias.data.fill_(1.0)elifinit=="normal":normal_init_(self.weight.data)elifinit=="final":final_init_(self.weight.data)else:raiseValueError(f"unknown init {repr(init)}")
def_calculate_fan(linear_weight_shape:torch.Size,fan:str="fan_in")->float:fan_out,fan_in=linear_weight_shapeiffan=="fan_in":fan_value=fan_ineliffan=="fan_out":fan_value=fan_outeliffan=="fan_avg":fan_value=(fan_in+fan_out)/2else:raiseValueError("Invalid fan option")returnfan_value