ModuleGraph

class torchelie.nn.ModuleGraph(outputs: Union[str, List[str]])

Allows description of networks as computation graphs. The graph is constructed by labelling inputs and outputs of each node. Each node will be ran in declaration order, fetching its input values from a pool of named values populated from previous node’s output values and keyword arguments in forward.

Simple example:

>>> m = tnn.ModuleGraph(outputs='y')
>>> m.add_operation(
        inputs=['x'],
        operation=nn.Linear(10, 20),
        name='linear',
        outputs=['y'])
>>> m(x=torch.randn(1, 10))
<a bunch of numbers>

Multiple inputs example:

If a layer takes more than 1 input, labels can be a tuple or a list of labels instead. The same applies if a module returns more than 1 output values.

>>> m = tnn.ModuleGraph(outputs=['x1', 'y'])
>>> m.add_operation(
        inputs=['x0'],
        operation=nn.Linear(10, 20)
        name='linear',
        outputs=['x1'])
>>> m.add_operation(
        inputs=['x1', 'z'],
        operation=nn.AdaIN2d(20, 3)
        name='adain',
        outputs=['y'])
>>> m(x0=torch.randn(1, 10), z=torch.randn(1, 3))['y']
<a bunch of numbers>
add_operation(inputs: List[str], outputs: List[str], name: str, operation: torch.nn.modules.module.Module)torchelie.nn.graph.ModuleGraph
forward(**args)
to_dot()str

Experimental

Warning

ModuleGraph.to_dot() is experimental, and may change or be deleted soon if not already broken

.

training: bool