UnitGaussianPrior(in_channels, num_latents, strength=1, kl_reduction='mean')¶
Force a representation to fit a unit gaussian prior. It projects with a nn.Linear the input vector to a mu and sigma that represent a gaussian distribution from which the output is sampled. The backward pass includes a kl divergence loss between N(mu, sigma) and N(0, 1).
This can be used to implement VAEs or information bottlenecks
In train mode, the output is sampled from N(mu, sigma) but at test time mu is returned.
in_channels (int) – dimension of input channels
num_latents (int) – dimension of output latents
strength (float) – strength of the kl loss. When using this to implement a VAE, set strength to
1/number of output dim of the modelor set it to 1 but make sure that the loss for each output dimension is summed, but averaged over the batch.
kl_reduction (str) – how the implicit kl loss is reduced over the batch samples. ‘sum’ means the kl term of each sample is summed, while ‘mean’ divides the loss by the number of examples.
x (Tensor) – A 2D (N, in_channels) tensor
- A 2D (N, num_channels) tensor sampled from the implicit gaussian