kl_div_to_std_normal

behavenet.fitting.losses.kl_div_to_std_normal(mu, logvar)[source]

Compute element-wise KL(q(z) || N(0, 1)) where q(z) is a normal parameterized by mu, logvar.

Parameters:
  • mu (torch.Tensor) – mean parameter of shape (n_frames, n_dims)

  • logvar – log variance parameter of shape (n_frames, n_dims)

Returns:

KL divergence summed across dims, averaged across batch

Return type:

torch.Tensor