get_reconstruction¶
- behavenet.fitting.eval.get_reconstruction(model, inputs, dataset=None, return_latents=False, labels=None, labels_2d=None, apply_inverse_transform=True, use_mean=True)[source]¶
Reconstruct an image from either image or latent inputs.
- Parameters:
model (
AEobject) – pytorch modelinputs (
torch.Tensorobject) –image tensor of shape (batch, channels, y_pix, x_pix)
latents tensor of shape (batch, n_ae_latents)
dataset (
intorNoneType, optional) – for use with session-specific io layersreturn_latents (
bool, optional) – ifTruereturn tuple of (recon, latents)labels (
torch.Tensorobject orNoneType, optional) – label tensor of shape (batch, n_labels)labels_2d (
torch.Tensorobject orNoneType, optional) – label tensor of shape (batch, n_labels, y_pix, x_pix)apply_inverse_transform (
bool) – if inputs are latents (and model class is ‘cond-ae-msp’ or ‘ps-vae’), apply inverse transform to put in original latent spaceuse_mean (
bool) – if inputs are images (and model class is variational), use mean of approximate posterior without sampling
- Returns:
reconstructed images of shape (batch, channels, y_pix, x_pix)
- Return type:
np.ndarray