index_code_mi

behavenet.fitting.losses.index_code_mi(z, mu, logvar)[source]

Estimate index code mutual information in a batch.

We ignore the constant as it does not matter for the minimization. The constant should be equal to log(n_frames * dataset_size).

Parameters:
  • z (torch.Tensor) – sample of shape (n_frames, n_dims)

  • mu (torch.Tensor) – mean parameter of shape (n_frames, n_dims)

  • logvar (torch.Tensor) – log variance parameter of shape (n_frames, n_dims)

Returns:

index code mutual information for batch, scalar value

Return type:

torch.Tensor