torchelie.callbacks

Loggers

class torchelie.callbacks.StdoutLogger(log_every=10, prefix='')

Log metrics to stdout. It logs scalars, scalar tensors and strings.

Parameters
  • log_every (int) – batch logging freq. -1 logs on epoch ends only.

  • prefix (str) – prefix for all metrics name

class torchelie.callbacks.VisdomLogger(visdom_env='main', log_every=10, prefix='', post_epoch_ends=True)

Log metrics to Visdom. It logs scalars and scalar tensors as plots, 3D and 4D tensors as images, and strings as HTML.

Parameters
  • visdom_env (str) – name of the target visdom env

  • log_every (int) – batch logging freq. -1 logs on epoch ends only.

  • prefix (str) – prefix for all metrics name

Logging

class torchelie.callbacks.TopkAccAvg(k: int = 5, post_each_batch: bool = True, avg_type: str = 'window')

Log the average topk accuracy to the metrics. The true classes is expected in state['batch'][1], and the logits predictions are expected in state['preds'].

Parameters
  • k (int) – top k

  • post_each_batch (bool) – whether to post on each batch or on epoch end. Default: True.

  • avg_type (str) – type of averaging. One of ‘running’ for a running average over each epoch, ‘moving’ for an exponential moving average, ‘window’ for an average over the 100 last values.

class torchelie.callbacks.AccAvg(post_each_batch=True, avg_type='window')

Log the average accuracy to the metrics. The true classes is expected in state['batch'][1], and the logits predictions are expected in state['preds'].

Parameters
  • post_each_batch (bool) – whether to post on each batch or on epoch end. Default: True.

  • avg_type (str) – type of averaging. One of ‘running’ for a running average over each epoch, ‘moving’ for an exponential moving average, ‘window’ for an average over the 100 last values.

class torchelie.callbacks.EpochMetricAvg(name, post_each_batch=True)

Log to the metrics a value averaged over an epoch in the current state

Parameters
  • name (str) – the name of the value to log

  • post_each_batch (bool) – whether to post on each batch (True, default), or only on epoch ends (False). Notice that posting on each batch necessarily yields an approximate average.

class torchelie.callbacks.WindowedMetricAvg(name, post_each_batch=True, k=100)

Log to the metrics a window averaged value in the current state

Parameters
  • name (str) – the name of the value to log

  • post_each_batch (bool) – whether to post on each batch (True, default), or only on epoch ends (False)

class torchelie.callbacks.Log(from_k, to, post_each_batch=True)

Move a value from the state to the metrics.

