"""Utility functions for managing model paths and the hparams dict."""
import os
import pickle
import numpy as np
# to ignore imports for sphinx-autoapidoc
__all__ = [
'get_subdirs', 'get_session_dir', 'get_expt_dir', 'read_session_info_from_csv',
'export_session_info_to_csv', 'contains_session', 'find_session_dirs', 'experiment_exists',
'get_model_params', 'export_hparams', 'get_lab_example', 'get_region_dir',
'create_tt_experiment', 'get_best_model_version',
'get_best_model_and_data']
[docs]def get_subdirs(path):
"""Get all first-level subdirectories in a given path (no recursion).
Parameters
----------
path : :obj:`str`
absolute path
Returns
-------
:obj:`list`
first-level subdirectories in :obj:`path`
"""
if not os.path.exists(path):
raise NotADirectoryError('%s is not a path' % path)
try:
s = next(os.walk(path))[1]
except StopIteration:
raise StopIteration('%s does not contain any subdirectories' % path)
if len(s) == 0:
raise StopIteration('%s does not contain any subdirectories' % path)
return s
def _get_multisession_paths(base_dir, lab='', expt='', animal=''):
"""Returns all paths in `base_dir` that start with `multi`.
The absolute paths returned are determined by `base_dir`, `lab`, `expt`, `animal`, and
`session` as follows: :obj:`base_dir/lab/expt/animal/session/sub_dir`
Use empty strings to ignore one of the session id components.
Parameters
----------
base_dir : :obj:`str`
lab : :obj:`str`, optional
expt : :obj:`str`, optional
animal : :obj:`str`, optional
Returns
-------
:obj:`list`
list of absolute paths
"""
multi_paths = []
try:
sub_dirs = get_subdirs(os.path.join(base_dir, lab, expt, animal))
for sub_dir in sub_dirs:
if sub_dir[:5] == 'multi':
# record top-level multi-session directory
multi_paths.append(os.path.join(base_dir, lab, expt, animal, sub_dir))
except ValueError:
print('warning: did not find requested multisession(s)')
except NotADirectoryError:
print('warning: did not find any sessions')
except StopIteration:
print('warning: did not find any sessions')
return multi_paths
def _get_single_sessions(base_dir, depth, curr_depth):
"""Recursively search through non-multisession directories for all single sessions.
Parameters
----------
base_dir : :obj:`str`
depth : :obj:`int`
depth of recursion
curr_depth : :obj:`int`
current depth in recursion
Returns
-------
:obj:`list` of :obj:`dict`
session ids for all single sessions in :obj:`base_dir`
"""
session_list = []
if curr_depth < depth:
curr_depth += 1
sub_dirs = get_subdirs(base_dir)
for sub_dir in sub_dirs:
if sub_dir[:12] != 'multisession':
session_list += _get_single_sessions(
os.path.join(base_dir, sub_dir), depth=depth, curr_depth=curr_depth)
elif curr_depth == depth:
# take previous 4 directories (lab/expt/animal/session)
sess_path = base_dir.split(os.sep)
session_list = [{
'lab': sess_path[-4],
'expt': sess_path[-3],
'animal': sess_path[-2],
'session': sess_path[-1]}]
return session_list
def _get_transition_str(hparams):
"""
Parameters
----------
hparams : :obj:`dict`
model hyperparameters; needs key 'transitions' and 'kappa' if using sticky transitions
Returns
-------
:obj:`str`
arhmm transition string used for model path specification
"""
if hparams['transitions'] == 'sticky':
return 'sticky_%.0e' % hparams['kappa']
else:
return hparams['transitions']
[docs]def get_session_dir(hparams, session_source='save'):
"""Get session-level directory for saving model outputs from hparams dict.
Relies on hparams keys 'sessions_csv', 'multisession', 'lab', 'expt', 'animal', 'session'.
The :obj:`sessions_csv` key takes precedence. The value for this key is a non-empty string of
the pattern :obj:`/path/to/session_info.csv`, where `session_info.csv` has 4 columns for lab,
expt, animal and session.
If `sessions_csv` is an empty string or the key is not in `hparams`, the following occurs:
- if :obj:`'lab' == 'all'`, an error is thrown since multiple-lab runs are not currently
supported
- if :obj:`'expt' == 'all'`, all sessions from all animals from all expts from the specified
lab are used; the session_dir will then be :obj:`save_dir/lab/multisession-xx`
- if :obj:`'animal' == 'all'`, all sessions from all animals in the specified expt are used;
the session_dir will then be :obj:`save_dir/lab/expt/multisession-xx`
- if :obj:`'session' == 'all'`, all sessions from the specified animal are used; the
session_dir will then be :obj:`save_dir/lab/expt/animal/multisession-xx`
- if none of 'lab', 'expt', 'animal' or 'session' is 'all', session_dir is
:obj:`save_dir/lab/expt/animal/session`
The :obj:`session_source` argument defines where the code looks for sessions whenever one
of 'lab', 'expt', 'animal', or 'session' is :obj:`'all'`; if :obj:`'session_source' = 'data'`,
the data directory is searched for sessions; if :obj:`'session_source' = 'save'`, the save
directory is searched for sessions. This means that only sessions that have been previously
used for fitting models will be included.
The :obj:`multisession-xx` directory will contain a file :obj:`session_info.csv` which will
contain information about the sessions that comprise the multisession; this file is used to
determine whether or not a new multisession directory needs to be created.
Parameters
----------
hparams : :obj:`dict`
requires `sessions_csv`, `multisession`, `lab`, `expt`, `animal` and `session`
session_source : :obj:`str`, optional
'save' to use hparams['save_dir'], 'data' to use hparams['data_dir'] as base directory;
note that using :obj:`path_type='data'` will not return multisession directories
Returns
-------
:obj:`tuple`
- session_dir (:obj:`str`)
- sessions_single (:obj:`list`)
"""
save_dir = hparams['save_dir']
if session_source == 'save':
sess_dir = hparams['save_dir']
elif session_source == 'data':
sess_dir = hparams['data_dir']
else:
raise ValueError('"%s" is an invalid session_source' % session_source)
if len(hparams.get('sessions_csv', [])) > 0:
# collect all single sessions from csv
sessions_single = read_session_info_from_csv(hparams['sessions_csv'])
labs, expts, animals, sessions = [], [], [], []
for sess in sessions_single:
sess.pop('save_dir', None)
labs.append(sess['lab'])
expts.append(sess['expt'])
animals.append(sess['animal'])
sessions.append(sess['session'])
# find appropriate session directory
labs, expts, animals, sessions = \
np.array(labs), np.array(expts), np.array(animals), np.array(sessions)
lab, expt, animal, session = '', '', '', ''
if len(np.unique(sessions)) == 1:
# get single session from one animal
lab, expt, animal, session = labs[0], expts[0], animals[0], sessions[0]
session_dir_base = os.path.join(save_dir, lab, expt, animal, session)
elif len(np.unique(animals)) == 1:
# get all sessions from one animal
lab, expt, animal = labs[0], expts[0], animals[0]
session_dir_base = os.path.join(save_dir, lab, expt, animal)
elif len(np.unique(expts)) == 1:
lab, expt = labs[0], expts[0]
# get all animals from one experiment
session_dir_base = os.path.join(save_dir, lab, expt)
elif len(np.unique(labs)) == 1:
# get all experiments from one lab
lab = labs[0]
session_dir_base = os.path.join(save_dir, lab)
else:
raise NotImplementedError('multiple labs not currently supported')
# find corresponding multisession (ok if they don't exist)
multisession_paths = _get_multisession_paths(save_dir, lab=lab, expt=expt, animal=animal)
else:
# get session dirs (can include multiple sessions)
lab = hparams['lab']
if lab == 'all':
raise NotImplementedError('multiple labs not currently supported')
elif hparams['expt'] == 'all':
# get all experiments from one lab
multisession_paths = _get_multisession_paths(save_dir, lab=lab)
sessions_single = _get_single_sessions(
os.path.join(sess_dir, lab), depth=3, curr_depth=0)
session_dir_base = os.path.join(save_dir, lab)
elif hparams['animal'] == 'all':
# get all animals from one experiment
expt = hparams['expt']
multisession_paths = _get_multisession_paths(save_dir, lab=lab, expt=expt)
sessions_single = _get_single_sessions(
os.path.join(sess_dir, lab, expt), depth=2, curr_depth=0)
session_dir_base = os.path.join(save_dir, lab, expt)
elif hparams['session'] == 'all':
# get all sessions from one animal
expt = hparams['expt']
animal = hparams['animal']
multisession_paths = _get_multisession_paths(
save_dir, lab=lab, expt=expt, animal=animal)
sessions_single = _get_single_sessions(
os.path.join(sess_dir, lab, expt, animal), depth=1, curr_depth=0)
session_dir_base = os.path.join(save_dir, lab, expt, animal)
else:
multisession_paths = []
sessions_single = [{
'lab': hparams['lab'], 'expt': hparams['expt'], 'animal': hparams['animal'],
'session': hparams['session']}]
session_dir_base = os.path.join(
save_dir, hparams['lab'], hparams['expt'], hparams['animal'], hparams['session'])
# construct session_dir
if hparams.get('multisession', None) is not None and len(hparams.get('sessions_csv', [])) == 0:
session_dir = os.path.join(session_dir_base, 'multisession-%02i' % hparams['multisession'])
# overwrite sessions_single with whatever is in requested multisession
sessions_single = read_session_info_from_csv(os.path.join(session_dir, 'session_info.csv'))
for sess in sessions_single:
sess.pop('save_dir', None)
elif len(sessions_single) > 1:
# check if this combo of experiments exists in previous multi-sessions
found_match = False
multi_idx = None
for session_multi in multisession_paths:
csv_file = os.path.join(session_multi, 'session_info.csv')
sessions_multi = read_session_info_from_csv(csv_file)
for d in sessions_multi:
# save path doesn't matter for comparison
d.pop('save_dir', None)
# compare to collection of single sessions above
set_l1 = set(tuple(sorted(d.items())) for d in sessions_single)
set_l2 = set(tuple(sorted(d.items())) for d in sessions_multi)
set_diff = set_l1.symmetric_difference(set_l2)
if len(set_diff) == 0:
# found match; record index
found_match = True
multi_idx = int(session_multi.split('-')[-1])
break
# create new multisession if match was not found
if not found_match:
multi_idxs = [
int(session_multi.split('-')[-1]) for session_multi in multisession_paths]
if len(multi_idxs) == 0:
multi_idx = 0
else:
multi_idx = max(multi_idxs) + 1
else:
pass
session_dir = os.path.join(session_dir_base, 'multisession-%02i' % multi_idx)
else:
session_dir = session_dir_base
return session_dir, sessions_single
[docs]def get_expt_dir(hparams, model_class=None, model_type=None, expt_name=None):
"""Get output directories associated with a particular model class/type/testtube expt name.
Examples
--------
* autoencoder: :obj:`session_dir/ae/conv/08_latents/expt_name`
* arhmm: :obj:`session_dir/arhmm/08_latents/16_states/stationary/gaussian/expt_name`
* arhmm-labels: :obj:`session_dir/arhmm-labels/16_states/stationary/gaussian/expt_name`
* neural->ae decoder: :obj:`session_dir/neural-ae/08_latents/mlp/mctx/expt_name`
* neural->arhmm decoder:
:obj:`session_dir/neural-ae/08_latents/16_states/stationary/mlp/mctx/expt_name`
* bayesian decoder:
:obj:`session_dir/arhmm-decoding/08_latents/16_states/stationary/gaussian/mctx/expt_name`
Parameters
----------
hparams : :obj:`dict`
specify model hyperparameters
model_class : :obj:`str`, optional
will search :obj:`hparams` if not present
model_type : :obj:`str`, optional
will search :obj:`hparams` if not present
expt_name : :obj:`str`, optional
will search :obj:`hparams` if not present
Returns
-------
:obj:`str`
contains data info (lab/expt/animal/session) as well as model info (e.g. n_ae_latents) and
expt_name
"""
import copy
if model_class is None:
model_class = hparams['model_class']
if model_type is None:
model_type = hparams['model_type']
if expt_name is None:
expt_name = hparams['experiment_name']
# get results dir
if model_class == 'ae' \
or model_class == 'vae' \
or model_class == 'beta-tcvae' \
or model_class == 'cond-vae' \
or model_class == 'cond-ae' \
or model_class == 'cond-ae-msp' \
or model_class == 'ps-vae' \
or model_class == 'msps-vae':
model_path = os.path.join(
model_class, model_type, '%02i_latents' % hparams['n_ae_latents'])
if hparams.get('ae_multisession', None) is not None:
# using a multisession autoencoder; assumes multisessionis at animal level
# (rather than experiment level), i.e.
# - latent session dir: lab/expt/animal/multisession-xx
# - en/decoding session dir: lab/expt/animal/session
hparams_ = copy.deepcopy(hparams)
hparams_['session'] = 'all'
hparams_['multisession'] = hparams['ae_multisession']
session_dir, _ = get_session_dir(hparams_)
else:
session_dir = hparams['session_dir']
elif model_class == 'neural-ae' or model_class == 'neural-ae-me' or model_class == 'ae-neural':
brain_region = get_region_dir(hparams)
model_path = os.path.join(
model_class, '%02i_latents' % hparams['n_ae_latents'], model_type, brain_region)
session_dir = hparams['session_dir']
elif model_class == 'neural-labels' or model_class == 'labels-neural':
brain_region = get_region_dir(hparams)
model_path = os.path.join(model_class, model_type, brain_region)
session_dir = hparams['session_dir']
elif model_class == 'neural-arhmm' or model_class == 'arhmm-neural':
brain_region = get_region_dir(hparams)
model_path = os.path.join(
model_class, '%02i_latents' % hparams['n_ae_latents'],
'%02i_states' % hparams['n_arhmm_states'],
_get_transition_str(hparams), model_type, brain_region)
session_dir = hparams['session_dir']
elif model_class == 'arhmm' or model_class == 'hmm':
model_path = os.path.join(
model_class, '%02i_latents' % hparams['n_ae_latents'],
'%02i_states' % hparams['n_arhmm_states'],
_get_transition_str(hparams), hparams['noise_type'])
if hparams.get('arhmm_multisession', None) is not None:
# using a multisession autoencoder with single session arhmm; assumes multisession
# is at animal level (rather than experiment level), i.e.
# - latent session dir: lab/expt/animal/multisession-xx
# - arhmm session dir: lab/expt/animal/session
hparams_ = copy.deepcopy(hparams)
hparams_['session'] = 'all'
hparams_['multisession'] = hparams['arhmm_multisession']
session_dir, _ = get_session_dir(hparams_)
else:
session_dir = hparams['session_dir']
elif model_class == 'arhmm-labels' or model_class == 'hmm-labels':
model_path = os.path.join(
model_class, '%02i_states' % hparams['n_arhmm_states'],
_get_transition_str(hparams), hparams['noise_type'])
if hparams.get('arhmm_multisession', None) is not None:
# using a multisession autoencoder with single session arhmm; assumes multisession
# is at animal level (rather than experiment level), i.e.
# - latent session dir: lab/expt/animal/multisession-xx
# - arhmm session dir: lab/expt/animal/session
hparams_ = copy.deepcopy(hparams)
hparams_['session'] = 'all'
hparams_['multisession'] = hparams['arhmm_multisession']
session_dir, _ = get_session_dir(hparams_)
else:
session_dir = hparams['session_dir']
elif model_class == 'bayesian-decoding':
brain_region = get_region_dir(hparams)
model_path = os.path.join(
model_class, '%02i_latents' % hparams['n_ae_latents'],
'%02i_states' % hparams['n_arhmm_states'],
_get_transition_str(hparams), hparams['noise_type'], brain_region)
session_dir = hparams['session_dir']
elif model_class == 'labels-images':
model_path = os.path.join(model_class, model_type)
session_dir = hparams['session_dir']
else:
raise ValueError('"%s" is an invalid model class' % model_class)
expt_dir = os.path.join(session_dir, model_path, expt_name)
return expt_dir
[docs]def read_session_info_from_csv(session_file):
"""Read csv file that contains lab/expt/animal/session info.
Parameters
----------
session_file : :obj:`str`
/full/path/to/session_info.csv
Returns
-------
:obj:`list` of :obj:`dict`
dict for each session which contains lab/expt/animal/session
"""
import csv
sessions_multi = []
# load and parse csv file that contains single session info
with open(session_file) as csv_file:
csv_reader = csv.DictReader(csv_file)
for row in csv_reader:
sessions_multi.append(dict(row))
return sessions_multi
[docs]def export_session_info_to_csv(session_dir, ids_list):
"""Export list of sessions to csv file.
Parameters
----------
session_dir : :obj:`str`
absolute path for where to save :obj:`session_info.csv` file
ids_list : :obj:`list` of :obj:`dict`
dict for each session which contains lab/expt/animal/session
"""
import csv
session_file = os.path.join(session_dir, 'session_info.csv')
if not os.path.isdir(session_dir):
os.makedirs(session_dir)
with open(session_file, mode='w') as f:
session_writer = csv.DictWriter(f, fieldnames=list(ids_list[0].keys()))
session_writer.writeheader()
for ids in ids_list:
session_writer.writerow(ids)
[docs]def contains_session(session_dir, session_id):
"""Determine if session defined by `session_id` dict is in the multi-session `session_dir`.
Parameters
----------
session_dir : :obj:`str`
absolute path to multi-session directory that contains a :obj:`session_info.csv` file
session_id : :obj:`dict`
must contain keys 'lab', 'expt', 'animal' and 'session'
Returns
-------
:obj:`bool`
"""
session_ids = read_session_info_from_csv(os.path.join(session_dir, 'session_info.csv'))
contains_sess = False
for sess_id in session_ids:
sess_id.pop('save_dir', None)
if sess_id == session_id:
contains_sess = True
break
return contains_sess
[docs]def find_session_dirs(hparams):
"""Find all session dirs (single- and multi-session) that contain the session in hparams.
Parameters
----------
hparams : :obj:`dict`
must contain keys 'lab', 'expt', 'animal' and 'session'
Returns
-------
:obj:`list` of :obj:`str`
list of session directories containing session defined in :obj:`hparams`
"""
# TODO: refactor like get_session_dir?
ids = {s: hparams[s] for s in ['lab', 'expt', 'animal', 'session']}
lab = hparams['lab']
expts = get_subdirs(os.path.join(hparams['save_dir'], lab))
# need to grab all multi-sessions as well as the single session
session_dirs = [] # full paths
session_ids = [] # dict of lab/expt/animal/session
for expt in expts:
if expt[:5] == 'multi':
session_dir = os.path.join(hparams['save_dir'], lab, expt)
if contains_session(session_dir, ids):
session_dirs.append(session_dir)
session_ids.append({
'lab': lab, 'expt': 'all', 'animal': '', 'session': '',
'multisession': int(expt[-2:])})
continue
else:
animals = get_subdirs(os.path.join(hparams['save_dir'], lab, expt))
for animal in animals:
if animal[:5] == 'multi':
session_dir = os.path.join(hparams['save_dir'], lab, expt, animal)
if contains_session(session_dir, ids):
session_dirs.append(session_dir)
session_ids.append({
'lab': lab, 'expt': expt, 'animal': 'all', 'session': '',
'multisession': int(animal[-2:])})
continue
else:
sessions = get_subdirs(os.path.join(hparams['save_dir'], lab, expt, animal))
for session in sessions:
session_dir = os.path.join(hparams['save_dir'], lab, expt, animal, session)
if session[:5] == 'multi':
if contains_session(session_dir, ids):
session_dirs.append(session_dir)
session_ids.append({
'lab': lab, 'expt': expt, 'animal': animal, 'session': 'all',
'multisession': int(session[-2:])})
else:
tmp_ids = {'lab': lab, 'expt': expt, 'animal': animal, 'session': session}
if tmp_ids == ids:
session_dirs.append(session_dir)
session_ids.append({
'lab': lab, 'expt': expt, 'animal': animal, 'session': session,
'multisession': None})
return session_dirs, session_ids
[docs]def experiment_exists(hparams, which_version=False):
"""Search testtube versions to find if experiment with the same hyperparameters has been fit.
Parameters
----------
hparams : :obj:`dict`
needs to contain enough information to specify a test tube experiment (model + training
parameters)
which_version : :obj:`bool`, optional
:obj:`True` to return version number
Returns
-------
variable
- :obj:`bool` if :obj:`which_version=False`
- :obj:`tuple` (:obj:`bool`, :obj:`int`) if :obj:`which_version=True`
"""
import pickle
# fill out path info if not present
if 'expt_dir' not in hparams:
if 'session_dir' not in hparams:
hparams['session_dir'], _ = get_session_dir(
hparams, session_source=hparams.get('all_source', 'save'))
hparams['expt_dir'] = get_expt_dir(hparams)
try:
tt_versions = get_subdirs(hparams['expt_dir'])
except StopIteration:
# no versions yet
if which_version:
return False, None
else:
return False
# get model-specific params
hparams_less = get_model_params(hparams)
found_match = False
version = None
for version in tt_versions:
# load hparams
version_file = os.path.join(hparams['expt_dir'], version, 'meta_tags.pkl')
try:
with open(version_file, 'rb') as f:
hparams_ = pickle.load(f)
if all([hparams_[key] == hparams_less[key] for key in hparams_less.keys()]):
# found match - did it finish training?
if hparams_['training_completed']:
found_match = True
break
except IOError:
continue
if which_version and found_match:
return found_match, int(version.split('_')[-1])
elif which_version and not found_match:
return found_match, None
else:
return found_match
[docs]def get_model_params(hparams):
"""Returns dict containing all params considered essential for defining a model in that class.
Parameters
----------
hparams : :obj:`dict`
all relevant hparams for the given model class will be pulled from this dict
Returns
-------
:obj:`dict`
hparams dict
"""
model_class = hparams['model_class']
# start with general params
hparams_less = {
'rng_seed_data': hparams['rng_seed_data'],
'trial_splits': hparams['trial_splits'],
'train_frac': hparams['train_frac'],
'rng_seed_model': hparams['rng_seed_model'],
'model_class': hparams['model_class'],
'model_type': hparams['model_type'],
}
if model_class == 'ae' \
or model_class == 'vae' \
or model_class == 'beta-tcvae' \
or model_class == 'cond-vae' \
or model_class == 'cond-ae' \
or model_class == 'cond-ae-msp' \
or model_class == 'ps-vae' \
or model_class == 'msps-vae':
hparams_less['n_ae_latents'] = hparams['n_ae_latents']
hparams_less['fit_sess_io_layers'] = hparams['fit_sess_io_layers']
hparams_less['learning_rate'] = hparams['learning_rate']
hparams_less['l2_reg'] = hparams['l2_reg']
if model_class == 'cond-ae' or model_class == 'cond-vae':
hparams_less['conditional_encoder'] = hparams.get('conditional_encoder', False)
if model_class == 'cond-ae-msp':
hparams_less['msp.alpha'] = hparams['msp.alpha']
if model_class == 'vae' or model_class == 'cond-vae':
hparams_less['vae.beta'] = hparams['vae.beta']
# hparams_less['vae.beta_anneal_epochs'] = hparams['vae.beta_anneal_epochs']
if model_class == 'beta-tcvae':
hparams_less['beta_tcvae.beta'] = hparams['beta_tcvae.beta']
if model_class == 'ps-vae' or model_class == 'msps-vae':
hparams_less['ps_vae.alpha'] = hparams['ps_vae.alpha']
hparams_less['ps_vae.beta'] = hparams['ps_vae.beta']
if model_class == 'msps-vae':
hparams_less['ps_vae.delta'] = hparams['ps_vae.delta']
hparams_less['n_background'] = hparams['n_background']
hparams_less['n_sessions_per_batch'] = hparams['n_sessions_per_batch']
# hparams_less['ps_vae.ms_loss'] = hparams['ps_vae.ms_loss']
elif model_class == 'arhmm' or model_class == 'hmm':
hparams_less['n_arhmm_lags'] = hparams['n_arhmm_lags']
hparams_less['noise_type'] = hparams['noise_type']
hparams_less['transitions'] = hparams['transitions']
if hparams['transitions'] == 'sticky':
hparams_less['kappa'] = hparams['kappa']
hparams_less['ae_experiment_name'] = hparams['ae_experiment_name']
hparams_less['ae_version'] = hparams['ae_version']
hparams_less['ae_model_class'] = hparams['ae_model_class']
hparams_less['ae_model_type'] = hparams['ae_model_type']
hparams_less['n_ae_latents'] = hparams['n_ae_latents']
elif model_class == 'arhmm-labels' or model_class == 'hmm-labels':
hparams_less['n_arhmm_lags'] = hparams['n_arhmm_lags']
hparams_less['noise_type'] = hparams['noise_type']
hparams_less['transitions'] = hparams['transitions']
if hparams['transitions'] == 'sticky':
hparams_less['kappa'] = hparams['kappa']
elif model_class == 'neural-ae' or model_class == 'neural-ae-me' or model_class == 'ae-neural':
hparams_less['ae_experiment_name'] = hparams['ae_experiment_name']
hparams_less['ae_version'] = hparams['ae_version']
hparams_less['ae_model_class'] = hparams['ae_model_class']
hparams_less['ae_model_type'] = hparams['ae_model_type']
hparams_less['n_ae_latents'] = hparams['n_ae_latents']
elif model_class == 'neural-labels' or model_class == 'labels-neural':
pass
elif model_class == 'neural-arhmm' or model_class == 'arhmm-neural':
hparams_less['arhmm_experiment_name'] = hparams['arhmm_experiment_name']
hparams_less['arhmm_version'] = hparams['arhmm_version']
hparams_less['n_arhmm_states'] = hparams['n_arhmm_states']
hparams_less['n_arhmm_lags'] = hparams['n_arhmm_lags']
hparams_less['noise_type'] = hparams['noise_type']
hparams_less['transitions'] = hparams['transitions']
if hparams['transitions'] == 'sticky':
hparams_less['kappa'] = hparams['kappa']
hparams_less['ae_model_class'] = hparams['ae_model_class']
hparams_less['ae_model_type'] = hparams['ae_model_type']
hparams_less['n_ae_latents'] = hparams['n_ae_latents']
elif model_class == 'bayesian-decoding':
raise NotImplementedError
elif model_class == 'labels-images':
hparams_less['fit_sess_io_layers'] = hparams['fit_sess_io_layers']
hparams_less['learning_rate'] = hparams['learning_rate']
hparams_less['l2_reg'] = hparams['l2_reg']
else:
raise NotImplementedError('"%s" is not a valid model class' % model_class)
# decoder arch params
if model_class == 'neural-ae' or model_class == 'neural-ae-me' or model_class == 'ae-neural' \
or model_class == 'neural-arhmm' or model_class == 'arhmm-neural' \
or model_class == 'neural-labels' or model_class == 'labels-neural':
hparams_less['learning_rate'] = hparams['learning_rate']
hparams_less['n_lags'] = hparams['n_lags']
hparams_less['l2_reg'] = hparams['l2_reg']
hparams_less['model_type'] = hparams['model_type']
hparams_less['n_hid_layers'] = hparams['n_hid_layers']
if hparams['n_hid_layers'] != 0:
hparams_less['n_hid_units'] = hparams['n_hid_units']
hparams_less['activation'] = hparams['activation']
hparams_less['subsample_method'] = hparams['subsample_method']
if hparams_less['subsample_method'] != 'none':
hparams_less['subsample_idxs_name'] = hparams['subsample_idxs_name']
hparams_less['subsample_idxs_group_0'] = hparams['subsample_idxs_group_0']
hparams_less['subsample_idxs_group_1'] = hparams['subsample_idxs_group_1']
return hparams_less
[docs]def export_hparams(hparams, exp):
"""Export hyperparameter dictionary.
The dict is export once as a csv file (for easy human reading) and again as a pickled dict
(for easy python loading/parsing).
Parameters
----------
hparams : :obj:`dict`
hyperparameter dict to export
exp : :obj:`test_tube.Experiment` object
defines where parameters are saved
"""
import pickle
# save out as pickle
meta_file = os.path.join(hparams['expt_dir'], 'version_%i' % exp.version, 'meta_tags.pkl')
with open(meta_file, 'wb') as f:
pickle.dump(hparams, f)
# save out as csv
exp.tag(hparams)
exp.save()
[docs]def get_lab_example(hparams, lab, expt):
"""Helper function to load data-specific hyperparameters and update hparams.
These values are loaded from the json file defined by :obj:`lab` and :obj:`expt` in the
:obj:`.behavenet` user directory. See
https://behavenet.readthedocs.io/en/latest/source/installation.html#adding-a-new-dataset
for more information.
Parameters
----------
hparams : :obj:`dict`
hyperparmeter dict to update
lab : :obj:`str`
lab id
expt : :obj:`str`
expt id
"""
import json
from behavenet import get_params_dir
params_file = os.path.join(get_params_dir(), str('%s_%s_params.json' % (lab, expt)))
with open(params_file, 'r') as f:
dparams = json.load(f)
hparams.update(dparams)
[docs]def get_region_dir(hparams):
"""Return brain region string that combines region name and inclusion info.
If not subsampling regions, will return :obj:`'all'`
If using neural activity from *only* specified region, will return e.g. :obj:`'mctx-single'`
If using neural activity from all *but* specified region (leave-one-out), will return e.g.
:obj:`'mctx-loo'`
Parameters
----------
hparams : :obj:`dict`
must contain the key 'subsample_regions', else function assumes no subsampling
Returns
-------
:obj:`str`
region directory name
"""
if hparams.get('subsample_method', 'none') == 'none':
region_dir = 'all'
elif hparams['subsample_method'] == 'single':
region_dir = str('%s-single' % hparams['subsample_idxs_name'])
elif hparams['subsample_method'] == 'loo':
region_dir = str('%s-loo' % hparams['subsample_idxs_name'])
else:
raise ValueError('"%s" is an invalid sampling type' % hparams['subsample_method'])
return region_dir
[docs]def create_tt_experiment(hparams):
"""Create test-tube experiment for logging training and storing models.
Parameters
----------
hparams : :obj:`dict`
dictionary of hyperparameters defining experiment that will be saved as a csv file
Returns
-------
:obj:`tuple`
- if experiment defined by hparams already exists, returns :obj:`(None, None, None)`
- if experiment does not exist, returns :obj:`(hparams, sess_ids, exp)`
"""
from test_tube import Experiment
# get session_dir
hparams['session_dir'], sess_ids = get_session_dir(
hparams, session_source=hparams.get('all_source', 'save'))
if not os.path.isdir(hparams['session_dir']):
os.makedirs(hparams['session_dir'])
export_session_info_to_csv(hparams['session_dir'], sess_ids)
hparams['expt_dir'] = get_expt_dir(hparams)
if not os.path.isdir(hparams['expt_dir']):
os.makedirs(hparams['expt_dir'])
# check to see if experiment already exists
if experiment_exists(hparams):
return None, None, None
exp = Experiment(
name=hparams['experiment_name'],
debug=False,
save_dir=os.path.dirname(hparams['expt_dir']))
exp.save()
hparams['version'] = exp.version
return hparams, sess_ids, exp
[docs]def get_best_model_version(expt_dir, measure='val_loss', best_def='min', n_best=1):
"""Get best model version from a test tube experiment.
Parameters
----------
expt_dir : :obj:`str`
test tube experiment directory containing version_%i subdirectories
measure : :obj:`str`, optional
heading in csv file that is used to determine which model is best
best_def : :obj:`str`, optional
how :obj:`measure` should be parsed; 'min' | 'max'
n_best : :obj:`int`, optional
top `n_best` models are returned
Returns
-------
:obj:`list`
list of best models, with best first
"""
import pickle
import pandas as pd
# gather all versions
versions = get_subdirs(expt_dir)
# load csv files with model metrics (saved out from test tube)
metrics = []
for i, version in enumerate(versions):
# make sure training has been completed
meta_file = os.path.join(expt_dir, version, 'meta_tags.pkl')
if not os.path.exists(meta_file):
continue
with open(meta_file, 'rb') as f:
meta_tags = pickle.load(f)
if not meta_tags['training_completed']:
continue
# read metrics csv file
metric = pd.read_csv(os.path.join(expt_dir, version, 'metrics.csv'))
# get validation loss of best model
if best_def == 'min':
val_loss = metric[measure].min()
elif best_def == 'max':
val_loss = metric[measure].max()
metrics.append(pd.DataFrame({'loss': val_loss, 'version': version}, index=[i]))
# put everything in pandas dataframe
metrics_df = pd.concat(metrics, sort=False)
# get version with smallest loss
if n_best == 1:
if best_def == 'min':
best_versions = [metrics_df['version'][metrics_df['loss'].idxmin()]]
elif best_def == 'max':
best_versions = [metrics_df['version'][metrics_df['loss'].idxmax()]]
else:
if best_def == 'min':
best_versions = np.asarray(
metrics_df['version'][metrics_df['loss'].nsmallest(n_best, 'all').index])
elif best_def == 'max':
raise NotImplementedError
if best_versions.shape[0] != n_best:
print('More versions than specified due to same validation loss')
# convert string to integer
best_versions = [int(version.split('_')[-1]) for version in best_versions]
return best_versions
[docs]def get_best_model_and_data(hparams, Model=None, load_data=True, version='best', data_kwargs=None):
"""Load the best model (and data) defined by hparams out of all available test-tube versions.
Parameters
----------
hparams : :obj:`dict`
needs to contain enough information to specify both a model and the associated data
Model : :obj:`behavenet.models` object, optional
model type
load_data : :obj:`bool`, optional
if `False` then data generator is not returned
version : :obj:`str` or :obj:`int`, optional
can be 'best' to load best model
data_kwargs : :obj:`dict`, optional
additional kwargs for data generator
Returns
-------
:obj:`tuple`
- model (:obj:`behavenet.models` object)
- data generator (:obj:`ConcatSessionsGenerator` object or :obj:`NoneType`)
"""
import torch
from behavenet.data.data_generator import ConcatSessionsGenerator
from behavenet.data.utils import get_data_generator_inputs
# get session_dir
hparams['session_dir'], sess_ids = get_session_dir(
hparams, session_source=hparams.get('all_source', 'save'))
expt_dir = get_expt_dir(hparams)
# get best model version
if version == 'best':
best_version_int = get_best_model_version(expt_dir)[0]
best_version = str('version_{}'.format(best_version_int))
elif version is None:
# try to match hparams
_, version_hp = experiment_exists(hparams, which_version=True)
best_version = str('version_{}'.format(version_hp))
else:
if isinstance(version, str) and version[0] == 'v':
# assume we got a string of the form 'version_{%i}'
best_version = version
else:
best_version = str('version_{}'.format(version))
# get int representation as well
version_dir = os.path.join(expt_dir, best_version)
arch_file = os.path.join(version_dir, 'meta_tags.pkl')
model_file = os.path.join(version_dir, 'best_val_model.pt')
if not os.path.exists(model_file) and not os.path.exists(model_file + '.meta'):
model_file = os.path.join(version_dir, 'best_val_model.ckpt')
print('Loading model defined in %s' % arch_file)
with open(arch_file, 'rb') as f:
hparams_new = pickle.load(f)
# update paths if performing analysis on a different machine
hparams_new['data_dir'] = hparams['data_dir']
hparams_new['session_dir'] = hparams['session_dir']
hparams_new['expt_dir'] = expt_dir
hparams_new['use_output_mask'] = hparams.get('use_output_mask', False)
hparams_new['use_label_mask'] = hparams.get('use_label_mask', False)
hparams_new['device'] = hparams.get('device', 'cpu')
# build data generator
hparams_new, signals, transforms, paths = get_data_generator_inputs(hparams_new, sess_ids)
if load_data:
# sometimes we want a single data_generator for multiple models
if data_kwargs is None:
data_kwargs = {}
data_generator = ConcatSessionsGenerator(
hparams_new['data_dir'], sess_ids,
signals_list=signals, transforms_list=transforms, paths_list=paths,
device=hparams_new['device'], as_numpy=hparams_new['as_numpy'],
batch_load=hparams_new['batch_load'], rng_seed=hparams_new['rng_seed_data'],
train_frac=hparams_new['train_frac'], **data_kwargs)
else:
data_generator = None
# build model
if Model is None:
if hparams['model_class'] == 'ae':
from behavenet.models import AE as Model
elif hparams['model_class'] == 'vae':
from behavenet.models import VAE as Model
elif hparams['model_class'] == 'cond-ae':
from behavenet.models import ConditionalAE as Model
elif hparams['model_class'] == 'cond-vae':
from behavenet.models import ConditionalVAE as Model
elif hparams['model_class'] == 'cond-ae-msp':
from behavenet.models import AEMSP as Model
elif hparams['model_class'] == 'beta-tcvae':
from behavenet.models import BetaTCVAE as Model
elif hparams['model_class'] == 'ps-vae':
from behavenet.models import PSVAE as Model
elif hparams['model_class'] == 'msps-vae':
from behavenet.models import MSPSVAE as Model
elif hparams['model_class'] == 'labels-images':
from behavenet.models import ConvDecoder as Model
elif hparams['model_class'] == 'neural-ae' or hparams['model_class'] == 'neural-ae-me' \
or hparams['model_class'] == 'neural-arhmm' \
or hparams['model_class'] == 'neural-labels':
from behavenet.models import Decoder as Model
elif hparams['model_class'] == 'ae-neural' or hparams['model_class'] == 'arhmm-neural' \
or hparams['model_class'] == 'labels-neural':
from behavenet.models import Decoder as Model
elif hparams['model_class'] == 'arhmm':
raise NotImplementedError('Cannot use get_best_model_and_data() for ssm models')
else:
raise NotImplementedError
model = Model(hparams_new)
model.version = int(best_version.split('_')[1])
model.load_state_dict(torch.load(model_file, map_location=lambda storage, loc: storage))
model.to(hparams_new['device'])
model.eval()
return model, data_generator
def _clean_tt_dir(hparams):
"""Delete all (unnecessary) subdirectories in the model directory (created test-tube)"""
import shutil
# get subdirs
version_dir = os.path.join(hparams['expt_dir'], 'version_%i' % hparams['version'])
subdirs = get_subdirs(version_dir)
for subdir in subdirs:
shutil.rmtree(os.path.join(version_dir, subdir))
def _print_hparams(hparams):
"""Pretty print hparams to console."""
import commentjson
config_files = ['data', 'compute', 'training', 'model']
for config_file in config_files:
print('\n%s CONFIG:' % config_file.upper())
config_json = commentjson.load(open(hparams['%s_config' % config_file], 'r'))
for key in config_json.keys():
print(' {}: {}'.format(key, hparams[key]))
print('')