get_model_input¶
- behavenet.plotting.cond_ae_utils.get_model_input(data_generator, hparams, model, trial=None, trial_idx=None, sess_idx=0, max_frames=200, compute_latents=False, compute_2d_labels=True, compute_scaled_labels=False, mask_labels=False, dtype='test')[source]¶
Return images, latents, and labels for a given trial.
- Parameters:
data_generator (
ConcatSessionGenerator) – for generating model inputhparams (
dict) – needs to contain enough information to specify both a model and the associated datamodel (
behavenet.modelsobject) – model typetrial (
int, optional) – trial index into all possible trials (train, val, test); one of trial or trial_idx must be specified; trial takes precedence over trial_idxtrial_idx (
int, optional) – trial index into trial type defined by dtype; one of trial or trial_idx must be specified; trial takes precedence over trial_idxsess_idx (
int, optional) – session index into data generatormax_frames (
int, optional) – maximum size of batch to returncompute_latents (
bool, optional) – True to return latentscompute_2d_labels (
bool, optional) – True to return 2d label tensors of shape (batch, n_labels, y_pix, x_pix)compute_scaled_labels (
bool, optional) – ignored if compute_2d_labels is True; if compute_scaled_labels=True, return scaled labels as shape (batch, n_labels) rather than 2d labels as shape (batch, n_labels, y_pix, x_pix).mask_labels (
bool, optional) – True to return numpy labels where nan values indicate masked time pointsdtype (
str, optional) – data type that is indexed by trial_idx; ‘train’ | ‘val’ | ‘test’
- Returns:
ims_pt (
torch.Tensor) of shape (max_frames, n_channels, y_pix, x_pix)ims_np (
np.ndarray) of shape (max_frames, n_channels, y_pix, x_pix)latents_np (
np.ndarray) of shape (max_frames, n_latents)labels_pt (
torch.Tensor) of shape (max_frames, n_labels)labels_2d_pt (
torch.Tensor) of shape (max_frames, n_labels, y_pix, x_pix)labels_2d_np (
np.ndarray) of shape (max_frames, n_labels, y_pix, x_pix)
- Return type:
tuple