Parameters
  • from_k (str) – path in the state of the value to log (as accepted by torchelie.utils.dict_by_key()

  • to (str) – metrics name

class torchelie.callbacks.Counter

Count iterations and epochs. Mandatory as first callback as other callbacks may depend on it.

It writes in the state the current epoch (as ‘epoch’), the current iteration within batch (‘epoch_batch’) and the overall iteration number (‘iters’):

class torchelie.callbacks.Throughput

For debugging and benchmarking purposes: compute the througput (samples/s) of both the forward pass and the full iteration.

It creates 4 new metrics:

  • iter_time: number of seconds between two consecutives batch_start

    event.

  • iter_throughput imgs/s for the whole iteration as

    batch_size / iter_time

  • forward_time: number of seconds between batch_start and batch_end

    events.

  • forward_throughput imgs/s for the forward pass as

    batch_size / forward_time

Visualization

class torchelie.callbacks.ConfusionMatrix(labels, normalize=False)

Generate a HTML confusion matrix.

Parameters
  • labels (list of str) – classes name

  • normalize (bool) – whether to show absolute or relative frequencies

class torchelie.callbacks.MetricsTable(post_each_batch=True, epoch_ends=True)

Generate a HTML table with all the current metrics, to be displayed in Visdom.

Parameters

post_each_batch (bool) – whether to post on each batch or on epoch end. Default: True.

class torchelie.callbacks.ClassificationInspector(nb_show, classes, post_each_batch=True)

For image classification tasks, create a HTML report with best classified samples, worst classified samples, and samples closest to the boundary decision.

Parameters
  • nb_show (int) – how many samples to show

  • classes (list of str) – classes names

  • post_each_batch (bool) – or only on epoch end.

class torchelie.callbacks.ImageGradientVis

Log gradients backpropagated to the input as a feature visualization mean. Works only for image data.

class torchelie.callbacks.GANMetrics(real_key: str = 'batch.0', fake_key: str = 'fake', device='cpu', metrics: List[str] = ['fid', 'kid', 'precision', 'recall'])

Kernel Inception Distance and ‘Fréchet Inception Distance <https://arxiv.org/abs/1706.08500>`_ on batches.

Parameters
  • real_key (str) – name of the batch of real images in the state

  • fake_key (str) – name of the batch of generated images in the state

  • device – a torch device on which to compute the InceptionV3 activations.

  • metrics (List[str]) – which metrics to use ‘kid’, ‘fid’, ‘precision’,

  • 'recall'

  • 'ids'.

Warning

The FID is heavily biased but stable. This means the scale of the FID will heavily vary according to your batch size. If you want to compare your FID with papers, do it on 50k samples to match the scale. However, you can use the FID on much smaller batches to track your training and relative improvements.

tl;dr: FID values can only be compared for a similar number of samples, papers use 50k, but for your own experiments any number work as long as it stays the same.

See page 8 Fig. 1 a for more details.

Warning

The KID is not biased but slightly unstable with less than 500 samples. With more than 500 samples, the FID can be compared between papers. With less than that, you can expect some variance of about +/- 1e-3.

See page 8 Fig. 1 a for more details.

Note

When doing distributed training, the computation is ran on all gathered batches.

Note

IDS is Paired Inception Discriminative Score from CoModGAN

class torchelie.callbacks.SegmentationInspector(nb_show, classes, post_each_batch=True)

For image binnary segmentation tasks, create a HTML report with best segmented samples, worst segmented samples, and samples closest to the boundary decision.

Parameters
  • nb_show (int) – how many samples to show

  • classes (list of str) – classes names

  • post_each_batch (bool) – or only on epoch end.

Training

class torchelie.callbacks.Optimizer(opt, accumulation=1, clip_grad_norm=None, centralize_grad=False, log_lr=False, log_mom=False, scaler=None, grad_multiplier=1)

Apply an optimizer’s step() and zero_grad().

Parameters
  • opt (Optimizer) – the optimizer to use

  • accumulations (int) – number of batches to accumulate gradients over

  • clip_grad_norm (float or None) – maximal norm of gradients to clip, before applying opt.step()

  • log_lr (bool) – whether to log the current learning rates in the metrics

  • log_mom (bool) – whether to log the current momentum / beta1 in the metrics

class torchelie.callbacks.LRSched(sched, metric='loss', step_each_batch=False)

Call lr_sched.step().

Parameters
  • sched (Scheduler) – the scheduler to run

  • metric (str or None) – if step() takes a value as an argument, that value should be in the state, and named here. Otherwise, just use None if it takes no argument.

  • step_each_batch (bool) – whether to call step() on each batch or on each epoch.

Model

class torchelie.callbacks.Polyak(original: torch.nn.modules.module.Module, copy: torch.nn.modules.module.Module, beta: float = 0.999)

Polyak averaging (Exponential moving average).

Parameters
  • original (nn.Module) – source module

  • copy (nn.Module) – averaged model

  • beta (float) – decay value

class torchelie.callbacks.Checkpoint(filename_base, objects, max_saves=10, key_best=None)

Save object to disk on epoch end

Parameters
  • filename_base (str) – a format string that is the filename. The format string can have keyword parameters that will be indexed in the state.

  • objects – what to save. It must have a state_dict() member

  • max_saves (int) – maximum number of checkpoints to save. Older checkpoints will be removed.

  • key_best (func) – key to determinte the best test. The value of the key parameter should be a function that takes a single argument and returns a key to use for sorting purposes.

Misc

class torchelie.callbacks.CallRecipe(loop, run_every=100, prefix='test', init_fun=None)

Call another recipe.

Parameters
  • loop (Recipe) – the recipe to call

  • run_every (int) – how often to call that recipe

  • prefix (str) – prefix of the metrics of the recipe

  • init_fun (Callable or None) – a fun to call before running the recipe