WeightLambda

class torchelie.nn.utils.WeightLambda(hook_name: str, name: str, function)

Apply a lambda function as a hook to the weight matrix of a layer before a forward pass.

Don’t use it directly, use the functions weight_lambda() and remove_weight_lambda() instead.

Parameters
  • hook_name (str) – an identifier for that WeightLambda hook, such as ‘l2normalize’, ‘weight_norm’, etc.

  • name (str) – the name of the module’s parameter to apply the hook on

  • function (Callable) – a function of the form (torch.Tensor) -> torch.Tensor that takes applies the desired computation to the module’s parameter.

static apply(module, hook_name: str, name: str, function)torchelie.nn.utils.WeightLambda
remove(module: torch.nn.modules.module.Module)None
name: str