kl_div_to_std_normal¶
-
behavenet.fitting.losses.kl_div_to_std_normal(mu, logvar)[source]¶ Compute element-wise KL(q(z) || N(0, 1)) where q(z) is a normal parameterized by mu, logvar.
- Parameters
mu (
torch.Tensor) – mean parameter of shape (n_frames, n_dims)logvar – log variance parameter of shape (n_frames, n_dims)
- Returns
KL divergence summed across dims, averaged across batch
- Return type
torch.Tensor