AEMSP

class behavenet.models.aes.AEMSP(hparams)[source]

Bases: AE

Autoencoder class with matrix subspace projection for disentangling the latent space.

This class constructs an autoencoder whose latent space is forced to learn a subspace that reconstructs a set of supervised labels; this subspace should be orthogonal to another subspace that does not contain information about the labels. These labels are saved under the ‘labels’ key in the hdf5 data file. For more information see: Li et al 2019, Latent Space Factorisation and Manipulation via Matrix Subspace Projection https://arxiv.org/pdf/1907.12385.pdf

Note: the data in the hdf5 group labels should be mean/median centered, as no bias is learned in the transformation from the original latent space to the predicted labels.

Methods Summary

build_model()

Construct the model using hparams.

create_orthogonal_matrix()

Use the learned projection matrix to construct a full rank orthogonal matrix.

forward(x[, dataset])

Process input data.

get_inverse_transformed_latents(latents[, ...])

Take latents in transformed space to original space to push through decoder.

get_transformed_latents(inputs[, dataset, ...])

Return latents after they have been transformed using the orthogonal matrix U.

loss(data[, dataset, accumulate_grad, ...])

Calculate MSE loss for autoencoder.

sample([x, dataset, latents, labels, labels_2d])

Generate output given an input x and arbitrary labels and/or latents.

save(filepath)

Save model parameters.

Methods Documentation

build_model()[source]

Construct the model using hparams.

The AEMSP is initialized when model_class='cond-ae-msp, and currently only supports model_type='conv (i.e. no linear)

create_orthogonal_matrix()[source]

Use the learned projection matrix to construct a full rank orthogonal matrix.

forward(x, dataset=None, **kwargs)[source]

Process input data.

Parameters:
  • x (torch.Tensor object) – input data of shape (batch, n_channels, y_pix, x_pix)

  • dataset (int, optional) – used with session-specific io layers

Returns:

  • x_hat (torch.Tensor): output of shape (n_frames, n_channels, y_pix, x_pix)

  • z (torch.Tensor): hidden representation of shape (n_frames, n_latents)

  • y (torch.Tensor): predicted labels of shape (n_frames, n_labels)

Return type:

tuple

get_inverse_transformed_latents(latents, as_numpy=True)[source]

Take latents in transformed space to original space to push through decoder.

Parameters:
  • latents (torch.Tensor object) – shape (batch, n_ae_latents)

  • as_numpy (bool, optional) – True to return as numpy array, False to return as torch Tensor

Returns:

array of latents in original latent space

Return type:

np.ndarray or torch.Tensor object

get_transformed_latents(inputs, dataset=None, as_numpy=True)[source]

Return latents after they have been transformed using the orthogonal matrix U.

Parameters:
  • inputs (torch.Tensor object) –

    • image tensor of shape (batch, n_channels, y_pix, x_pix)

    • latents tensor of shape (batch, n_ae_latents)

  • dataset (int, optional) – used with session-specific io layers

  • as_numpy (bool, optional) – True to return as numpy array, False to return as torch Tensor

Returns:

array of latents in transformed latent space

Return type:

np.ndarray or torch.Tensor object

loss(data, dataset=0, accumulate_grad=True, chunk_size=200)[source]

Calculate MSE loss for autoencoder.

The batch is split into chunks if larger than a hard-coded chunk_size to keep memory requirements low; gradients are accumulated across all chunks before a gradient step is taken.

Parameters:
  • data (dict) – batch of data; keys should include ‘images’ and ‘masks’, if necessary

  • dataset (int, optional) – used for session-specific io layers

  • accumulate_grad (bool, optional) – accumulate gradient for training step

  • chunk_size (int, optional) – batch is split into chunks of this size to keep memory requirements low

Returns:

  • ‘loss’ (float): total loss

  • ’loss_mse’ (float): pixel mse loss

  • ’loss_msp’ (float): combined msp loss

  • ’labels_r2’ (float): variance-weighted $R^2$ of reconstructed labels

Return type:

dict

sample(x=None, dataset=None, latents=None, labels=None, labels_2d=None)[source]

Generate output given an input x and arbitrary labels and/or latents.

How output image is generated:

  • if latents is not None and labels is not None, these are concatenated, tranformed to the original latent space, and pushed through the decoder

  • if latents is not None and labels is None, the input x is pushed through the encoder to produce the latents, these are transformed with the projection layer, and the resulting latents (n_latents - n_labels dimensions) are replaced with the user-defined latents. This vector (labels + latents) is then transformed back into the original latent space, and pushed through the decoder.

  • if latents is None and labels is not None, the input x is pushed through the encoder to produce the latents, these are transformed with the projection layer, and the resulting labels (n_labels dimensions) are replaced with the user-defined labels. This vector (latents + labels) is then transformed back into the original latent space, and pushed through the decoder.

Parameters:
  • x (torch.Tensor object, optional) – input data of shape (batch, n_channels, y_pix, x_pix)

  • dataset (int, optional) – used with session-specific io layers

  • latents (np.ndarray object, optional) – transformed latents of shape (batch, n_latents - n_labels)

  • labels (np.ndarray object, optional) – continuous labels corresponding to input data, of shape (batch, n_labels)

  • labels_2d (torch.Tensor object, optional) – one-hot labels corresponding to input data, of shape (batch, n_labels, y_pix, x_pix); for a given frame, each channel corresponds to a label and is all zeros with a single value of one in the proper x/y position

Returns:

output of shape (n_frames, n_channels, y_pix, x_pix)

Return type:

torch.Tensor

save(filepath)[source]

Save model parameters.