ConcatSessionsGeneratorMulti

class behavenet.data.data_generator.ConcatSessionsGeneratorMulti(data_dir, ids_list, signals_list=None, transforms_list=None, paths_list=None, device='cuda', as_numpy=False, batch_load=True, rng_seed=0, trial_splits=None, train_frac=1.0, n_sessions_per_batch=2)[source]

Bases: ConcatSessionsGenerator

Dataset class for multiple sessions, which returns multiple sessions per training batch.

This class contains a list of single session data generators. It handles shuffling and iterating over these sessions.

Methods Summary

next_batch(dtype[, return_multiple])

Return next batch of data.

Methods Documentation

next_batch(dtype, return_multiple=True)[source]

Return next batch of data.

The data generator iterates randomly through sessions and trials. Once a session runs out of trials it is skipped.

Parameters:
  • dtype (str) – ‘train’ | ‘val’ | ‘test’

  • return_multiple (bool) – True to return multiple batches for train data

Returns:

  • samples (dict): data batch with keys given by signals input to class

  • datasets (int): dataset from which data batch is drawn

Return type:

tuple