edit_model¶
-
class
torchelie.nn.utils.
edit_model
(m: T_Module, f: Callable[[torch.nn.modules.module.Module], torch.nn.modules.module.Module])¶ Allow to edit any part of a model by recursively editing its modules.
For instance, in order to delete all dropout layers and change relus into leakyrelus
def make_leaky_no_dropout(m): if isinstance(m, nn.ReLU): return nn.LeakyReLU(inplace=True) if isinstance(m, nn.Dropout2d): return nn.Identity() return m model = edit_model(model, make_leaky_no_dropout)
- Parameters
m (nn.Module) – the model to edit
(Callabble (f) – nn.Module -> nn.Module): a mapping function applied to all modules and submodules
- Returns
The edited model.