VAE¶
- class behavenet.models.vaes.VAE(hparams)[source]¶
Bases:
AEBase variational autoencoder class.
This class constructs convolutional variational autoencoders. The convolutional autoencoder architecture is defined by various keys in the dict that serves as the constructor input. See the
behavenet.fitting.ae_model_architecture_generatormodule to see examples for how this is done.The VAE class can also be used to fit β-VAE models (see https://arxiv.org/pdf/1804.03599.pdf) by changing the value of the vae.beta parameter in the ae_model.json file; a value of 1 corresponds to a standard VAE; a value >1 will upweight the KL divergence term which, in some cases, can lead to disentangling of the latent representation.
Methods Summary
forward(x[, dataset, use_mean])Process input data.
loss(data[, dataset, accumulate_grad, ...])Calculate ELBO loss for VAE.
Methods Documentation
- forward(x, dataset=None, use_mean=False, **kwargs)[source]¶
Process input data.
- Parameters:
x (
torch.Tensorobject) – input datadataset (
int) – used with session-specific io layersuse_mean (
bool) – True to skip sampling step
- Returns:
x_hat (
torch.Tensor): output of shape (n_frames, n_channels, y_pix, x_pix)z (
torch.Tensor): sampled latent variable of shape (n_frames, n_latents)mu (
torch.Tensor): mean paramter of shape (n_frames, n_latents)logvar (
torch.Tensor): logvar paramter of shape (n_frames, n_latents)
- Return type:
tuple
- loss(data, dataset=0, accumulate_grad=True, chunk_size=200)[source]¶
Calculate ELBO loss for VAE.
The batch is split into chunks if larger than a hard-coded chunk_size to keep memory requirements low; gradients are accumulated across all chunks before a gradient step is taken.
- Parameters:
data (
dict) – batch of data; keys should include ‘images’ and ‘masks’, if necessarydataset (
int, optional) – used for session-specific io layersaccumulate_grad (
bool, optional) – accumulate gradient for training stepchunk_size (
int, optional) – batch is split into chunks of this size to keep memory requirements low
- Returns:
‘loss’ (
float): full elbo’loss_ll’ (
float): log-likelihood portion of elbo’loss_kl’ (
float): kl portion of elbo’loss_mse’ (
float): mse (without gaussian constants)’beta’ (
float): weight in front of kl term
- Return type:
dict