decomposed_kl¶
- behavenet.fitting.losses.decomposed_kl(z, mu, logvar)[source]¶
Decompose KL term in VAE loss.
Decomposes the KL divergence loss term of the variational autoencoder into three terms: 1. index code mutual information 2. total correlation 3. dimension-wise KL
None of these terms can be computed exactly when using stochastic gradient descent. This function instead computes approximations as detailed in https://arxiv.org/pdf/1802.04942.pdf.
- Parameters:
z (
torch.Tensor) – sample of shape (n_frames, n_dims)mu (
torch.Tensor) – mean parameter of shape (n_frames, n_dims)logvar (
torch.Tensor) – log variance parameter of shape (n_frames, n_dims)
- Returns:
index code mutual information (
torch.Tensor)total correlation (
torch.Tensor)dimension-wise KL (
torch.Tensor)
- Return type:
tuple