ConcatSessionsGeneratorMulti¶
- class behavenet.data.data_generator.ConcatSessionsGeneratorMulti(data_dir, ids_list, signals_list=None, transforms_list=None, paths_list=None, device='cuda', as_numpy=False, batch_load=True, rng_seed=0, trial_splits=None, train_frac=1.0, n_sessions_per_batch=2)[source]¶
Bases:
ConcatSessionsGeneratorDataset class for multiple sessions, which returns multiple sessions per training batch.
This class contains a list of single session data generators. It handles shuffling and iterating over these sessions.
Methods Summary
next_batch(dtype[, return_multiple])Return next batch of data.
Methods Documentation
- next_batch(dtype, return_multiple=True)[source]¶
Return next batch of data.
The data generator iterates randomly through sessions and trials. Once a session runs out of trials it is skipped.
- Parameters:
dtype (
str) – ‘train’ | ‘val’ | ‘test’return_multiple (
bool) – True to return multiple batches for train data
- Returns:
samples (
dict): data batch with keys given bysignalsinput to classdatasets (
int): dataset from which data batch is drawn
- Return type:
tuple