weight_lambda¶
-
class
torchelie.nn.utils.
weight_lambda
(module: torch.nn.modules.module.Module, hook_name: str, function, name: str = 'weight')¶ Apply
function()
togetattr(module, name)
on each forward pass.Allows to implement things such as weight normalization, or equalized learning rate weight scaling.
- Parameters
module (nn.Module) – the module to hook on
hook_name (str) – an identifier for that WeightLambda hook, such as ‘l2normalize’, ‘weight_norm’, etc.
function (Callable) – a function of the form
(torch.Tensor) -> torch.Tensor
that takes applies the desired computation to the module’s parameter.name (str) – the name of the module’s parameter to apply the hook on. Default: ‘weight’.
- Returns
the module with the hook