get_labels_2d_for_trial

behavenet.plotting.cond_ae_utils.get_labels_2d_for_trial(hparams, sess_ids, trial=None, trial_idx=None, sess_idx=0, dtype='test', data_gen=None)[source]

Return scaled labels (in pixel space) for a given trial.

Parameters:
  • hparams (dict) – needs to contain enough information to build a data generator

  • sess_ids (list of dict) – each entry is a session dict with keys ‘lab’, ‘expt’, ‘animal’, ‘session’

  • trial (int, optional) – trial index into all possible trials (train, val, test); one of trial or trial_idx must be specified; trial takes precedence over trial_idx

  • trial_idx (int, optional) – trial index into trial type defined by dtype; one of trial or trial_idx must be specified; trial takes precedence over trial_idx

  • sess_idx (int, optional) – session index into data generator

  • dtype (str, optional) – data type that is indexed by trial_idx; ‘train’ | ‘val’ | ‘test’

  • data_gen (ConcatSessionGenerator object, optional) – for generating labels

Returns:

  • labels_2d_pt (torch.Tensor) of shape (batch, n_labels, y_pix, x_pix)

  • labels_2d_np (np.ndarray) of shape (batch, n_labels, y_pix, x_pix)

Return type:

tuple