ConditionalVAE

class behavenet.models.vaes.ConditionalVAE(hparams)[source]

Bases: VAE

Conditional variational autoencoder class.

This class constructs conditional convolutional variational autoencoders. At the latent layer an additional set of variables, saved under the ‘labels’ key in the hdf5 data file, are concatenated with the latents before being reshaped into a 2D array for decoding.

Methods Summary

build_model()

Construct the model using hparams.

forward(x[, dataset, labels, labels_2d, ...])

Process input data.

loss(data[, dataset, accumulate_grad, ...])

Calculate ELBO loss for ConditionalVAE.

Methods Documentation

build_model()[source]

Construct the model using hparams.

The ConditionalAE is initialized when model_class='cond-ae, and currently only supports model_type='conv (i.e. no linear)

forward(x, dataset=None, labels=None, labels_2d=None, use_mean=False, **kwargs)[source]

Process input data.

Parameters:
  • x (torch.Tensor object) – input data of shape (batch, n_channels, y_pix, x_pix)

  • dataset (int) – used with session-specific io layers

  • labels (torch.Tensor object) – continuous labels corresponding to input data, of shape (batch, n_labels)

  • labels_2d (torch.Tensor object) – one-hot labels corresponding to input data, of shape (batch, n_labels, y_pix, x_pix); for a given frame, each channel corresponds to a label and is all zeros with a single value of one in the proper x/y position

  • use_mean (bool) – True to skip sampling step

Returns:

  • y (torch.Tensor): output of shape (n_frames, n_channels, y_pix, x_pix)

  • x (torch.Tensor): hidden representation of shape (n_frames, n_latents)

Return type:

tuple

loss(data, dataset=0, accumulate_grad=True, chunk_size=200)[source]

Calculate ELBO loss for ConditionalVAE.

The batch is split into chunks if larger than a hard-coded chunk_size to keep memory requirements low; gradients are accumulated across all chunks before a gradient step is taken.

Parameters:
  • data (dict) – batch of data; keys should include ‘images’ and ‘masks’, if necessary

  • dataset (int, optional) – used for session-specific io layers

  • accumulate_grad (bool, optional) – accumulate gradient for training step

  • chunk_size (int, optional) – batch is split into chunks of this size to keep memory requirements low

Returns:

  • ‘loss’ (float): full elbo

  • ’loss_ll’ (float): log-likelihood portion of elbo

  • ’loss_kl’ (float): kl portion of elbo

  • ’loss_mse’ (float): mse (without gaussian constants)

  • ’beta’ (float): weight in front of kl term

Return type:

dict