"""Utility functions for evaluating model fits."""
import numpy as np
[docs]def export_latents(data_generator, model, filename=None):
"""Export predicted latents using an already initialized data_generator and model.
Latents are saved based on the model's hparams dict unless another file is provided. The
default filename is `[lab_id]_[expt_id]_[animal_id]_[session_id]_latents.pkl`.
Parameters
----------
data_generator : :obj:`ConcatSessionGenerator` object
data generator to use for latent creation
model : :obj:`AE` object
pytorch model
filename : :obj:`str` or :obj:`NoneType`, optional
absolute path to save latents; if :obj:`NoneType`, latents are stored in model directory
Returns
-------
:obj:`list`
list of latent filenames
"""
import pickle
import os
import torch
if model.hparams['model_class'] == 'msps-vae':
filenames = model.export_latents(data_generator, filename=filename)
return filenames
model.eval()
# initialize container for latents
latents = [[] for _ in range(data_generator.n_datasets)]
for sess, dataset in enumerate(data_generator.datasets):
latents[sess] = [np.array([]) for _ in range(dataset.n_trials)]
# partially fill container (gap trials will be included as nans)
dtypes = ['train', 'val', 'test']
for dtype in dtypes:
data_generator.reset_iterators(dtype)
for i in range(data_generator.n_tot_batches[dtype]):
data, sess = data_generator.next_batch(dtype)
# process batch, perhaps in chunks if full batch is too large to fit on gpu
chunk_size = 200
y = data['images'][0]
if model.hparams['model_class'] == 'cond-ae' and \
model.hparams.get('conditional_encoder', False):
labels_2d = data['labels_sc'][0]
else:
labels_2d = None
batch_size = y.shape[0]
if batch_size > chunk_size:
latents[sess][data['batch_idx'].item()] = np.full(
shape=(data['images'].shape[1], model.hparams['n_ae_latents']),
fill_value=np.nan)
# split into chunks
n_chunks = int(np.ceil(batch_size / chunk_size))
for chunk in range(n_chunks):
# take chunks of size chunk_size, plus overlap due to
# max_lags
idx_beg = chunk * chunk_size
idx_end = np.min([(chunk + 1) * chunk_size, batch_size])
if labels_2d is not None:
y_in = torch.cat((y[idx_beg:idx_end], labels_2d[idx_beg:idx_end]), dim=1)
else:
y_in = y[idx_beg:idx_end]
output = model.encoding(y_in, dataset=sess)
if model.hparams['model_class'] == 'ps-vae':
curr_latents = torch.cat([output[0], output[1]], axis=1)
else:
curr_latents = output[0]
if model.hparams['model_class'] == 'cond-ae-msp':
# push latents through linear transformation
curr_latents = model.U(curr_latents)
latents[sess][data['batch_idx'].item()][idx_beg:idx_end, :] = \
curr_latents.cpu().detach().numpy()
else:
if labels_2d is not None:
y_in = torch.cat((y, labels_2d), dim=1)
else:
y_in = y
output = model.encoding(y_in, dataset=sess)
if model.hparams['model_class'] == 'ps-vae':
curr_latents = torch.cat([output[0], output[1]], axis=1)
else:
curr_latents = output[0]
if model.hparams['model_class'] == 'cond-ae-msp':
# push latents through linear transformation
curr_latents = model.U(curr_latents)
latents[sess][data['batch_idx'].item()] = curr_latents.cpu().detach().numpy()
# save latents separately for each dataset
filenames = []
for sess, dataset in enumerate(data_generator.datasets):
if filename is None:
# get save name which includes lab/expt/animal/session
sess_id = str('%s_%s_%s_%s_latents.pkl' % (
dataset.lab, dataset.expt, dataset.animal, dataset.session))
filename_save = os.path.join(
model.hparams['expt_dir'], 'version_%i' % model.version, sess_id)
else:
filename_save = filename
# save out array in pickle file
print(
'saving latents %i of %i:\n%s' % (sess + 1, data_generator.n_datasets, filename_save))
latents_dict = {'latents': latents[sess], 'trials': dataset.batch_idxs}
with open(filename_save, 'wb') as f:
pickle.dump(latents_dict, f)
filenames.append(filename_save)
return filenames
[docs]def export_states(hparams, data_generator, model, filename=None):
"""Export predicted latents using an already initialized data_generator and model.
States are saved based on the hparams dict unless another file is provided. The default
filename is `[lab_id]_[expt_id]_[animal_id]_[session_id]_states.pkl`.
Parameters
----------
hparams : :obj:`dict`
needs to contain 'expt_dir' and 'version'
data_generator : :obj:`ConcatSessionGenerator` object
data generator to use for latent creation
model : :obj:`HMM` object
ssm model
filename : :obj:`str` or :obj:`NoneType`, optional
absolute path to save latents; if :obj:`NoneType`, latents are stored in model directory
Returns
-------
:obj:`list`
list of state filenames
"""
import pickle
import os
# initialize container for states
states = [[] for _ in range(data_generator.n_datasets)]
for sess, dataset in enumerate(data_generator.datasets):
states[sess] = [np.array([]) for _ in range(dataset.n_trials)]
# partially fill container (gap trials will be included as nans)
dtypes = ['train', 'val', 'test']
for dtype in dtypes:
data_generator.reset_iterators(dtype)
for i in range(data_generator.n_tot_batches[dtype]):
data, sess = data_generator.next_batch(dtype)
# process batch
if hparams['model_class'].find('label') > -1:
y = data['labels'][0][0]
else:
y = data['ae_latents'][0][0]
# batch_size = y.shape[0]
curr_states = model.most_likely_states(y)
states[sess][data['batch_idx'].item()] = curr_states
# save states separately for each dataset
filenames = []
for sess, dataset in enumerate(data_generator.datasets):
if filename is None:
# get save name which includes lab/expt/animal/session
sess_id = str('%s_%s_%s_%s_states.pkl' % (
dataset.lab, dataset.expt, dataset.animal, dataset.session))
filename_save = os.path.join(
hparams['expt_dir'], 'version_%i' % hparams['version'], sess_id)
else:
filename_save = filename
# save out array in pickle file
print('saving states %i of %i:\n%s' % (sess + 1, data_generator.n_datasets, filename_save))
states_dict = {'states': states[sess], 'trials': dataset.batch_idxs}
with open(filename_save, 'wb') as f:
pickle.dump(states_dict, f)
filenames.append(filename_save)
return filenames
[docs]def export_predictions(data_generator, model, filename=None):
"""Export decoder predictions using an already initialized data_generator and model.
Predictions are saved based on the model's hparams dict unless another file is provided. The
default filename is `[lab_id]_[expt_id]_[animal_id]_[session_id]_predictions.pkl`.
This function only supports pytorch decoding models - not autoencoders. To get AE
reconstructions see the `get_reconstruction` function in this module.
Parameters
----------
data_generator : :obj:`ConcatSessionGenerator` object
data generator to use for latent creation
model : :obj:`NN` object
pytorch model
filename : :obj:`str` or :obj:`NoneType`, optional
absolute path to save latents; if :obj:`NoneType`, latents are stored in model directory
Returns
-------
:obj:`list`
list of prediction filenames
"""
import pickle
import os
model.eval()
# initialize container for latents
predictions = [[] for _ in range(data_generator.n_datasets)]
for sess, dataset in enumerate(data_generator.datasets):
predictions[sess] = [np.array([]) for _ in range(dataset.n_trials)]
# partially fill container (gap trials will be included as nans)
max_lags = model.hparams['n_max_lags']
dtypes = ['train', 'val', 'test']
for dtype in dtypes:
data_generator.reset_iterators(dtype)
for i in range(data_generator.n_tot_batches[dtype]):
data, sess = data_generator.next_batch(dtype)
predictors = data[model.hparams['input_signal']][0]
targets = data[model.hparams['output_signal']][0]
trial_len = targets.shape[0]
predictions[sess][data['batch_idx'].item()] = np.full(
shape=(trial_len, model.hparams['output_size']), fill_value=np.nan)
# process batch, perhaps in chunks if full batch is too large
# to fit on gpu
chunk_size = 200
batch_size = targets.shape[0]
if batch_size > chunk_size:
# split into chunks
n_chunks = int(np.ceil(batch_size / chunk_size))
for chunk in range(n_chunks):
# take chunks of size chunk_size, plus overlap due to
# max_lags
idx_beg = np.max([chunk * chunk_size - max_lags, 0])
idx_end = np.min([(chunk + 1) * chunk_size + max_lags, batch_size])
outputs, _ = model(predictors[idx_beg:idx_end])
slc = (idx_beg + max_lags, idx_end - max_lags)
predictions[sess][data['batch_idx'].item()][slice(*slc), :] = \
outputs[max_lags:-max_lags].cpu().detach().numpy()
else:
outputs, _ = model(predictors)
slc = (max_lags, -max_lags)
predictions[sess][data['batch_idx'].item()][slice(*slc), :] = \
outputs[max_lags:-max_lags].cpu().detach().numpy()
# save latents separately for each dataset
filenames = []
for sess, dataset in enumerate(data_generator.datasets):
if filename is None:
# get save name which includes lab/expt/animal/session
sess_id = str('%s_%s_%s_%s_predictions.pkl' % (
dataset.lab, dataset.expt, dataset.animal, dataset.session))
filename_save = os.path.join(
model.hparams['expt_dir'], 'version_%i' % model.version, sess_id)
else:
filename_save = filename
# save out array in pickle file
print(
'saving predictions %i of %i to %s' %
(sess + 1, data_generator.n_datasets, filename_save))
predictions_dict = {'predictions': predictions[sess], 'trials': dataset.batch_idxs}
with open(filename_save, 'wb') as f:
pickle.dump(predictions_dict, f)
filenames.append(filename_save)
return filenames
[docs]def get_reconstruction(
model, inputs, dataset=None, return_latents=False, labels=None, labels_2d=None,
apply_inverse_transform=True, use_mean=True):
"""Reconstruct an image from either image or latent inputs.
Parameters
----------
model : :obj:`AE` object
pytorch model
inputs : :obj:`torch.Tensor` object
- image tensor of shape (batch, channels, y_pix, x_pix)
- latents tensor of shape (batch, n_ae_latents)
dataset : :obj:`int` or :obj:`NoneType`, optional
for use with session-specific io layers
return_latents : :obj:`bool`, optional
if :obj:`True` return tuple of (recon, latents)
labels : :obj:`torch.Tensor` object or :obj:`NoneType`, optional
label tensor of shape (batch, n_labels)
labels_2d : :obj:`torch.Tensor` object or :obj:`NoneType`, optional
label tensor of shape (batch, n_labels, y_pix, x_pix)
apply_inverse_transform : :obj:`bool`
if inputs are latents (and model class is 'cond-ae-msp' or 'ps-vae'), apply inverse
transform to put in original latent space
use_mean : :obj:`bool`
if inputs are images (and model class is variational), use mean of approximate posterior
without sampling
Returns
-------
:obj:`np.ndarray`
reconstructed images of shape (batch, channels, y_pix, x_pix)
"""
import torch
model.eval()
if not isinstance(inputs, torch.Tensor):
inputs = torch.Tensor(inputs).to(model.hparams['device'])
# check to see if inputs are images or latents
if len(inputs.shape) == 2:
input_type = 'latents'
else:
input_type = 'images'
if input_type == 'images':
if model.hparams['model_class'] == 'ae':
ims_recon, latents = model(inputs, dataset=dataset)
elif model.hparams['model_class'] == 'cond-ae-msp':
ims_recon, latents, _ = model(inputs, dataset=dataset)
elif model.hparams['model_class'] == 'vae' \
or model.hparams['model_class'] == 'beta-tcvae':
ims_recon, latents, _, _ = model(inputs, dataset=dataset, use_mean=use_mean)
elif model.hparams['model_class'] == 'ps-vae' \
or model.hparams['model_class'] == 'msps-vae':
ims_recon, _, latents, _, _ = model(inputs, dataset=dataset, use_mean=use_mean)
elif model.hparams['model_class'] == 'cond-ae':
ims_recon, latents = model(inputs, dataset=dataset, labels=labels, labels_2d=labels_2d)
elif model.hparams['model_class'] == 'cond-vae':
ims_recon, latents, _, _ = model(
inputs, dataset=dataset, labels=labels, labels_2d=labels_2d)
else:
raise ValueError('Invalid model class %s' % model.hparams['model_class'])
else: # input is latents
# TODO: how to incorporate maxpool layers for decoding only?
if model.hparams['model_class'] == 'cond-ae' or model.hparams['model_class'] == 'cond-vae':
inputs = torch.cat((inputs, labels), dim=1)
elif model.hparams['model_class'] == 'cond-ae-msp' and apply_inverse_transform:
inputs = model.get_inverse_transformed_latents(inputs, as_numpy=False)
elif model.hparams['model_class'] == 'ps-vae' and apply_inverse_transform:
# assume "inputs" are [labels, unsupervised latents] where "labels" need to be
# transformed into N(0, 1) latent space
inputs = model.get_inverse_transformed_latents(inputs, as_numpy=False)
elif model.hparams['model_class'] == 'msps-vae' and apply_inverse_transform:
# assume "inputs" are [labels, background latents, unsupervised latents] where "labels"
# need to be transformed into N(0, 1) latent space
inputs = model.get_inverse_transformed_latents(inputs, as_numpy=False)
else:
pass
ims_recon = model.decoding(inputs, None, None, dataset=None)
latents = inputs
ims_recon = ims_recon.cpu().detach().numpy()
latents = latents.cpu().detach().numpy()
if return_latents:
return ims_recon, latents
else:
return ims_recon
[docs]def get_test_metric(
hparams, model_version, metric='r2', dtype='test', multioutput='variance_weighted',
sess_idx=0):
"""Calculate a single R\ :sup:`2` value across all test batches for a decoder.
Parameters
----------
hparams : :obj:`dict`
needs to contain enough information to specify an autoencoder
model_version : :obj:`int` or :obj:`str`
version from test tube experiment defined in :obj:`hparams` or the string 'best'
metric : :obj:`str`, optional
'r2' | 'fc' | 'mse'
dtype : :obj:`str`
type of trials to use for computing metric
'train' | 'val' | 'test'
multioutput : :obj:`str`
defines how to aggregate multiple r2 scores; see r2_score documentation in sklearn
'raw_values' | 'uniform_average' | 'variance_weighted'
sess_idx : :obj:`int`, optional
session index into data generator
Returns
-------
:obj:`tuple`
- hparams (:obj:`dict`): hparams of model used to calculate metrics
- metric (:obj:`int`)
"""
from sklearn.metrics import r2_score, accuracy_score
from behavenet.fitting.utils import get_best_model_and_data
from behavenet.models import Decoder
model, data_generator = get_best_model_and_data(
hparams, Decoder, load_data=True, version=model_version)
n_test_batches = len(data_generator.datasets[sess_idx].batch_idxs[dtype])
max_lags = hparams['n_max_lags']
true = []
pred = []
data_generator.reset_iterators(dtype)
for i in range(n_test_batches):
batch, _ = data_generator.next_batch(dtype)
# get true latents/states
if metric == 'r2' or metric == 'mse':
if 'ae_latents' in batch:
curr_true = batch['ae_latents'][0].cpu().detach().numpy()
elif 'labels' in batch:
curr_true = batch['labels'][0].cpu().detach().numpy()
else:
raise ValueError('no valid key in {}'.format(batch.keys()))
elif metric == 'fc':
curr_true = batch['arhmm_states'][0].cpu().detach().numpy()
else:
raise ValueError('"%s" is an invalid metric type' % metric)
# get predicted latents
curr_pred = model(batch['neural'][0])[0].cpu().detach().numpy()
true.append(curr_true[max_lags:-max_lags])
pred.append(curr_pred[max_lags:-max_lags])
if metric == 'r2':
metric = r2_score(
np.concatenate(true, axis=0), np.concatenate(pred, axis=0), multioutput=multioutput)
elif metric == 'mse':
metric = np.mean(np.square(np.concatenate(true, axis=0) - np.concatenate(pred, axis=0)))
elif metric == 'fc':
metric = accuracy_score(
np.concatenate(true, axis=0), np.argmax(np.concatenate(pred, axis=0), axis=1))
return model.hparams, metric, true, pred
[docs]def export_train_plots(hparams, dtype, loss_type='mse', save_file=None, format='png'):
"""Export plot with MSE/LL as a function of training epochs.
Parameters
----------
hparams : :obj:`dict`
needs to contain enough information to specify the desired model (autoencoder, arhmm, etc.)
dtype : :obj:`str`
type of trials to use for plotting: 'train' | 'val' (metrics are not computed for 'test'
trials throughout training)
loss_type : :obj:`str`, optional
'mse' | 'll'
save_file : :obj:`str` or :obj:`NoneType`, optional
full filename (absolute path) for saving plot; if :obj:`NoneType`, plot is displayed
format : :obj:`str`
file format of plot, e.g. 'png' | 'pdf' | 'jpeg'
"""
import os
import pandas as pd
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
from behavenet.fitting.utils import read_session_info_from_csv
mpl.use('Agg') # deal with display-less machines
sns.set_style('white')
sns.set_context('talk')
# find metrics csv file
version_dir = os.path.join(hparams['expt_dir'], 'version_%i' % hparams['version'])
metric_file = os.path.join(version_dir, 'metrics.csv')
metrics = pd.read_csv(metric_file)
# collect data from csv file
sess_ids = read_session_info_from_csv(os.path.join(version_dir, 'session_info.csv'))
sess_ids_strs = []
for sess_id in sess_ids:
sess_ids_strs.append(str('%s/%s' % (sess_id['animal'], sess_id['session'])))
metrics_df = []
for i, row in metrics.iterrows():
dataset = 'all' if row['dataset'] == -1 else sess_ids_strs[row['dataset']]
metrics_df.append(pd.DataFrame({
'dataset': dataset,
'epoch': row['epoch'],
'loss': row['val_loss'],
'dtype': 'val',
}, index=[0]))
metrics_df.append(pd.DataFrame({
'dataset': dataset,
'epoch': row['epoch'],
'loss': row['tr_loss'],
'dtype': 'train',
}, index=[0]))
metrics_df = pd.concat(metrics_df)
# plot data
data_queried = metrics_df[
(metrics_df.dtype == dtype) &
(metrics_df.epoch > 0) &
~pd.isna(metrics_df.loss)]
splt = sns.relplot(x='epoch', y='loss', hue='dataset', kind='line', data=data_queried)
splt.ax.set_xlabel('Epoch')
if loss_type == 'mse':
splt.ax.set_yscale('log')
splt.ax.set_ylabel('MSE per pixel')
elif loss_type == 'll':
splt.ax.set_ylabel('Neg log prob per datapoint')
else:
raise ValueError('"%s" is an invalid loss type' % loss_type)
title_str = 'Validation' if dtype == 'val' else 'Training'
plt.title('%s loss' % title_str)
if save_file is not None:
plt.savefig(str('%s.%s' % (save_file, format)), dpi=300, format=format)
plt.close()
else:
plt.show()
return splt