BetaTCVAE

class behavenet.models.vaes.BetaTCVAE(hparams)[source]

Bases: VAE

Beta Total Correlation VAE class.

This class constructs convolutional variational autoencoders and decomposes the KL divergence term in the ELBO into three terms: 1. index code mutual information 2. total correlation 3. dimension-wise KL

The total correlation term is up-weighted to encourage “disentangled” latents; for more information, see https://arxiv.org/pdf/1802.04942.pdf.

Methods Summary

loss(data[, dataset, accumulate_grad, ...])

Calculate (decomposed) ELBO loss for VAE.

Methods Documentation

loss(data, dataset=0, accumulate_grad=True, chunk_size=200)[source]

Calculate (decomposed) ELBO loss for VAE.

The batch is split into chunks if larger than a hard-coded chunk_size to keep memory requirements low; gradients are accumulated across all chunks before a gradient step is taken.

Parameters:
  • data (dict) – batch of data; keys should include ‘images’ and ‘masks’, if necessary

  • dataset (int, optional) – used for session-specific io layers

  • accumulate_grad (bool, optional) – accumulate gradient for training step

  • chunk_size (int, optional) – batch is split into chunks of this size to keep memory requirements low

Returns:

  • ‘loss’ (float): full elbo

  • ’loss_ll’ (float): log-likelihood portion of elbo

  • ’loss_mi’ (float): code index mutual info portion of kl of elbo

  • ’loss_tc’ (float): total correlation portion of kl of elbo

  • ’loss_dwkl’ (float): dim-wise kl portion of kl of elbo

  • ’loss_mse’ (float): mse (without gaussian constants)

  • ’beta’ (float): weight in front of kl term

Return type:

dict