AEMSP¶
- class behavenet.models.aes.AEMSP(hparams)[source]¶
Bases:
AEAutoencoder class with matrix subspace projection for disentangling the latent space.
This class constructs an autoencoder whose latent space is forced to learn a subspace that reconstructs a set of supervised labels; this subspace should be orthogonal to another subspace that does not contain information about the labels. These labels are saved under the ‘labels’ key in the hdf5 data file. For more information see: Li et al 2019, Latent Space Factorisation and Manipulation via Matrix Subspace Projection https://arxiv.org/pdf/1907.12385.pdf
Note: the data in the hdf5 group labels should be mean/median centered, as no bias is learned in the transformation from the original latent space to the predicted labels.
Methods Summary
Construct the model using hparams.
Use the learned projection matrix to construct a full rank orthogonal matrix.
forward(x[, dataset])Process input data.
get_inverse_transformed_latents(latents[, ...])Take latents in transformed space to original space to push through decoder.
get_transformed_latents(inputs[, dataset, ...])Return latents after they have been transformed using the orthogonal matrix U.
loss(data[, dataset, accumulate_grad, ...])Calculate MSE loss for autoencoder.
sample([x, dataset, latents, labels, labels_2d])Generate output given an input x and arbitrary labels and/or latents.
save(filepath)Save model parameters.
Methods Documentation
- build_model()[source]¶
Construct the model using hparams.
The AEMSP is initialized when
model_class='cond-ae-msp, and currently only supportsmodel_type='conv(i.e. no linear)
- create_orthogonal_matrix()[source]¶
Use the learned projection matrix to construct a full rank orthogonal matrix.
- forward(x, dataset=None, **kwargs)[source]¶
Process input data.
- Parameters:
x (
torch.Tensorobject) – input data of shape (batch, n_channels, y_pix, x_pix)dataset (
int, optional) – used with session-specific io layers
- Returns:
x_hat (
torch.Tensor): output of shape (n_frames, n_channels, y_pix, x_pix)z (
torch.Tensor): hidden representation of shape (n_frames, n_latents)y (
torch.Tensor): predicted labels of shape (n_frames, n_labels)
- Return type:
tuple
- get_inverse_transformed_latents(latents, as_numpy=True)[source]¶
Take latents in transformed space to original space to push through decoder.
- Parameters:
latents (
torch.Tensorobject) – shape (batch, n_ae_latents)as_numpy (
bool, optional) – True to return as numpy array, False to return as torch Tensor
- Returns:
array of latents in original latent space
- Return type:
np.ndarrayortorch.Tensorobject
- get_transformed_latents(inputs, dataset=None, as_numpy=True)[source]¶
Return latents after they have been transformed using the orthogonal matrix U.
- Parameters:
inputs (
torch.Tensorobject) –image tensor of shape (batch, n_channels, y_pix, x_pix)
latents tensor of shape (batch, n_ae_latents)
dataset (
int, optional) – used with session-specific io layersas_numpy (
bool, optional) – True to return as numpy array, False to return as torch Tensor
- Returns:
array of latents in transformed latent space
- Return type:
np.ndarrayortorch.Tensorobject
- loss(data, dataset=0, accumulate_grad=True, chunk_size=200)[source]¶
Calculate MSE loss for autoencoder.
The batch is split into chunks if larger than a hard-coded chunk_size to keep memory requirements low; gradients are accumulated across all chunks before a gradient step is taken.
- Parameters:
data (
dict) – batch of data; keys should include ‘images’ and ‘masks’, if necessarydataset (
int, optional) – used for session-specific io layersaccumulate_grad (
bool, optional) – accumulate gradient for training stepchunk_size (
int, optional) – batch is split into chunks of this size to keep memory requirements low
- Returns:
‘loss’ (
float): total loss’loss_mse’ (
float): pixel mse loss’loss_msp’ (
float): combined msp loss’labels_r2’ (
float): variance-weighted $R^2$ of reconstructed labels
- Return type:
dict
- sample(x=None, dataset=None, latents=None, labels=None, labels_2d=None)[source]¶
Generate output given an input x and arbitrary labels and/or latents.
How output image is generated:
if latents is not None and labels is not None, these are concatenated, tranformed to the original latent space, and pushed through the decoder
if latents is not None and labels is None, the input x is pushed through the encoder to produce the latents, these are transformed with the projection layer, and the resulting latents (n_latents - n_labels dimensions) are replaced with the user-defined latents. This vector (labels + latents) is then transformed back into the original latent space, and pushed through the decoder.
if latents is None and labels is not None, the input x is pushed through the encoder to produce the latents, these are transformed with the projection layer, and the resulting labels (n_labels dimensions) are replaced with the user-defined labels. This vector (latents + labels) is then transformed back into the original latent space, and pushed through the decoder.
- Parameters:
x (
torch.Tensorobject, optional) – input data of shape (batch, n_channels, y_pix, x_pix)dataset (
int, optional) – used with session-specific io layerslatents (
np.ndarrayobject, optional) – transformed latents of shape (batch, n_latents - n_labels)labels (
np.ndarrayobject, optional) – continuous labels corresponding to input data, of shape (batch, n_labels)labels_2d (
torch.Tensorobject, optional) – one-hot labels corresponding to input data, of shape (batch, n_labels, y_pix, x_pix); for a given frame, each channel corresponds to a label and is all zeros with a single value of one in the proper x/y position
- Returns:
output of shape (n_frames, n_channels, y_pix, x_pix)
- Return type:
torch.Tensor