MSPSVAE¶
-
class
behavenet.models.vaes.MSPSVAE(hparams)[source]¶ Bases:
behavenet.models.vaes.PSVAEPartitioned subspace variational autoencoder class for multiple sessions.
Methods Summary
Construct the model using hparams.
export_latents(data_gen[, filename])Need to create standard data generator in order to export latents.
forward(x[, dataset, use_mean])Process input data.
get_inverse_transformed_latents(inputs[, …])Return latents after they have been transformed using the diagonal mapping D.
get_predicted_labels(x[, dataset, use_mean])Process input data to get predicted labels.
get_transformed_latents(inputs[, dataset, …])Return latents after supervised subspace has been transformed to original label space.
loss(datas[, dataset, accumulate_grad, …])Calculate modified ELBO loss for MSPSVAE.
Methods Documentation
-
export_latents(data_gen, filename=None)[source]¶ Need to create standard data generator in order to export latents.
-
forward(x, dataset=None, use_mean=False, **kwargs)[source]¶ Process input data.
- Parameters
x (
torch.Tensorobject) – input datadataset (
int) – used with session-specific io layersuse_mean (
bool) – True to skip sampling step
- Returns
x_hat (
torch.Tensor): output of shape (n_frames, n_channels, y_pix, x_pix)y_hat (
torch.Tensor): output of shape (n_frames, n_channels, y_pix, x_pix)z (
torch.Tensor): sampled latent variable of shape (n_frames, n_latents)mu (
torch.Tensor): mean paramter of shape (n_frames, n_latents)logvar (
torch.Tensor): logvar paramter of shape (n_frames, n_latents)
- Return type
tuple
-
get_inverse_transformed_latents(inputs, dataset=None, as_numpy=True)[source]¶ Return latents after they have been transformed using the diagonal mapping D.
- Parameters
inputs (
torch.Tensorobject) –image tensor of shape (batch, n_channels, y_pix, x_pix)
latents tensor of shape (batch, n_ae_latents) where the first n_labels entries are assumed to be labels in the original pixel space
dataset (
int, optional) – used with session-specific io layersas_numpy (
bool, optional) – True to return as numpy array, False to return as torch Tensor
- Returns
array of latents in transformed latent space
- Return type
np.ndarrayortorch.Tensorobject
-
get_predicted_labels(x, dataset=None, use_mean=True)[source]¶ Process input data to get predicted labels.
- Parameters
x (
torch.Tensorobject) – input datadataset (
int) – used with session-specific io layersuse_mean (
bool) – True to skip sampling step
- Returns
x_hat (
torch.Tensor): output of shape (n_frames, n_channels, y_pix, x_pix)y_hat (
torch.Tensor): output of shape (n_frames, n_channels, y_pix, x_pix)z (
torch.Tensor): sampled latent variable of shape (n_frames, n_latents)mu (
torch.Tensor): mean paramter of shape (n_frames, n_latents)logvar (
torch.Tensor): logvar paramter of shape (n_frames, n_latents)
- Return type
tuple
-
get_transformed_latents(inputs, dataset=None, as_numpy=True)[source]¶ Return latents after supervised subspace has been transformed to original label space.
- Parameters
inputs (
torch.Tensorobject) –image tensor of shape (batch, n_channels, y_pix, x_pix)
latents tensor of shape (batch, n_ae_latents)
dataset (
int, optional) – used with session-specific io layersas_numpy (
bool, optional) – True to return as numpy array, False to return as torch Tensor
- Returns
array of latents in transformed latent space
- Return type
np.ndarrayortorch.Tensorobject
-
loss(datas, dataset=None, accumulate_grad=True, chunk_size=None)[source]¶ Calculate modified ELBO loss for MSPSVAE.
The batch is split into chunks if larger than a hard-coded chunk_size to keep memory requirements low; gradients are accumulated across all chunks before a gradient step is taken.
- Parameters
datas (
listofdict) – batch of data; keys should include ‘images’ and ‘masks’, if necessarydatasets (
listofint) – used for embedding lossaccumulate_grad (
bool, optional) – accumulate gradient for training stepchunk_size (
int, optional) – deprecated
- Returns
‘loss’ (
float): full elbo’loss_ll’ (
float): log-likelihood portion of elbo’loss_kl’ (
float): kl portion of elbo’loss_mse’ (
float): mse (without gaussian constants)’beta’ (
float): weight in front of kl term
- Return type
dict
-