get_reconstruction

behavenet.fitting.eval.get_reconstruction(model, inputs, dataset=None, return_latents=False, labels=None, labels_2d=None, apply_inverse_transform=True, use_mean=True)[source]

Reconstruct an image from either image or latent inputs.

Parameters:
  • model (AE object) – pytorch model

  • inputs (torch.Tensor object) –

    • image tensor of shape (batch, channels, y_pix, x_pix)

    • latents tensor of shape (batch, n_ae_latents)

  • dataset (int or NoneType, optional) – for use with session-specific io layers

  • return_latents (bool, optional) – if True return tuple of (recon, latents)

  • labels (torch.Tensor object or NoneType, optional) – label tensor of shape (batch, n_labels)

  • labels_2d (torch.Tensor object or NoneType, optional) – label tensor of shape (batch, n_labels, y_pix, x_pix)

  • apply_inverse_transform (bool) – if inputs are latents (and model class is ‘cond-ae-msp’ or ‘ps-vae’), apply inverse transform to put in original latent space

  • use_mean (bool) – if inputs are images (and model class is variational), use mean of approximate posterior without sampling

Returns:

reconstructed images of shape (batch, channels, y_pix, x_pix)

Return type:

np.ndarray