weight_lambda

class torchelie.nn.utils.weight_lambda(module: torch.nn.modules.module.Module, hook_name: str, function, name: str = 'weight')

Apply function() to getattr(module, name) on each forward pass.

Allows to implement things such as weight normalization, or equalized learning rate weight scaling.

Parameters
  • module (nn.Module) – the module to hook on

  • hook_name (str) – an identifier for that WeightLambda hook, such as ‘l2normalize’, ‘weight_norm’, etc.

  • function (Callable) – a function of the form (torch.Tensor) -> torch.Tensor that takes applies the desired computation to the module’s parameter.

  • name (str) – the name of the module’s parameter to apply the hook on. Default: ‘weight’.

Returns

the module with the hook