class torchelie.nn.utils.edit_model(m: T_Module, f: Callable[[torch.nn.modules.module.Module], torch.nn.modules.module.Module])

Allow to edit any part of a model by recursively editing its modules.

For instance, in order to delete all dropout layers and change relus into leakyrelus

def make_leaky_no_dropout(m):
    if isinstance(m, nn.ReLU):
        return nn.LeakyReLU(inplace=True)
    if isinstance(m, nn.Dropout2d):
        return nn.Identity()
    return m
model = edit_model(model, make_leaky_no_dropout)
  • m (nn.Module) – the model to edit

  • (Callabble (f) – nn.Module -> nn.Module): a mapping function applied to all modules and submodules


The edited model.