total_correlation¶
- behavenet.fitting.losses.total_correlation(z, mu, logvar)[source]¶
Estimate total correlation in a batch.
Compute the expectation over a batch of:
E_j [log(q(z(x_j))) - log(prod_l q(z(x_j)_l))]
We ignore the constant as it does not matter for the minimization. The constant should be equal to (n_dims - 1) * log(n_frames * dataset_size).
Code modified from https://github.com/julian-carpenter/beta-TCVAE/blob/master/nn/losses.py
- Parameters:
z (
torch.Tensor) – sample of shape (n_frames, n_dims)mu (
torch.Tensor) – mean parameter of shape (n_frames, n_dims)logvar (
torch.Tensor) – log variance parameter of shape (n_frames, n_dims)
- Returns:
total correlation for batch, scalar value
- Return type:
torch.Tensor