"""Classes for splitting and serving data to models.
The data generator classes contained in this module inherit from the
:class:`torch.utils.data.Dataset` class. The user-facing class is the
:class:`ConcatSessionsGenerator`, which can manage one or more datasets. Each dataset is composed
of trials, which are split into training, validation, and testing trials using the
:func:`split_trials`. The default data generator can handle the following data types:
* **images**: individual frames of the behavioral video
* **masks**: binary mask for each frame
* **labels**: i.e. DLC labels
* **neural activity**
* **AE latents**
* **AE predictions**: predictions of AE latents from neural activity
* **ARHMM states**
* **ARHMM predictions**: predictions of ARHMM states from neural activity
Please see the online documentation at
`Read the Docs <https://behavenet.readthedocs.io/en/latest/index.html>`_ for detailed examples of
how to use the data generators.
"""
from collections import OrderedDict
import h5py
import numpy as np
import os
import pickle
import torch
from torch.utils import data
from torch.utils.data import SubsetRandomSampler
__all__ = [
'split_trials',
'SingleSessionDatasetBatchedLoad',
'SingleSessionDataset',
'ConcatSessionsGenerator',
'ConcatSessionsGeneratorMulti']
[docs]def split_trials(n_trials, rng_seed=0, train_tr=8, val_tr=1, test_tr=1, gap_tr=0):
"""Split trials into train/val/test blocks.
The data is split into blocks that have gap trials between tr/val/test:
:obj:`train tr | gap tr | val tr | gap tr | test tr | gap tr`
Parameters
----------
n_trials : :obj:`int`
total number of trials to be split
rng_seed : :obj:`int`, optional
random seed for reproducibility
train_tr : :obj:`int`, optional
number of train trials per block
val_tr : :obj:`int`, optional
number of validation trials per block
test_tr : :obj:`int`, optional
number of test trials per block
gap_tr : :obj:`int`, optional
number of gap trials between tr/val/test; there will be a total of 3 * `gap_tr` gap trials
per block; can be zero if no gap trials are desired.
Returns
-------
:obj:`dict`
Split trial indices are stored in a dict with keys `train`, `test`, and `val`
"""
# same random seed for reproducibility
np.random.seed(rng_seed)
tr_per_block = train_tr + gap_tr + val_tr + gap_tr + test_tr + gap_tr
n_blocks = int(np.floor(n_trials / tr_per_block))
if n_blocks == 0:
raise ValueError(
'Not enough trials (n=%i) for the train/test/val/gap values %i/%i/%i/%i' %
(n_trials, train_tr, val_tr, test_tr, gap_tr))
leftover_trials = n_trials - tr_per_block * n_blocks
if leftover_trials > 0:
offset = np.random.randint(0, high=leftover_trials)
else:
offset = 0
idxs_block = np.random.permutation(n_blocks)
batch_idxs = {'train': [], 'test': [], 'val': []}
for block in idxs_block:
curr_tr = block * tr_per_block + offset
batch_idxs['train'].append(np.arange(curr_tr, curr_tr + train_tr))
curr_tr += (train_tr + gap_tr)
batch_idxs['val'].append(np.arange(curr_tr, curr_tr + val_tr))
curr_tr += (val_tr + gap_tr)
batch_idxs['test'].append(np.arange(curr_tr, curr_tr + test_tr))
for dtype in ['train', 'val', 'test']:
batch_idxs[dtype] = np.concatenate(batch_idxs[dtype], axis=0)
return batch_idxs
def _load_pkl_dict(path, key, idx=None, dtype='float32'):
"""Helper function to load pickled data.
Parameters
----------
path : :obj:`str`
full file name including `.pkl` extention
key : :obj:`str`
data is returned from this key of the pickled dictionary
idx : :obj:`int` or :obj:`NoneType`
if :obj:`NoneType` return all data, else return data from this index
dtype : :obj:`str`
numpy data type of data
Returns
-------
:obj:`list` of :obj:`numpy.ndarray` if :obj:`idx=None`
:obj:`numpy.ndarray` is :obj:`idx=int`
"""
with open(path, 'rb') as f:
data_dict = pickle.load(f)
if idx is None:
samp = [data.astype(dtype) for data in data_dict[key]]
else:
samp = [data_dict[key][idx].astype(dtype)]
return samp
[docs]class SingleSessionDatasetBatchedLoad(data.Dataset):
"""Dataset class for a single session with batch loading of data."""
def __init__(
self, data_dir, lab='', expt='', animal='', session='', signals=None, transforms=None,
paths=None, device='cpu', as_numpy=False):
"""
Parameters
----------
data_dir : :obj:`str`
root directory of data
lab : :obj:`str`
lab id
expt : :obj:`str`
expt id
animal : :obj:`str`
animal id
session : :obj:`str`
session id
signals : :obj:`list` of :obj:`str`
e.g. 'images' | 'masks' | 'neural' | .... See
:func:`behavenet.fitting.utils.get_data_generator_inputs` for examples.
transforms : :obj:`list` of :obj:`behavenet.data.transform` objects
each element corresponds to an entry in :obj:`signals`; for multiple transforms, chain
together using :obj:`behavenet.data.transform.Compose` class. See
:mod:`behavenet.data.transforms` for available transform options.
paths : :obj:`list` of :obj:`str`
each element corresponds to an entry in :obj:`signals`; filename (using absolute path)
of data
device : :obj:`str`, optional
location of data; options are :obj:`cpu | cuda`
as_numpy : bool
if :obj:`True` return data as a numpy array, else return as a torch tensor
"""
# specify data
self.lab = lab
self.expt = expt
self.animal = animal
self.session = session
self.data_dir = os.path.join(
data_dir, self.lab, self.expt, self.animal, self.session)
self.name = os.path.join(self.lab, self.expt, self.animal, self.session)
self.sess_str = str('%s_%s_%s_%s' % (self.lab, self.expt, self.animal, self.session))
# get data paths
self.signals = signals
self.transforms = OrderedDict()
self.paths = OrderedDict()
for signal, transform, path in zip(signals, transforms, paths):
self.transforms[signal] = transform
self.paths[signal] = path
# get total number of trials by loading images/neural data
self.n_trials = None
for i, signal in enumerate(signals):
if signal == 'images' or signal == 'neural' or signal == 'labels' or \
signal == 'labels_sc' or signal == 'labels_masks':
data_file = paths[i]
with h5py.File(data_file, 'r', libver='latest', swmr=True) as f:
self.n_trials = len(f[signal])
break
elif signal == 'ae_latents':
try:
latents = _load_pkl_dict(self.paths[signal], 'latents')
except FileNotFoundError:
raise NotImplementedError(
('Could not open %s\nMust create ae latents from model;' +
' currently not implemented') % self.paths[signal])
self.n_trials = len(latents)
# meta data about train/test/xv splits; set by ConcatSessionsGenerator
self.batch_idxs = None
self.n_batches = None
self.device = device
self.as_numpy = as_numpy
def __str__(self):
"""Pretty printing of dataset info"""
format_str = str('%s\n' % self.sess_str)
format_str += str(' signals: {}\n'.format(self.signals))
format_str += str(' transforms: {}\n'.format(self.transforms))
format_str += str(' paths: {}\n'.format(self.paths))
return format_str
def __len__(self):
return self.n_trials
def __getitem__(self, idx):
"""Return batch of data; if idx is None, return all data
Parameters
----------
idx : :obj:`int` or :obj:`NoneType`
trial index to load; if :obj:`NoneType`, return all data.
Returns
-------
:obj:`dict`
data sample
"""
if idx is None and not self.as_numpy:
raise NotImplementedError('Cannot currently load all data as torch tensors')
sample = OrderedDict()
for signal in self.signals:
# index correct trial
if signal == 'images':
dtype = 'float32'
with h5py.File(self.paths[signal], 'r', libver='latest', swmr=True) as f:
if idx is None:
print('Warning: loading all images!')
temp_data = []
for tr in range(self.n_trials):
temp_data.append(f[signal][str(
'trial_%04i' % tr)][()].astype(dtype) / 255)
sample[signal] = temp_data
else:
sample[signal] = [f[signal][str(
'trial_%04i' % idx)][()].astype(dtype) / 255]
elif signal == 'masks':
dtype = 'float32'
with h5py.File(self.paths[signal], 'r', libver='latest', swmr=True) as f:
if idx is None:
print('Warning: loading all masks!')
temp_data = []
for tr in range(self.n_trials):
temp_data.append(f[signal][str(
'trial_%04i' % tr)][()].astype(dtype))
sample[signal] = temp_data
else:
sample[signal] = f[signal][str('trial_%04i' % idx)][()].astype(dtype)
elif signal == 'neural' or signal == 'labels' or signal == 'labels_sc' \
or signal == 'labels_masks':
dtype = 'float32'
with h5py.File(self.paths[signal], 'r', libver='latest', swmr=True) as f:
if idx is None:
temp_data = []
for tr in range(self.n_trials):
temp_data.append(f[signal][str(
'trial_%04i' % tr)][()].astype(dtype))
sample[signal] = temp_data
else:
sample[signal] = [f[signal][str('trial_%04i' % idx)][()].astype(dtype)]
elif signal == 'ae_latents' or signal == 'latents':
dtype = 'float32'
sample[signal] = self._try_to_load(signal, key='latents', idx=idx, dtype=dtype)
elif signal == 'ae_predictions':
dtype = 'float32'
sample[signal] = self._try_to_load(signal, key='predictions', idx=idx, dtype=dtype)
elif signal == 'arhmm' or signal == 'arhmm_states':
dtype = 'int32'
sample[signal] = self._try_to_load(signal, key='states', idx=idx, dtype=dtype)
elif signal == 'arhmm_predictions':
dtype = 'float32'
sample[signal] = self._try_to_load(signal, key='predictions', idx=idx, dtype=dtype)
else:
raise ValueError('"%s" is an invalid signal type' % signal)
# apply transforms
if self.transforms[signal]:
sample[signal] = [self.transforms[signal](samp) for samp in sample[signal]]
# transform into tensor
if not self.as_numpy:
if dtype == 'float32':
sample[signal] = torch.from_numpy(sample[signal][0]).float()
else:
sample[signal] = torch.from_numpy(sample[signal][0]).long()
sample['batch_idx'] = idx
return sample
def _try_to_load(self, signal, key, idx, dtype):
# try:
# data = _load_pkl_dict(self.paths[signal], key, idx=idx, dtype=dtype)
# except FileNotFoundError:
# # try prepending session string
# try:
# self.paths[signal] = _prepend_sess_id(self.paths[signal], self.sess_str)
# data = _load_pkl_dict(self.paths[signal], key, idx=idx, dtype=dtype)
# except FileNotFoundError:
# raise NotImplementedError(
# ('Could not open %s\nMust create %s from model;' +
# ' currently not implemented') % (self.paths[signal], key))
try:
data = _load_pkl_dict(self.paths[signal], key, idx=idx, dtype=dtype)
except FileNotFoundError:
raise NotImplementedError(
('Could not open %s\nMust create %s from model;' +
' currently not implemented') % (self.paths[signal], key))
return data
[docs]class SingleSessionDataset(SingleSessionDatasetBatchedLoad):
"""Dataset class for a single session.
Loads all data during Dataset creation and saves as an attribute. Batches are then sampled from
this stored data. All data transformations are applied to the full dataset upon load, *not*
for each batch. This automatically returns data as lists of numpy arrays.
Note
----
This data loader cannot be used to fit pytorch models, only ssm models.
"""
def __init__(
self, data_dir, lab='', expt='', animal='', session='', signals=None, transforms=None,
paths=None, device='cuda', as_numpy=False):
"""
Parameters
----------
data_dir : :obj:`str`
root directory of data
lab : :obj:`str`
lab id
expt : :obj:`str`
expt id
animal : :obj:`str`
animal id
session : :obj:`str`
session id
signals : :obj:`list` of :obj:`str`
e.g. 'images' | 'masks' | 'neural' | .... See
:func:`behavenet.fitting.utils.get_data_generator_inputs` for examples.
transforms : :obj:`list` of :obj:`behavenet.data.transform` objects
each element corresponds to an entry in :obj:`signals`; for multiple transforms, chain
together using :obj:`behavenet.data.transform.Compose` class. See
:mod:`behavenet.data.transforms` for available transform options.
paths : :obj:`list` of :obj:`str`
each element corresponds to an entry in :obj:`signals`; filename (using absolute path)
of data
device : :obj:`str`, optional
location of data; options are :obj:`cpu | cuda`
"""
super().__init__(data_dir, lab, expt, animal, session, signals, transforms, paths, device)
# grab all data as a single batch
self.as_numpy = as_numpy
self.data = super(SingleSessionDataset, self).__getitem__(idx=None)
_ = self.data.pop('batch_idx')
# collect dims for easy reference
# self.dims = OrderedDict()
# for signal, data in self.data.items():
# self.dims[signal] = data.shape
# if self.n_trials is None:
# self.n_trials = self.dims[signal][0]
def __len__(self):
return self.n_trials
def __getitem__(self, idx):
"""Return batch of data.
Parameters
----------
idx : :obj:`int` or :obj:`NoneType`
trial index to load; if :obj:`NoneType`, return all data.
Returns
-------
:obj:`dict`
data sample
"""
sample = OrderedDict()
for signal in self.signals:
sample[signal] = [self.data[signal][idx]]
sample['batch_idx'] = idx
return sample
[docs]class ConcatSessionsGenerator(object):
"""Dataset class for multiple sessions.
This class contains a list of single session data generators. It handles shuffling and
iterating over these sessions.
"""
_dtypes = {'train', 'val', 'test'}
def __init__(
self, data_dir, ids_list, signals_list=None, transforms_list=None, paths_list=None,
device='cuda', as_numpy=False, batch_load=True, rng_seed=0, trial_splits=None,
train_frac=1.0):
"""
Parameters
----------
data_dir : :obj:`str`
root directory of data
ids_list : :obj:`list` of :obj:`dict`
each element has the following keys: 'lab', 'expt', 'animal', and 'session'; the data
(images, masks, neural activity) is assumed to be located in:
:obj:`data_dir/lab/expt/animal/session/data.hdf5`
signals_list : :obj:`list` of :obj:`list`
list of signals for each session
transforms_list : :obj:`list` of :obj:`list`
list of transforms for each session
paths_list : :obj:`list` of :obj:`list`
list of paths for each session
device : :obj:`str`, optional
location of data; options are :obj:`cpu | cuda`
as_numpy : bool, optional
if :obj:`True` return data as a numpy array, else return as a torch tensor
batch_load : :obj:`bool`, optional
:obj:`True` to load data one batch at a time, :obj:`False` to load all data at once and
store in memory (data is still served one trial at a time).
rng_seed : :obj:`int`, optional
controls split of train/val/test trials
trial_splits : :obj:`dict`, optional
determines number of train/val/test trials using the keys 'train_tr', 'val_tr',
'test_tr', and 'gap_tr'; see :func:`split_trials` for how these are used.
train_frac : :obj:`float`, optional
if :obj:`0 < train_frac < 1.0`, defines the fraction of assigned training trials to
actually use; if :obj:`train_frac > 1.0`, defines the number of assigned training
trials to actually use
"""
if isinstance(ids_list, dict):
ids_list = [ids_list]
self.ids = ids_list
self.as_numpy = as_numpy
self.device = device
self.batch_load = batch_load
if self.batch_load:
SingleSession = SingleSessionDatasetBatchedLoad
else:
SingleSession = SingleSessionDataset
self.datasets = []
self.datasets_info = []
self.signals = signals_list
self.transforms = transforms_list
self.paths = paths_list
for ids, signals, transforms, paths in zip(
ids_list, signals_list, transforms_list, paths_list):
self.datasets.append(SingleSession(
data_dir, lab=ids['lab'], expt=ids['expt'], animal=ids['animal'],
session=ids['session'], signals=signals, transforms=transforms, paths=paths,
device=device, as_numpy=self.as_numpy))
self.datasets_info.append({
'lab': ids['lab'], 'expt': ids['expt'], 'animal': ids['animal'],
'session': ids['session']})
# collect info about datasets
self.n_datasets = len(self.datasets)
# get train/val/test batch indices for each dataset
if trial_splits is None:
trial_splits = {'train_tr': 8, 'val_tr': 1, 'test_tr': 1, 'gap_tr': 0}
self.batch_ratios = [None] * self.n_datasets
for i, dataset in enumerate(self.datasets):
dataset.batch_idxs = split_trials(len(dataset), rng_seed=rng_seed, **trial_splits)
dataset.n_batches = {}
for dtype in self._dtypes:
if dtype == 'train':
# subsample training data if requested
if train_frac != 1.0:
n_batches = len(dataset.batch_idxs[dtype])
if train_frac < 1.0:
# subsample as fraction of total batches
n_idxs = int(np.floor(train_frac * n_batches))
if n_idxs <= 0:
print(
'warning: attempting to use invalid number of training ' +
'batches; defaulting to all training batches')
n_idxs = n_batches
else:
# subsample fixed number of batches
train_frac = n_batches if train_frac > n_batches else train_frac
n_idxs = int(train_frac)
idxs_rand = np.random.choice(n_batches, size=n_idxs, replace=False)
dataset.batch_idxs[dtype] = dataset.batch_idxs[dtype][idxs_rand]
self.batch_ratios[i] = len(dataset.batch_idxs[dtype])
dataset.n_batches[dtype] = len(dataset.batch_idxs[dtype])
self.batch_ratios = np.array(self.batch_ratios) / np.sum(self.batch_ratios)
# find total number of batches per data type; this will be iterated over in the train loop
self.n_tot_batches = {}
for dtype in self._dtypes:
self.n_tot_batches[dtype] = np.sum(
[dataset.n_batches[dtype] for dataset in self.datasets])
# create data loaders (will shuffle/batch/etc datasets)
self.dataset_loaders = [None] * self.n_datasets
for i, dataset in enumerate(self.datasets):
self.dataset_loaders[i] = {}
for dtype in self._dtypes:
self.dataset_loaders[i][dtype] = torch.utils.data.DataLoader(
dataset,
batch_size=1,
sampler=SubsetRandomSampler(dataset.batch_idxs[dtype]),
num_workers=0,
pin_memory=False)
# create all iterators (will iterate through data loaders)
self.dataset_iters = [None] * self.n_datasets
for i in range(self.n_datasets):
self.dataset_iters[i] = {}
for dtype in self._dtypes:
self.dataset_iters[i][dtype] = iter(self.dataset_loaders[i][dtype])
def __str__(self):
"""Pretty printing of dataset info"""
if self.batch_load:
dataset_type = 'SingleSessionDatasetBatchedLoad'
else:
dataset_type = 'SingleSessionDataset'
format_str = str('Generator contains %i %s objects:\n' % (self.n_datasets, dataset_type))
for dataset in self.datasets:
format_str += dataset.__str__()
return format_str
def __len__(self):
return self.n_datasets
[docs] def reset_iterators(self, dtype):
"""Reset iterators so that all data is available.
Parameters
----------
dtype : :obj:`str`
'train' | 'val' | 'test' | 'all'
"""
for i in range(self.n_datasets):
if dtype == 'all':
for dtype_ in self._dtypes:
self.dataset_iters[i][dtype_] = iter(self.dataset_loaders[i][dtype_])
else:
self.dataset_iters[i][dtype] = iter(self.dataset_loaders[i][dtype])
[docs] def next_batch(self, dtype):
"""Return next batch of data.
The data generator iterates randomly through sessions and trials. Once a session runs out
of trials it is skipped.
Parameters
----------
dtype : :obj:`str`
'train' | 'val' | 'test'
Returns
-------
:obj:`tuple`
- **sample** (:obj:`dict`): data batch with keys given by :obj:`signals` input to class
- **dataset** (:obj:`int`): dataset from which data batch is drawn
"""
while True:
# get next session
dataset = int(np.random.choice(np.arange(self.n_datasets), p=self.batch_ratios))
# get this session data
try:
sample = next(self.dataset_iters[dataset][dtype])
break
except StopIteration:
continue
if self.as_numpy:
for i, signal in enumerate(sample):
if signal != 'batch_idx':
sample[signal] = [ss.cpu().detach().numpy() for ss in sample[signal]]
else:
if self.device == 'cuda':
sample = {key: val.to('cuda') for key, val in sample.items()}
return sample, dataset
[docs]class ConcatSessionsGeneratorMulti(ConcatSessionsGenerator):
"""Dataset class for multiple sessions, which returns multiple sessions per training batch.
This class contains a list of single session data generators. It handles shuffling and
iterating over these sessions.
"""
_dtypes = {'train', 'val', 'test'}
def __init__(
self, data_dir, ids_list, signals_list=None, transforms_list=None, paths_list=None,
device='cuda', as_numpy=False, batch_load=True, rng_seed=0, trial_splits=None,
train_frac=1.0, n_sessions_per_batch=2):
"""
Parameters
----------
data_dir : :obj:`str`
root directory of data
ids_list : :obj:`list` of :obj:`dict`
each element has the following keys: 'lab', 'expt', 'animal', and 'session'; the data
(images, masks, neural activity) is assumed to be located in:
:obj:`data_dir/lab/expt/animal/session/data.hdf5`
signals_list : :obj:`list` of :obj:`list`
list of signals for each session
transforms_list : :obj:`list` of :obj:`list`
list of transforms for each session
paths_list : :obj:`list` of :obj:`list`
list of paths for each session
device : :obj:`str`, optional
location of data; options are :obj:`cpu | cuda`
as_numpy : bool, optional
if :obj:`True` return data as a numpy array, else return as a torch tensor
batch_load : :obj:`bool`, optional
:obj:`True` to load data one batch at a time, :obj:`False` to load all data at once and
store in memory (data is still served one trial at a time).
rng_seed : :obj:`int`, optional
controls split of train/val/test trials
trial_splits : :obj:`dict`, optional
determines number of train/val/test trials using the keys 'train_tr', 'val_tr',
'test_tr', and 'gap_tr'; see :func:`split_trials` for how these are used.
train_frac : :obj:`float`, optional
if :obj:`0 < train_frac < 1.0`, defines the fraction of assigned training trials to
actually use; if :obj:`train_frac > 1.0`, defines the number of assigned training
trials to actually use
n_sessions_per_batch : :obj:`int`, optional
number of session per training batch to serve model; the combination of datasets and
batches will be shuffled when the data iterator is reset
"""
if n_sessions_per_batch > 4:
# requires more implementation in behavenet.fitting.losses.triplet_loss()
raise NotImplementedError
self.n_sessions_per_batch = n_sessions_per_batch
super().__init__(
data_dir, ids_list, signals_list=signals_list, transforms_list=transforms_list,
paths_list=paths_list, device=device, as_numpy=as_numpy, batch_load=batch_load,
rng_seed=rng_seed, trial_splits=trial_splits, train_frac=train_frac)
# redefine total number of training batches to reflect the fact that multiple batches are
# served per iteration (but only for training data)
self.n_tot_batches['train'] = int(self.n_tot_batches['train'] / n_sessions_per_batch)
def __str__(self):
"""Pretty printing of dataset info"""
if self.batch_load:
dataset_type = 'SingleSessionDatasetBatchedLoad'
else:
dataset_type = 'SingleSessionDataset'
format_str = 'MultiGenerator contains %i %s objects:\n' % (self.n_datasets, dataset_type)
for dataset in self.datasets:
format_str += dataset.__str__()
return format_str
def __len__(self):
return self.n_datasets
[docs] def next_batch(self, dtype, return_multiple=True):
"""Return next batch of data.
The data generator iterates randomly through sessions and trials. Once a session runs out
of trials it is skipped.
Parameters
----------
dtype : :obj:`str`
'train' | 'val' | 'test'
return_multiple : :obj:`bool`
True to return multiple batches for train data
Returns
-------
:obj:`tuple`
- **samples** (:obj:`dict`): data batch with keys given by :obj:`signals` input to class
- **datasets** (:obj:`int`): dataset from which data batch is drawn
"""
def renormalize(array):
if np.sum(array) == 0:
return array
else:
return array / np.sum(array)
if dtype == 'train' and return_multiple:
samples = []
datasets = []
curr_batch_ratios = np.copy(self.batch_ratios)
for sess in range(self.n_sessions_per_batch):
while True:
# check to see if there are enough available batches
if np.sum(curr_batch_ratios > 0) < (self.n_sessions_per_batch - sess):
return None, None
# get next dataset
dataset = np.random.choice(np.arange(self.n_datasets), p=curr_batch_ratios)
# don't choose this dataset in the future
curr_batch_ratios[dataset] = 0
curr_batch_ratios = renormalize(curr_batch_ratios)
# get this session data
try:
sample = next(self.dataset_iters[dataset][dtype])
break
except StopIteration:
continue
if self.as_numpy:
raise NotImplementedError
# for i, signal in enumerate(sample):
# if signal != 'batch_idx':
# sample[signal] = [ss.cpu().detach().numpy() for ss in sample[signal]]
else:
if self.device == 'cuda':
sample = {key: val.to('cuda') for key, val in sample.items()}
samples.append(sample)
datasets.append(dataset)
# print(datasets)
# print([s['batch_idx'].item() for s in samples])
else:
while True:
# get next session
dataset = np.random.choice(np.arange(self.n_datasets), p=self.batch_ratios)
# get this session data
try:
sample = next(self.dataset_iters[dataset][dtype])
break
except StopIteration:
continue
if self.as_numpy:
for i, signal in enumerate(sample):
if signal != 'batch_idx':
sample[signal] = [ss.cpu().detach().numpy() for ss in sample[signal]]
else:
if self.device == 'cuda':
sample = {key: val.to('cuda') for key, val in sample.items()}
datasets = int(dataset)
samples = sample
return samples, datasets