index_code_mi¶
- behavenet.fitting.losses.index_code_mi(z, mu, logvar)[source]¶
Estimate index code mutual information in a batch.
We ignore the constant as it does not matter for the minimization. The constant should be equal to log(n_frames * dataset_size).
- Parameters:
z (
torch.Tensor) – sample of shape (n_frames, n_dims)mu (
torch.Tensor) – mean parameter of shape (n_frames, n_dims)logvar (
torch.Tensor) – log variance parameter of shape (n_frames, n_dims)
- Returns:
index code mutual information for batch, scalar value
- Return type:
torch.Tensor