split_trials

behavenet.data.data_generator.split_trials(n_trials, rng_seed=0, train_tr=8, val_tr=1, test_tr=1, gap_tr=0)[source]

Split trials into train/val/test blocks.

The data is split into blocks that have gap trials between tr/val/test:

train tr | gap tr | val tr | gap tr | test tr | gap tr

Parameters:
  • n_trials (int) – total number of trials to be split

  • rng_seed (int, optional) – random seed for reproducibility

  • train_tr (int, optional) – number of train trials per block

  • val_tr (int, optional) – number of validation trials per block

  • test_tr (int, optional) – number of test trials per block

  • gap_tr (int, optional) – number of gap trials between tr/val/test; there will be a total of 3 * gap_tr gap trials per block; can be zero if no gap trials are desired.

Returns:

Split trial indices are stored in a dict with keys train, test, and val

Return type:

dict