BetaTCVAE¶
- class behavenet.models.vaes.BetaTCVAE(hparams)[source]¶
Bases:
VAEBeta Total Correlation VAE class.
This class constructs convolutional variational autoencoders and decomposes the KL divergence term in the ELBO into three terms: 1. index code mutual information 2. total correlation 3. dimension-wise KL
The total correlation term is up-weighted to encourage “disentangled” latents; for more information, see https://arxiv.org/pdf/1802.04942.pdf.
Methods Summary
loss(data[, dataset, accumulate_grad, ...])Calculate (decomposed) ELBO loss for VAE.
Methods Documentation
- loss(data, dataset=0, accumulate_grad=True, chunk_size=200)[source]¶
Calculate (decomposed) ELBO loss for VAE.
The batch is split into chunks if larger than a hard-coded chunk_size to keep memory requirements low; gradients are accumulated across all chunks before a gradient step is taken.
- Parameters:
data (
dict) – batch of data; keys should include ‘images’ and ‘masks’, if necessarydataset (
int, optional) – used for session-specific io layersaccumulate_grad (
bool, optional) – accumulate gradient for training stepchunk_size (
int, optional) – batch is split into chunks of this size to keep memory requirements low
- Returns:
‘loss’ (
float): full elbo’loss_ll’ (
float): log-likelihood portion of elbo’loss_mi’ (
float): code index mutual info portion of kl of elbo’loss_tc’ (
float): total correlation portion of kl of elbo’loss_dwkl’ (
float): dim-wise kl portion of kl of elbo’loss_mse’ (
float): mse (without gaussian constants)’beta’ (
float): weight in front of kl term
- Return type:
dict