get_labels_2d_for_trial¶
- behavenet.plotting.cond_ae_utils.get_labels_2d_for_trial(hparams, sess_ids, trial=None, trial_idx=None, sess_idx=0, dtype='test', data_gen=None)[source]¶
Return scaled labels (in pixel space) for a given trial.
- Parameters:
hparams (
dict) – needs to contain enough information to build a data generatorsess_ids (
listofdict) – each entry is a session dict with keys ‘lab’, ‘expt’, ‘animal’, ‘session’trial (
int, optional) – trial index into all possible trials (train, val, test); one of trial or trial_idx must be specified; trial takes precedence over trial_idxtrial_idx (
int, optional) – trial index into trial type defined by dtype; one of trial or trial_idx must be specified; trial takes precedence over trial_idxsess_idx (
int, optional) – session index into data generatordtype (
str, optional) – data type that is indexed by trial_idx; ‘train’ | ‘val’ | ‘test’data_gen (
ConcatSessionGeneratorobject, optional) – for generating labels
- Returns:
labels_2d_pt (
torch.Tensor) of shape (batch, n_labels, y_pix, x_pix)labels_2d_np (
np.ndarray) of shape (batch, n_labels, y_pix, x_pix)
- Return type:
tuple