dimension_wise_kl_to_std_normal¶
-
behavenet.fitting.losses.dimension_wise_kl_to_std_normal(z, mu, logvar)[source]¶ Estimate dimensionwise KL divergence to standard normal in a batch.
- Parameters
z (
torch.Tensor) – sample of shape (n_frames, n_dims)mu (
torch.Tensor) – mean parameter of shape (n_frames, n_dims)logvar (
torch.Tensor) – log variance parameter of shape (n_frames, n_dims)
- Returns
dimension-wise KL to standard normal for batch, scalar value
- Return type
torch.Tensor