UnitGaussianPrior

class torchelie.nn.UnitGaussianPrior(in_channels, num_latents, strength=1, kl_reduction='mean')

Force a representation to fit a unit gaussian prior. It projects with a nn.Linear the input vector to a mu and sigma that represent a gaussian distribution from which the output is sampled. The backward pass includes a kl divergence loss between N(mu, sigma) and N(0, 1).

This can be used to implement VAEs or information bottlenecks

In train mode, the output is sampled from N(mu, sigma) but at test time mu is returned.

Parameters
  • in_channels (int) – dimension of input channels

  • num_latents (int) – dimension of output latents

  • strength (float) – strength of the kl loss. When using this to implement a VAE, set strength to 1/number of output dim of the model or set it to 1 but make sure that the loss for each output dimension is summed, but averaged over the batch.

  • kl_reduction (str) – how the implicit kl loss is reduced over the batch samples. ‘sum’ means the kl term of each sample is summed, while ‘mean’ divides the loss by the number of examples.

forward(x)
Parameters

x (Tensor) – A 2D (N, in_channels) tensor

Returns

A 2D (N, num_channels) tensor sampled from the implicit gaussian

distribution.

training: bool