ConditionalAE

class behavenet.models.aes.ConditionalAE(hparams)[source]

Bases: AE

Conditional autoencoder class.

This class constructs conditional convolutional autoencoders. At the latent layer an additional set of variables, saved under the ‘labels’ key in the hdf5 data file, are concatenated with the latents before being reshaped into a 2D array for decoding.

Methods Summary

build_model()

Construct the model using hparams.

forward(x[, dataset, labels, labels_2d])

Process input data.

loss(data[, dataset, accumulate_grad, ...])

Calculate MSE loss for autoencoder.

Methods Documentation

build_model()[source]

Construct the model using hparams.

The ConditionalAE is initialized when model_class='cond-ae, and currently only supports model_type='conv (i.e. no linear)

forward(x, dataset=None, labels=None, labels_2d=None, **kwargs)[source]

Process input data.

Parameters:
  • x (torch.Tensor object) – input data of shape (batch, n_channels, y_pix, x_pix)

  • dataset (int) – used with session-specific io layers

  • labels (torch.Tensor object) – continuous labels corresponding to input data, of shape (batch, n_labels)

  • labels_2d (torch.Tensor object) – one-hot labels corresponding to input data, of shape (batch, n_labels, y_pix, x_pix); for a given frame, each channel corresponds to a label and is all zeros with a single value of one in the proper x/y position

Returns:

  • y (torch.Tensor): output of shape (n_frames, n_channels, y_pix, x_pix)

  • x (torch.Tensor): hidden representation of shape (n_frames, n_latents)

Return type:

tuple

loss(data, dataset=0, accumulate_grad=True, chunk_size=200)[source]

Calculate MSE loss for autoencoder.

The batch is split into chunks if larger than a hard-coded chunk_size to keep memory requirements low; gradients are accumulated across all chunks before a gradient step is taken.

Parameters:
  • data (dict) – batch of data; keys should include ‘images’ and ‘masks’, if necessary

  • dataset (int, optional) – used for session-specific io layers

  • accumulate_grad (bool, optional) – accumulate gradient for training step

  • chunk_size (int, optional) – batch is split into chunks of this size to keep memory requirements low

Returns:

  • ‘loss’ (float): mse loss

Return type:

dict