get_model_input

behavenet.plotting.cond_ae_utils.get_model_input(data_generator, hparams, model, trial=None, trial_idx=None, sess_idx=0, max_frames=200, compute_latents=False, compute_2d_labels=True, compute_scaled_labels=False, mask_labels=False, dtype='test')[source]

Return images, latents, and labels for a given trial.

Parameters:
  • data_generator (ConcatSessionGenerator) – for generating model input

  • hparams (dict) – needs to contain enough information to specify both a model and the associated data

  • model (behavenet.models object) – model type

  • trial (int, optional) – trial index into all possible trials (train, val, test); one of trial or trial_idx must be specified; trial takes precedence over trial_idx

  • trial_idx (int, optional) – trial index into trial type defined by dtype; one of trial or trial_idx must be specified; trial takes precedence over trial_idx

  • sess_idx (int, optional) – session index into data generator

  • max_frames (int, optional) – maximum size of batch to return

  • compute_latents (bool, optional) – True to return latents

  • compute_2d_labels (bool, optional) – True to return 2d label tensors of shape (batch, n_labels, y_pix, x_pix)

  • compute_scaled_labels (bool, optional) – ignored if compute_2d_labels is True; if compute_scaled_labels=True, return scaled labels as shape (batch, n_labels) rather than 2d labels as shape (batch, n_labels, y_pix, x_pix).

  • mask_labels (bool, optional) – True to return numpy labels where nan values indicate masked time points

  • dtype (str, optional) – data type that is indexed by trial_idx; ‘train’ | ‘val’ | ‘test’

Returns:

  • ims_pt (torch.Tensor) of shape (max_frames, n_channels, y_pix, x_pix)

  • ims_np (np.ndarray) of shape (max_frames, n_channels, y_pix, x_pix)

  • latents_np (np.ndarray) of shape (max_frames, n_latents)

  • labels_pt (torch.Tensor) of shape (max_frames, n_labels)

  • labels_2d_pt (torch.Tensor) of shape (max_frames, n_labels, y_pix, x_pix)

  • labels_2d_np (np.ndarray) of shape (max_frames, n_labels, y_pix, x_pix)

Return type:

tuple