Loading a trained model

After you’ve fit one or more models, often you’ll want to load these models and their associated data generator to perform further analyses. BehaveNet provides three methods for doing so:

  • Method 1: load the “best” model from a test-tube experiment

  • Method 2: specify the model version in a test-tube experiment

  • Method 3: specify the model hyperparameters in a test-tube experiment

To illustrate these three methods we’ll use an autoencoder as an example. Let’s assume that we’ve trained 5 convolutional autoencoders with 10 latents, each with a different random seed for initializing the weights, and these have all been saved in the test-tube experiment ae-example.

Method 1: load best model

The first option is to load the best model from ae-example. The “best” model is defined as the one with the smallest loss value computed on validation data. If you set the parameter val_check_interval in the ae training json to a nonzero value before fitting, this information has already been computed and saved in a csv file, so this is a relatively fast option. The following code block shows how to load the best model, as well as the associated data generator, from ae-example.

# imports
from behavenet import get_user_dir
from behavenet.fitting.utils import get_best_model_and_data
from behavenet.fitting.utils import get_expt_dir
from behavenet.fitting.utils import get_lab_example
from behavenet.fitting.utils import get_session_dir
from behavenet.models import AE as Model

# define necessary hyperparameters
hparams = {
    'data_dir': get_user_dir('data'),
    'save_dir': get_user_dir('save'),
    'experiment_name': 'ae-example',
    'model_class': 'ae',
    'model_type': 'conv',
    'n_ae_latents': 10,

# programmatically fill out other hparams options
get_lab_example(hparams, 'musall', 'vistrained')
hparams['session_dir'], sess_ids = get_session_dir(hparams)
hparams['expt_dir'] = get_expt_dir(hparams)

# use helper function to load model and data generator
model, data_generator = get_best_model_and_data(hparams, Model, version='best')

Method 2: specify the model version

The next option requires that you know in advance which test-tube version you want to load. In this example, we’ll load version 3. All you need to do is replace version='best' with version=3 in the final line above.

# use helper function to load model and data generator
model, data_generator = get_best_model_and_data(hparams, Model, version=3)

Method 3: specify model hyperparameters

The final option gives you the most control - you can specify all relevant hyperparameters needed to define the model and the data generator, and load that specific model.

# imports
from behavenet import get_user_dir
from behavenet.fitting.utils import experiment_exists
from behavenet.fitting.utils import get_best_model_and_data
from behavenet.fitting.utils import get_expt_dir
from behavenet.fitting.utils import get_lab_example
from behavenet.fitting.utils import get_session_dir
from behavenet.models import AE as Model

# define necessary hyperparameters
hparams = {
    'data_dir': get_user_dir('data'),
    'save_dir': get_user_dir('save'),
    'experiment_name': 'ae-example',
    'model_class': 'ae',
    'model_type': 'conv',
    'n_ae_latents': 10,
    'rng_seed_data': 0,
    'trial_splits': '8;1;1;0',
    'train_frac': 1,
    'rng_seed_model': 0,
    'fit_sess_io_layers': False,
    'learning_rate': 1e-4,
    'l2_reg': 0,

# programmatically fill out other hparams options
get_lab_example(hparams, 'musall', 'vistrained')
hparams['session_dir'], sess_ids = get_session_dir(hparams)
hparams['expt_dir'] = get_expt_dir(hparams)

# find the version for these hyperparameters; returns None for version if it doesn't exist
exists, version = experiment_exists(hparams, which_version=True)

# use helper function to load model and data generator
model, data_generator = get_best_model_and_data(hparams, Model, version=version)

You will need to specify the following entries in hparams regardless of the model class:

  • ‘rng_seed_data’

  • ‘trial_splits’

  • ‘train_frac’

  • ‘rng_seed_model’

  • ‘model_class’

  • ‘model_type’

For the autencoder, we need to additionally specify n_ae_latents, fit_sess_io_layers, learning_rate, and l2_reg. Check out the source code for behavenet.fitting.utils.get_model_params() to see which entries are required for other model classes.

Iterating through the data

Below is an example of how to iterate through the data generator and load batches of data:

# select data type to load
dtype = 'train'  # 'train' | 'val' | 'test'

# reset data iterator for this data type

# loop through all batches for this data type
for i in range(data_generator.n_tot_batches[dtype]):

    batch, sess = data_generator.next_batch(dtype)
    # "batch" is a dict with keys for the relevant signal, e.g. 'images', 'neural', etc.
    # "sess" is an integer denoting the dataset this batch comes from

    # ... perform analyses ...