ConcatSessionsGenerator

class behavenet.data.data_generator.ConcatSessionsGenerator(data_dir, ids_list, signals_list=None, transforms_list=None, paths_list=None, device='cuda', as_numpy=False, batch_load=True, rng_seed=0, trial_splits=None, train_frac=1.0)[source]

Bases: object

Dataset class for multiple sessions.

This class contains a list of single session data generators. It handles shuffling and iterating over these sessions.

Methods Summary

next_batch(dtype)

Return next batch of data.

reset_iterators(dtype)

Reset iterators so that all data is available.

Methods Documentation

next_batch(dtype)[source]

Return next batch of data.

The data generator iterates randomly through sessions and trials. Once a session runs out of trials it is skipped.

Parameters:

dtype (str) – ‘train’ | ‘val’ | ‘test’

Returns:

  • sample (dict): data batch with keys given by signals input to class

  • dataset (int): dataset from which data batch is drawn

Return type:

tuple

reset_iterators(dtype)[source]

Reset iterators so that all data is available.

Parameters:

dtype (str) – ‘train’ | ‘val’ | ‘test’ | ‘all’