dimension_wise_kl_to_std_normal

behavenet.fitting.losses.dimension_wise_kl_to_std_normal(z, mu, logvar)[source]

Estimate dimensionwise KL divergence to standard normal in a batch.

Parameters:
  • z (torch.Tensor) – sample of shape (n_frames, n_dims)

  • mu (torch.Tensor) – mean parameter of shape (n_frames, n_dims)

  • logvar (torch.Tensor) – log variance parameter of shape (n_frames, n_dims)

Returns:

dimension-wise KL to standard normal for batch, scalar value

Return type:

torch.Tensor