ConditionalVAE¶
- class behavenet.models.vaes.ConditionalVAE(hparams)[source]¶
Bases:
VAEConditional variational autoencoder class.
This class constructs conditional convolutional variational autoencoders. At the latent layer an additional set of variables, saved under the ‘labels’ key in the hdf5 data file, are concatenated with the latents before being reshaped into a 2D array for decoding.
Methods Summary
Construct the model using hparams.
forward(x[, dataset, labels, labels_2d, ...])Process input data.
loss(data[, dataset, accumulate_grad, ...])Calculate ELBO loss for ConditionalVAE.
Methods Documentation
- build_model()[source]¶
Construct the model using hparams.
The ConditionalAE is initialized when
model_class='cond-ae, and currently only supportsmodel_type='conv(i.e. no linear)
- forward(x, dataset=None, labels=None, labels_2d=None, use_mean=False, **kwargs)[source]¶
Process input data.
- Parameters:
x (
torch.Tensorobject) – input data of shape (batch, n_channels, y_pix, x_pix)dataset (
int) – used with session-specific io layerslabels (
torch.Tensorobject) – continuous labels corresponding to input data, of shape (batch, n_labels)labels_2d (
torch.Tensorobject) – one-hot labels corresponding to input data, of shape (batch, n_labels, y_pix, x_pix); for a given frame, each channel corresponds to a label and is all zeros with a single value of one in the proper x/y positionuse_mean (
bool) – True to skip sampling step
- Returns:
y (
torch.Tensor): output of shape (n_frames, n_channels, y_pix, x_pix)x (
torch.Tensor): hidden representation of shape (n_frames, n_latents)
- Return type:
tuple
- loss(data, dataset=0, accumulate_grad=True, chunk_size=200)[source]¶
Calculate ELBO loss for ConditionalVAE.
The batch is split into chunks if larger than a hard-coded chunk_size to keep memory requirements low; gradients are accumulated across all chunks before a gradient step is taken.
- Parameters:
data (
dict) – batch of data; keys should include ‘images’ and ‘masks’, if necessarydataset (
int, optional) – used for session-specific io layersaccumulate_grad (
bool, optional) – accumulate gradient for training stepchunk_size (
int, optional) – batch is split into chunks of this size to keep memory requirements low
- Returns:
‘loss’ (
float): full elbo’loss_ll’ (
float): log-likelihood portion of elbo’loss_kl’ (
float): kl portion of elbo’loss_mse’ (
float): mse (without gaussian constants)’beta’ (
float): weight in front of kl term
- Return type:
dict