torchelie.distributions

class torchelie.distributions.GaussianMixture(weights: torch.Tensor, locs: torch.Tensor, scales: torch.Tensor)

Mixture of gaussian distributions. Each tensor contains an additional dimension with number of distributions elements.

Parameters
  • weights (tensor) – un-normalized weights of distributions

  • loc (tensor) – mean of the distributions

  • scale (tensor) – scale of the distributions

  • dim (int) – dimension reprenseting the various distributions, that will weighted and averaged on.

log_prob(x: torch.Tensor)torch.Tensor
property mean
class torchelie.distributions.Logistic(loc: torch.Tensor, scale: torch.Tensor)

Logistic distribution

Parameters
  • loc (tensor) – mean of the distribution

  • scale (tensor) – scale of the distribution

class torchelie.distributions.LogisticMixture(weights, locs, scales, dim)

Mixture of Logistic distributions. Each tensor contains an additional dimension with number of distributions elements.

Parameters
  • weights (tensor) – un-normalized weights of distributions

  • loc (tensor) – mean of the distributions

  • scale (tensor) – scale of the distributions

  • dim (int) – dimension reprenseting the various distributions, that will weighted and averaged on.

log_prob(x: torch.Tensor)torch.Tensor
property mean
torchelie.distributions.parameterized_truncated_normal(uniform: torch.Tensor, mu: float, sigma: float, a: float, b: float)torch.Tensor

Experimental

Warning

parameterized_truncated_normal() is experimental, and may change or be deleted soon if not already broken

.

torchelie.distributions.sample_truncated_normal(*shape, cutoff: float = 2)

Experimental

Warning

sample_truncated_normal() is experimental, and may change or be deleted soon if not already broken

.

torchelie.distributions.truncated_normal(uniform: torch.Tensor, a: float, b: float)torch.Tensor

Experimental

Warning

truncated_normal() is experimental, and may change or be deleted soon if not already broken

.