ConditionalAE¶
- class behavenet.models.aes.ConditionalAE(hparams)[source]¶
Bases:
AEConditional autoencoder class.
This class constructs conditional convolutional autoencoders. At the latent layer an additional set of variables, saved under the ‘labels’ key in the hdf5 data file, are concatenated with the latents before being reshaped into a 2D array for decoding.
Methods Summary
Construct the model using hparams.
forward(x[, dataset, labels, labels_2d])Process input data.
loss(data[, dataset, accumulate_grad, ...])Calculate MSE loss for autoencoder.
Methods Documentation
- build_model()[source]¶
Construct the model using hparams.
The ConditionalAE is initialized when
model_class='cond-ae, and currently only supportsmodel_type='conv(i.e. no linear)
- forward(x, dataset=None, labels=None, labels_2d=None, **kwargs)[source]¶
Process input data.
- Parameters:
x (
torch.Tensorobject) – input data of shape (batch, n_channels, y_pix, x_pix)dataset (
int) – used with session-specific io layerslabels (
torch.Tensorobject) – continuous labels corresponding to input data, of shape (batch, n_labels)labels_2d (
torch.Tensorobject) – one-hot labels corresponding to input data, of shape (batch, n_labels, y_pix, x_pix); for a given frame, each channel corresponds to a label and is all zeros with a single value of one in the proper x/y position
- Returns:
y (
torch.Tensor): output of shape (n_frames, n_channels, y_pix, x_pix)x (
torch.Tensor): hidden representation of shape (n_frames, n_latents)
- Return type:
tuple
- loss(data, dataset=0, accumulate_grad=True, chunk_size=200)[source]¶
Calculate MSE loss for autoencoder.
The batch is split into chunks if larger than a hard-coded chunk_size to keep memory requirements low; gradients are accumulated across all chunks before a gradient step is taken.
- Parameters:
data (
dict) – batch of data; keys should include ‘images’ and ‘masks’, if necessarydataset (
int, optional) – used for session-specific io layersaccumulate_grad (
bool, optional) – accumulate gradient for training stepchunk_size (
int, optional) – batch is split into chunks of this size to keep memory requirements low
- Returns:
‘loss’ (
float): mse loss
- Return type:
dict