kl_div_to_std_normal¶
- behavenet.fitting.losses.kl_div_to_std_normal(mu, logvar)[source]¶
Compute element-wise KL(q(z) || N(0, 1)) where q(z) is a normal parameterized by mu, logvar.
- Parameters:
mu (
torch.Tensor) – mean parameter of shape (n_frames, n_dims)logvar – log variance parameter of shape (n_frames, n_dims)
- Returns:
KL divergence summed across dims, averaged across batch
- Return type:
torch.Tensor