Source code for behavenet.plotting

"""Utility functions shared across multiple plotting modules."""

from matplotlib.animation import FFMpegWriter
import numpy as np
import os
import pickle
import pandas as pd

from behavenet import make_dir_if_not_exists
from behavenet.fitting.utils import experiment_exists
from behavenet.fitting.utils import get_expt_dir
from behavenet.fitting.utils import get_session_dir
from behavenet.fitting.utils import get_best_model_version
from behavenet.fitting.utils import get_lab_example
from behavenet.fitting.utils import read_session_info_from_csv

# to ignore imports for sphix-autoapidoc
__all__ = ['concat', 'get_crop', 'load_latents', 'load_metrics_csv_as_df', 'save_movie']

# TODO: use load_metrics_csv_as_df in ae example notebook


[docs]def concat(ims, axis=1): """Concatenate two channels along x or y direction (useful for data with multiple views). Parameters ---------- ims : :obj:`np.ndarray` shape (2, y_pix, x_pix) axis : :obj:`int` axis along which to concatenate; 0 = y dir, 1 = x dir Returns ------- :obj:`np.ndarray` shape (2 * y_pix, x_pix) (if :obj:`axis=0`) or shape (y_pix, 2 * x_pix) (if :obj:`axis=1`) """ return np.concatenate([ims[0, :, :], ims[1, :, :]], axis=axis)
[docs]def get_crop(im, y_0, y_ext, x_0, x_ext): """Get crop of image, filling in borders with zeros. Parameters ---------- im : :obj:`np.ndarray` input image y_0 : :obj:`int` y-pixel center value y_ext : :obj:`int` y-pixel extent; crop in y-direction will be [y_0 - y_ext, y_0 + y_ext] x_0 : :obj:`int` y-pixel center value x_ext : :obj:`int` x-pixel extent; crop in x-direction will be [x_0 - x_ext, x_0 + x_ext] Returns ------- :obj:`np.ndarray` cropped image """ y_min = y_0 - y_ext y_max = y_0 + y_ext y_pix = y_max - y_min x_min = x_0 - x_ext x_max = x_0 + x_ext x_pix = x_max - x_min im_crop = np.copy(im[y_min:y_max, x_min:x_max]) y_pix_, x_pix_ = im_crop.shape im_tmp = np.zeros((y_pix, x_pix)) im_tmp[:y_pix_, :x_pix_] = im_crop return im_tmp
[docs]def load_latents(hparams, version, dtype='val'): """Load all latents as a single array. Parameters ---------- hparams : :obj:`dict` needs to contain enough information to specify both a model and the associated data version : :obj:`int` version from test tube experiment defined in :obj:`hparams` dtype : :obj:`str` 'train' | 'val' | 'test' Returns ------- :obj:`np.ndarray` shape (time, n_latents) """ sess_id = str('%s_%s_%s_%s_latents.pkl' % ( hparams['lab'], hparams['expt'], hparams['animal'], hparams['session'])) filename = os.path.join( hparams['expt_dir'], 'version_%i' % version, sess_id) if not os.path.exists(filename): raise FileNotFoundError('latents located at %s do not exist' % filename) latent_dict = pickle.load(open(filename, 'rb')) print('loaded latents from %s' % filename) # get all test latents latents = [] for trial in latent_dict['trials'][dtype]: ls = latent_dict['latents'][trial] latents.append(ls) return np.concatenate(latents)
[docs]def load_metrics_csv_as_df( hparams, lab, expt, metrics_list, test=False, version='best', version_dir=None): """Load metrics csv file and return as a pandas dataframe for easy plotting. Parameters ---------- hparams : :obj:`dict` requires `sessions_csv`, `multisession`, `lab`, `expt`, `animal` and `session` lab : :obj:`str` for `get_lab_example` expt : :obj:`str` for `get_lab_example` metrics_list : :obj:`list` names of metrics to pull from csv; do not prepend with 'tr', 'val', or 'test' test : :obj:`bool` True to only return test values (computed once at end of training) version: :obj:`str` `best` to find best model in tt expt, None to find model with hyperparams defined in `hparams`, int to load specific model Returns ------- :obj:`pandas.DataFrame` object """ # programmatically fill out other hparams options if version_dir is None: get_lab_example(hparams, lab, expt) hparams['session_dir'], sess_ids = get_session_dir(hparams) hparams['expt_dir'] = get_expt_dir(hparams) # find metrics csv file if version is 'best': version = get_best_model_version(hparams['expt_dir'])[0] elif isinstance(version, int): version = version else: _, version = experiment_exists(hparams, which_version=True) version_dir = os.path.join(hparams['expt_dir'], 'version_%i' % version) metric_file = os.path.join(version_dir, 'metrics.csv') metrics = pd.read_csv(metric_file) # collect data from csv file sess_ids = read_session_info_from_csv(os.path.join(version_dir, 'session_info.csv')) sess_ids_strs = [] for sess_id in sess_ids: sess_ids_strs.append(str('%s/%s' % (sess_id['animal'], sess_id['session']))) metrics_df = [] for i, row in metrics.iterrows(): dataset = 'all' if row['dataset'] == -1 else sess_ids_strs[row['dataset']] if test: test_dict = { 'dataset': dataset, 'epoch': row['epoch'], 'dtype': 'test'} for metric in metrics_list: metrics_df.append(pd.DataFrame( {**test_dict, 'loss': metric, 'val': row['test_%s' % metric]}, index=[0])) else: # make dict for val data val_dict = { 'dataset': dataset, 'epoch': row['epoch'], 'dtype': 'val'} for metric in metrics_list: metrics_df.append(pd.DataFrame( {**val_dict, 'loss': metric, 'val': row['val_%s' % metric]}, index=[0])) # NOTE: grayed out lines are old version that returns a single dataframe row containing # all losses per epoch; new way creates one row per loss, making it easy to use with # seaborn's FacetGrid object for multi-axis plotting for metric in metrics_list: # val_dict[metric] = row['val_%s' % metric] # metrics_df.append(pd.DataFrame(val_dict, index=[0])) # make dict for train data tr_dict = { 'dataset': dataset, 'epoch': row['epoch'], 'dtype': 'train'} for metric in metrics_list: metrics_df.append(pd.DataFrame( {**tr_dict, 'loss': metric, 'val': row['tr_%s' % metric]}, index=[0])) # for metric in metrics_list: # tr_dict[metric] = row['tr_%s' % metric] # metrics_df.append(pd.DataFrame(tr_dict, index=[0])) return pd.concat(metrics_df, sort=True)
[docs]def save_movie(save_file, ani, frame_rate=15): """Save out matplotlib ArtistAnimation Parameters ---------- save_file : :obj:`str` full save file (path and filename) ani : :obj:`matplotlib.animation.ArtistAnimation` object animation to save frame_rate : :obj:`int`, optional frame rate of saved movie """ if save_file is not None: make_dir_if_not_exists(save_file) if save_file[-3:] == 'gif': print('saving video to %s...' % save_file, end='') ani.save(save_file, writer='imagemagick', fps=frame_rate) print('done') else: if save_file[-3:] != 'mp4': save_file += '.mp4' writer = FFMpegWriter(fps=frame_rate, bitrate=-1) print('saving video to %s...' % save_file, end='') ani.save(save_file, writer=writer) print('done')