Data handing documentation. Module

Classes for splitting and serving data to models.

The data generator classes contained in this module inherit from the class. The user-facing class is the ConcatSessionsGenerator, which can manage one or more datasets. Each dataset is composed of trials, which are split into training, validation, and testing trials using the split_trials(). The default data generator can handle the following data types:

  • images: individual frames of the behavioral video

  • masks: binary mask for each frame

  • labels: i.e. DLC labels

  • neural activity

  • AE latents

  • AE predictions: predictions of AE latents from neural activity

  • ARHMM states

  • ARHMM predictions: predictions of ARHMM states from neural activity

Please see the online documentation at Read the Docs for detailed examples of how to use the data generators.


split_trials(n_trials[, rng_seed, train_tr, …])

Split trials into train/val/test blocks.


SingleSessionDatasetBatchedLoad(data_dir[, …])

Dataset class for a single session with batch loading of data.

SingleSessionDataset(data_dir[, lab, expt, …])

Dataset class for a single session.

ConcatSessionsGenerator(data_dir, ids_list)

Dataset class for multiple sessions.

ConcatSessionsGeneratorMulti(data_dir, ids_list)

Dataset class for multiple sessions, which returns multiple sessions per training batch. Module

Utility functions for automatically constructing hdf5 files.


build_hdf5(save_file, video_file[, …])

Build Behavenet-style HDF5 file from video file and optional label file.

load_raw_labels(file_path, pose_algo[, …])

Load labels and build masks from a variety of standardized source files.

resize_labels(labels, xpix_new, ypix_new, …)

Update label values to reflect scale of corresponding images.

get_frames_from_idxs(cap, idxs)

Helper function to load video segments. Module

Tranform classes to process data.

Data generator objects can apply these transforms to batches of data upon loading.



Shuffle blocks of contiguous discrete states within each trial.


Clip upper level of signal and divide by clip value.


Composes several transforms together.


Turn a categorical vector into a one-hot vector.

MakeOneHot2D(y_pixels, x_pixels)

Turn an array of continuous values into an array of one-hot 2D arrays.


Compute motion energy across batch dimension.

SelectIdxs(idxs[, sample_name])

“Index-based subsampling of neural activity.

Threshold(threshold, bin_size)

Remove channels of neural activity whose mean value is below a threshold.


Abstract base class for transforms.


z-score channel activity. Module

Utility functions for constructing inputs to data generators.


get_data_generator_inputs(hparams, sess_ids)

Helper function for generating signals, transforms and paths.

build_data_generator(hparams, sess_ids[, …])

Helper function to build data generator from hparams dict.

check_same_training_split(model_path, hparams)

Ensure data rng seed and trial splits are same for two models.

get_transforms_paths(data_type, hparams, sess_id)

Helper function for generating session-specific transforms and paths.

load_labels_like_latents(hparams, sess_ids, …)

Load labels from hdf5 in the same dictionary format that latents are saved.

get_region_list(hparams[, group_0, group_1])

Get brain regions and their indices into neural data.