VAE

class behavenet.models.vaes.VAE(hparams)[source]

Bases: AE

Base variational autoencoder class.

This class constructs convolutional variational autoencoders. The convolutional autoencoder architecture is defined by various keys in the dict that serves as the constructor input. See the behavenet.fitting.ae_model_architecture_generator module to see examples for how this is done.

The VAE class can also be used to fit β-VAE models (see https://arxiv.org/pdf/1804.03599.pdf) by changing the value of the vae.beta parameter in the ae_model.json file; a value of 1 corresponds to a standard VAE; a value >1 will upweight the KL divergence term which, in some cases, can lead to disentangling of the latent representation.

Methods Summary

forward(x[, dataset, use_mean])

Process input data.

loss(data[, dataset, accumulate_grad, ...])

Calculate ELBO loss for VAE.

Methods Documentation

forward(x, dataset=None, use_mean=False, **kwargs)[source]

Process input data.

Parameters:
  • x (torch.Tensor object) – input data

  • dataset (int) – used with session-specific io layers

  • use_mean (bool) – True to skip sampling step

Returns:

  • x_hat (torch.Tensor): output of shape (n_frames, n_channels, y_pix, x_pix)

  • z (torch.Tensor): sampled latent variable of shape (n_frames, n_latents)

  • mu (torch.Tensor): mean paramter of shape (n_frames, n_latents)

  • logvar (torch.Tensor): logvar paramter of shape (n_frames, n_latents)

Return type:

tuple

loss(data, dataset=0, accumulate_grad=True, chunk_size=200)[source]

Calculate ELBO loss for VAE.

The batch is split into chunks if larger than a hard-coded chunk_size to keep memory requirements low; gradients are accumulated across all chunks before a gradient step is taken.

Parameters:
  • data (dict) – batch of data; keys should include ‘images’ and ‘masks’, if necessary

  • dataset (int, optional) – used for session-specific io layers

  • accumulate_grad (bool, optional) – accumulate gradient for training step

  • chunk_size (int, optional) – batch is split into chunks of this size to keep memory requirements low

Returns:

  • ‘loss’ (float): full elbo

  • ’loss_ll’ (float): log-likelihood portion of elbo

  • ’loss_kl’ (float): kl portion of elbo

  • ’loss_mse’ (float): mse (without gaussian constants)

  • ’beta’ (float): weight in front of kl term

Return type:

dict