decomposed_kl

behavenet.fitting.losses.decomposed_kl(z, mu, logvar)[source]

Decompose KL term in VAE loss.

Decomposes the KL divergence loss term of the variational autoencoder into three terms: 1. index code mutual information 2. total correlation 3. dimension-wise KL

None of these terms can be computed exactly when using stochastic gradient descent. This function instead computes approximations as detailed in https://arxiv.org/pdf/1802.04942.pdf.

Parameters:
  • z (torch.Tensor) – sample of shape (n_frames, n_dims)

  • mu (torch.Tensor) – mean parameter of shape (n_frames, n_dims)

  • logvar (torch.Tensor) – log variance parameter of shape (n_frames, n_dims)

Returns:

  • index code mutual information (torch.Tensor)

  • total correlation (torch.Tensor)

  • dimension-wise KL (torch.Tensor)

Return type:

tuple