Source code for behavenet.plotting.arhmm_utils

"""Plotting and video making functions for ARHMMs."""

import pickle
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib
import matplotlib.animation as animation
from behavenet import make_dir_if_not_exists
from behavenet.models import AE as AE
from behavenet.plotting import save_movie

# to ignore imports for sphix-autoapidoc
__all__ = [
    'get_discrete_chunks', 'get_state_durations', 'get_latent_arrays_by_dtype',
    'get_model_latents_states',
    'make_syllable_movies_wrapper', 'make_syllable_movies',
    'real_vs_sampled_wrapper', 'make_real_vs_sampled_movies', 'plot_real_vs_sampled',
    'plot_states_overlaid_with_latents', 'plot_state_transition_matrix', 'plot_dynamics_matrices',
    'plot_obs_biases', 'plot_obs_covariance_matrices']


[docs]def get_discrete_chunks(states, include_edges=True): """Find occurences of each discrete state. Parameters ---------- states : :obj:`list` list of trials; each trial is numpy array containing discrete state for each frame include_edges : :obj:`bool` include states at start and end of chunk Returns ------- :obj:`list` list of length discrete states; each list contains all occurences of that discrete state by :obj:`[chunk number, starting index, ending index]` """ max_state = max([max(x) for x in states]) indexing_list = [[] for _ in range(max_state + 1)] for i_chunk, chunk in enumerate(states): # pad either side so we get start and end chunks chunk = np.pad(chunk, (1, 1), mode='constant', constant_values=-1) # don't add 1 because of start padding, this is now index in original unpadded data split_indices = np.where(np.ediff1d(chunk) != 0)[0] # last index will be 1 higher that it should be due to padding # split_indices[-1] -= 1 for i in range(len(split_indices)-1): # get which state this chunk was (+1 because data is still padded) which_state = chunk[split_indices[i]+1] if not include_edges: if split_indices[i] != 0 and split_indices[i+1] != (len(chunk)-2): indexing_list[which_state].append( [i_chunk, split_indices[i], split_indices[i+1]]) else: indexing_list[which_state].append( [i_chunk, split_indices[i], split_indices[i+1]]) # convert lists to numpy arrays indexing_list = [np.asarray(indexing_list[i_state]) for i_state in range(max_state + 1)] return indexing_list
[docs]def get_state_durations(latents, hmm, include_edges=True): """Calculate frame count for each state. Parameters ---------- latents : :obj:`list` of :obj:`np.ndarray` latent states hmm : :obj:`ssm.HMM` arhmm objecct include_edges : :obj:`bool` include states at start and end of chunk Returns ------- :obj:`list` number of frames for each state run; list is empty if single-state model """ if hmm.K == 1: return [] states = [hmm.most_likely_states(x) for x in latents if len(x) > 0] state_indices = get_discrete_chunks(states, include_edges=include_edges) durations = [] for i_state in range(0, len(state_indices)): if len(state_indices[i_state]) > 0: durations.append(np.concatenate(np.diff(state_indices[i_state][:, 1:3], 1))) else: durations.append(np.array([])) return durations
[docs]def get_latent_arrays_by_dtype(data_generator, sess_idxs=0, data_key='ae_latents'): """Collect data from data generator and put into dictionary with dtypes for keys. Parameters ---------- data_generator : :obj:`ConcatSessionsGenerator` sess_idxs : :obj:`int` or :obj:`list` concatenate train/test/val data across one or more sessions data_key : :obj:`str` key into data generator object; 'ae_latents' | 'labels' Returns ------- :obj:`tuple` - latents (:obj:`dict`): with keys 'train', 'val', 'test' - trial indices (:obj:`dict`): with keys 'train', 'val', 'test' """ if isinstance(sess_idxs, int): sess_idxs = [sess_idxs] dtypes = ['train', 'val', 'test'] latents = {key: [] for key in dtypes} trial_idxs = {key: [] for key in dtypes} for sess_idx in sess_idxs: dataset = data_generator.datasets[sess_idx] for data_type in dtypes: curr_idxs = dataset.batch_idxs[data_type] trial_idxs[data_type] += list(curr_idxs) latents[data_type] += [dataset[i_trial][data_key][0][:] for i_trial in curr_idxs] return latents, trial_idxs
[docs]def get_model_latents_states( hparams, version, sess_idx=0, return_samples=0, cond_sampling=False, dtype='test', dtypes=['train', 'val', 'test'], rng_seed=0): """Return arhmm defined in :obj:`hparams` with associated latents and states. Can also return sampled latents and states. Parameters ---------- hparams : :obj:`dict` needs to contain enough information to specify an arhmm version : :obj:`str` or :obj:`int` test tube model version (can be 'best') sess_idx : :obj:`int`, optional session index into data generator return_samples : :obj:`int`, optional number of trials to sample from model cond_sampling : :obj:`bool`, optional if :obj:`True` return samples conditioned on most likely state sequence; else return unconditioned samples dtype : :obj:`str`, optional trial type to use for conditonal sampling; 'train' | 'val' | 'test' dtypes : :obj:`array-like`, optional trial types for which to collect latents and states rng_seed : :obj:`int`, optional random number generator seed to control sampling Returns ------- :obj:`dict` - 'model' (:obj:`ssm.HMM` object) - 'latents' (:obj:`dict`): latents from train, val and test trials - 'states' (:obj:`dict`): states from train, val and test trials - 'trial_idxs' (:obj:`dict`): trial indices from train, val and test trials - 'latents_gen' (:obj:`list`) - 'states_gen' (:obj:`list`) """ from behavenet.data.utils import get_transforms_paths from behavenet.fitting.utils import experiment_exists from behavenet.fitting.utils import get_best_model_version from behavenet.fitting.utils import get_expt_dir from behavenet.fitting.utils import get_session_dir hparams['session_dir'], sess_ids = get_session_dir( hparams, session_source=hparams.get('all_source', 'save')) hparams['expt_dir'] = get_expt_dir(hparams) # get version/model if version == 'best': version = get_best_model_version( hparams['expt_dir'], measure='val_loss', best_def='min')[0] else: _, version = experiment_exists(hparams, which_version=True) if version is None: raise FileNotFoundError( 'Could not find the specified model version in %s' % hparams['expt_dir']) # load model model_file = os.path.join(hparams['expt_dir'], 'version_%i' % version, 'best_val_model.pt') with open(model_file, 'rb') as f: hmm = pickle.load(f) # load latents/labels if hparams['model_class'].find('labels') > -1: from behavenet.data.utils import load_labels_like_latents all_latents = load_labels_like_latents(hparams, sess_ids, sess_idx) else: _, latents_file = get_transforms_paths('ae_latents', hparams, sess_ids[sess_idx]) with open(latents_file, 'rb') as f: all_latents = pickle.load(f) # collect inferred latents/states trial_idxs = {} latents = {} states = {} for data_type in dtypes: trial_idxs[data_type] = all_latents['trials'][data_type] latents[data_type] = [all_latents['latents'][i_trial] for i_trial in trial_idxs[data_type]] states[data_type] = [hmm.most_likely_states(x) for x in latents[data_type]] # collect sampled latents/states if return_samples > 0: states_gen = [] np.random.seed(rng_seed) if cond_sampling: n_latents = latents[dtype][0].shape[1] latents_gen = [np.zeros((len(state_seg), n_latents)) for state_seg in states[dtype]] for i_seg, state_seg in enumerate(states[dtype]): for i_t in range(len(state_seg)): if i_t >= 1: latents_gen[i_seg][i_t] = hmm.observations.sample_x( states[dtype][i_seg][i_t], latents_gen[i_seg][:i_t], input=np.zeros(0)) else: latents_gen[i_seg][i_t] = hmm.observations.sample_x( states[dtype][i_seg][i_t], latents[dtype][i_seg][0].reshape((1, n_latents)), input=np.zeros(0)) else: latents_gen = [] offset = 200 for i in range(return_samples): these_states_gen, these_latents_gen = hmm.sample( latents[dtype][0].shape[0] + offset) latents_gen.append(these_latents_gen[offset:]) states_gen.append(these_states_gen[offset:]) else: latents_gen = [] states_gen = [] return_dict = { 'model': hmm, 'latents': latents, 'states': states, 'trial_idxs': trial_idxs, 'latents_gen': latents_gen, 'states_gen': states_gen, } return return_dict
[docs]def make_syllable_movies_wrapper( hparams, save_file, sess_idx=0, dtype='test', max_frames=400, frame_rate=10, min_threshold=0, n_buffer=5, n_pre_frames=3, n_rows=None, single_syllable=None): """Present video clips of each individual syllable in separate panels. This is a high-level function that loads the arhmm model described in the hparams dictionary and produces the necessary states/video frames. Parameters ---------- hparams : :obj:`dict` needs to contain enough information to specify an arhmm save_file : :obj:`str` full save file (path and filename) sess_idx : :obj:`int`, optional session index into data generator dtype : :obj:`str`, optional types of trials to make video with; 'train' | 'val' | 'test' max_frames : :obj:`int`, optional maximum number of frames to animate frame_rate : :obj:`float`, optional frame rate of saved movie min_threshold : :obj:`int`, optional minimum number of frames in a syllable run to be considered for movie n_buffer : :obj:`int` number of blank frames between syllable instances n_pre_frames : :obj:`int` number of behavioral frames to precede each syllable instance n_rows : :obj:`int` or :obj:`NoneType` number of rows in output movie single_syllable : :obj:`int` or :obj:`NoneType` choose only a single state for movie """ from behavenet.data.data_generator import ConcatSessionsGenerator from behavenet.data.utils import get_data_generator_inputs from behavenet.data.utils import get_transforms_paths from behavenet.fitting.utils import experiment_exists from behavenet.fitting.utils import get_expt_dir from behavenet.fitting.utils import get_session_dir # load images, latents, and states hparams['session_dir'], sess_ids = get_session_dir( hparams, session_source=hparams.get('all_source', 'save')) hparams['expt_dir'] = get_expt_dir(hparams) hparams['load_videos'] = True hparams, signals, transforms, paths = get_data_generator_inputs(hparams, sess_ids) data_generator = ConcatSessionsGenerator( hparams['data_dir'], sess_ids, signals_list=[signals[sess_idx]], transforms_list=[transforms[sess_idx]], paths_list=[paths[sess_idx]], device='cpu', as_numpy=True, batch_load=False, rng_seed=hparams['rng_seed_data']) ims_orig = data_generator.datasets[sess_idx].data['images'] del data_generator # free up memory # get tt version number _, version = experiment_exists(hparams, which_version=True) print('producing syllable videos for arhmm %s' % version) # load latents/labels if hparams['model_class'].find('labels') > -1: from behavenet.data.utils import load_labels_like_latents latents = load_labels_like_latents(hparams, sess_ids, sess_idx) else: _, latents_file = get_transforms_paths('ae_latents', hparams, sess_ids[sess_idx]) with open(latents_file, 'rb') as f: latents = pickle.load(f) trial_idxs = latents['trials'][dtype] # load model model_file = os.path.join(hparams['expt_dir'], 'version_%i' % version, 'best_val_model.pt') with open(model_file, 'rb') as f: hmm = pickle.load(f) # infer discrete states states = [hmm.most_likely_states(latents['latents'][s]) for s in latents['trials'][dtype]] if len(states) == 0: raise ValueError('No latents for dtype=%s' % dtype) # find runs of discrete states; state indices is a list, each entry of which is a np array with # shape (n_state_instances, 3), where the 3 values are: # chunk_idx, chunk_start_idx, chunk_end_idx # chunk_idx is in [0, n_chunks], and indexes trial_idxs state_indices = get_discrete_chunks(states, include_edges=True) K = len(state_indices) # get all example over minimum state length threshold over_threshold_instances = [[] for _ in range(K)] for i_state in range(K): if state_indices[i_state].shape[0] > 0: state_lens = np.diff(state_indices[i_state][:, 1:3], axis=1) over_idxs = state_lens > min_threshold over_threshold_instances[i_state] = state_indices[i_state][over_idxs[:, 0]] np.random.shuffle(over_threshold_instances[i_state]) # shuffle instances make_syllable_movies( ims_orig=ims_orig, state_list=over_threshold_instances, trial_idxs=trial_idxs, save_file=save_file, max_frames=max_frames, frame_rate=frame_rate, n_buffer=n_buffer, n_pre_frames=n_pre_frames, n_rows=n_rows, single_syllable=single_syllable)
[docs]def make_syllable_movies( ims_orig, state_list, trial_idxs, save_file=None, max_frames=400, frame_rate=10, n_buffer=5, n_pre_frames=3, n_rows=None, single_syllable=None): """Present video clips of each individual syllable in separate panels Parameters ---------- ims_orig : :obj:`np.ndarray` shape (n_frames, n_channels, y_pix, x_pix) state_list : :obj:`list` each entry (one per state) contains all occurences of that discrete state by :obj:`[chunk number, starting index, ending index]` trial_idxs : :obj:`array-like` indices into :obj:`states` for which trials should be plotted save_file : :obj:`str` full save file (path and filename) max_frames : :obj:`int`, optional maximum number of frames to animate frame_rate : :obj:`float`, optional frame rate of saved movie n_buffer : :obj:`int`, optional number of blank frames between syllable instances n_pre_frames : :obj:`int`, optional number of behavioral frames to precede each syllable instance n_rows : :obj:`int` or :obj:`NoneType`, optional number of rows in output movie single_syllable : :obj:`int` or :obj:`NoneType`, optional choose only a single state for movie """ K = len(state_list) # Initialize syllable movie frames plt.clf() if single_syllable is not None: K = 1 fig_width = 5 n_rows = 1 else: fig_width = 10 # aiming for dim 1 being 10 # get video dims bs, n_channels, y_dim, x_dim = ims_orig[0].shape movie_dim1 = n_channels * y_dim movie_dim2 = x_dim if n_rows is None: n_rows = int(np.floor(np.sqrt(K))) n_cols = int(np.ceil(K / n_rows)) fig_dim_div = movie_dim2 * n_cols / fig_width fig_width = (movie_dim2 * n_cols) / fig_dim_div fig_height = (movie_dim1 * n_rows) / fig_dim_div fig, axes = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height)) for i, ax in enumerate(fig.axes): ax.set_yticks([]) ax.set_xticks([]) if i >= K: ax.set_axis_off() elif single_syllable is not None: ax.set_title('Syllable %i' % single_syllable, fontsize=16) else: ax.set_title('Syllable %i' % i, fontsize=16) fig.tight_layout(pad=0, h_pad=1.005) imshow_kwargs = {'animated': True, 'cmap': 'gray', 'vmin': 0, 'vmax': 1} ims = [[] for _ in range(max_frames + bs + 200)] # Loop through syllables for i_k, ax in enumerate(fig.axes): # skip if no syllable in this axis if i_k >= K: continue print('processing syllable %i/%i' % (i_k + 1, K)) # skip if no syllables are longer than threshold if len(state_list[i_k]) == 0: continue if single_syllable is not None: i_k = single_syllable i_chunk = 0 i_frame = 0 while i_frame < max_frames: if i_chunk >= len(state_list[i_k]): # show blank if out of syllable examples im = ax.imshow(np.zeros((movie_dim1, movie_dim2)), **imshow_kwargs) ims[i_frame].append(im) i_frame += 1 else: # Get movies/latents chunk_idx = state_list[i_k][i_chunk, 0] which_trial = trial_idxs[chunk_idx] tr_beg = state_list[i_k][i_chunk, 1] tr_end = state_list[i_k][i_chunk, 2] batch = ims_orig[which_trial] movie_chunk = batch[max(tr_beg - n_pre_frames, 0):tr_end] movie_chunk = np.concatenate( [movie_chunk[:, j] for j in range(movie_chunk.shape[1])], axis=1) # if np.sum(states[chunk_idx][tr_beg:tr_end-1] != i_k) > 0: # raise ValueError('Misaligned states for syllable segmentation') # Loop over this chunk for i in range(movie_chunk.shape[0]): im = ax.imshow(movie_chunk[i], **imshow_kwargs) ims[i_frame].append(im) # Add red box if start of syllable syllable_start = n_pre_frames if tr_beg >= n_pre_frames else tr_beg if syllable_start <= i < (syllable_start + 2): rect = matplotlib.patches.Rectangle( (5, 5), 10, 10, linewidth=1, edgecolor='r', facecolor='r') im = ax.add_patch(rect) ims[i_frame].append(im) i_frame += 1 # Add buffer black frames for j in range(n_buffer): im = ax.imshow(np.zeros((movie_dim1, movie_dim2)), **imshow_kwargs) ims[i_frame].append(im) i_frame += 1 i_chunk += 1 print('creating animation...', end='') ani = animation.ArtistAnimation( fig, [ims[i] for i in range(len(ims)) if ims[i] != []], interval=20, blit=True, repeat=False) print('done') if save_file is not None: # put together file name if save_file[-3:] == 'mp4': save_file = save_file[:-3] if single_syllable is not None: state_str = str('_syllable-%02i' % single_syllable) else: state_str = '' save_file += state_str save_file += '.mp4' save_movie(save_file, ani, frame_rate=frame_rate)
[docs]def real_vs_sampled_wrapper( output_type, hparams, save_file, sess_idx, dtype='test', conditional=True, max_frames=400, frame_rate=20, n_buffer=5, xtick_locs=None, frame_rate_beh=None, format='png'): """Produce movie with (AE) reconstructed video and sampled video. This is a high-level function that loads the model described in the hparams dictionary and produces the necessary state sequences/samples. The sampled video can be completely unconditional (states and latents are sampled) or conditioned on the most likely state sequence. Parameters ---------- output_type : :obj:`str` 'plot' | 'movie' | 'both' hparams : :obj:`dict` needs to contain enough information to specify an autoencoder save_file : :obj:`str` full save file (path and filename) sess_idx : :obj:`int`, optional session index into data generator dtype : :obj:`str`, optional types of trials to make plot/video with; 'train' | 'val' | 'test' conditional : :obj:`bool` conditional vs unconditional samples; for creating reconstruction title max_frames : :obj:`int`, optional maximum number of frames to animate frame_rate : :obj:`float`, optional frame rate of saved movie n_buffer : :obj:`int` number of blank frames between animated trials if more one are needed to reach :obj:`max_frames` xtick_locs : :obj:`array-like`, optional tick locations in bin values for plot frame_rate_beh : :obj:`float`, optional behavioral video framerate; to properly relabel xticks format : :obj:`str`, optional any accepted matplotlib save format, e.g. 'png' | 'pdf' | 'jpeg' Returns ------- :obj:`matplotlib.figure.Figure` matplotlib figure handle if :obj:`output_type='plot'` or :obj:`output_type='both'`, else nothing returned (movie is saved) """ from behavenet.data.utils import get_transforms_paths from behavenet.fitting.utils import get_expt_dir from behavenet.fitting.utils import get_session_dir # check input - cannot create sampled movies for arhmm-labels models (no mapping from labels to # frames) if hparams['model_class'].find('labels') > -1: if output_type == 'both' or output_type == 'movie': print('warning: cannot create video with "arhmm-labels" model; producing plots') output_type = 'plot' # load latents and states (observed and sampled) model_output = get_model_latents_states( hparams, '', sess_idx=sess_idx, return_samples=50, cond_sampling=conditional) if output_type == 'both' or output_type == 'movie': # load in AE decoder if hparams.get('ae_model_path', None) is not None: ae_model_file = os.path.join(hparams['ae_model_path'], 'best_val_model.pt') ae_arch = pickle.load( open(os.path.join(hparams['ae_model_path'], 'meta_tags.pkl'), 'rb')) else: hparams['session_dir'], sess_ids = get_session_dir( hparams, session_source=hparams.get('all_source', 'save')) hparams['expt_dir'] = get_expt_dir(hparams) _, latents_file = get_transforms_paths('ae_latents', hparams, sess_ids[sess_idx]) ae_model_file = os.path.join(os.path.dirname(latents_file), 'best_val_model.pt') ae_arch = pickle.load( open(os.path.join(os.path.dirname(latents_file), 'meta_tags.pkl'), 'rb')) print('loading model from %s' % ae_model_file) ae_model = AE(ae_arch) ae_model.load_state_dict( torch.load(ae_model_file, map_location=lambda storage, loc: storage)) ae_model.eval() n_channels = ae_model.hparams['n_input_channels'] y_pix = ae_model.hparams['y_pixels'] x_pix = ae_model.hparams['x_pixels'] # push observed latents through ae decoder ims_recon = np.zeros((0, n_channels * y_pix, x_pix)) i_trial = 0 while ims_recon.shape[0] < max_frames: recon = ae_model.decoding( torch.tensor(model_output['latents'][dtype][i_trial]).float(), None, None). \ cpu().detach().numpy() recon = np.concatenate([recon[:, i] for i in range(recon.shape[1])], axis=1) zero_frames = np.zeros((n_buffer, n_channels * y_pix, x_pix)) # add a few black frames ims_recon = np.concatenate((ims_recon, recon, zero_frames), axis=0) i_trial += 1 # push sampled latents through ae decoder ims_recon_samp = np.zeros((0, n_channels * y_pix, x_pix)) i_trial = 0 while ims_recon_samp.shape[0] < max_frames: recon = ae_model.decoding(torch.tensor( model_output['latents_gen'][i_trial]).float(), None, None).cpu().detach().numpy() recon = np.concatenate([recon[:, i] for i in range(recon.shape[1])], axis=1) zero_frames = np.zeros((n_buffer, n_channels * y_pix, x_pix)) # add a few black frames ims_recon_samp = np.concatenate((ims_recon_samp, recon, zero_frames), axis=0) i_trial += 1 make_real_vs_sampled_movies( ims_recon, ims_recon_samp, conditional=conditional, save_file=save_file, frame_rate=frame_rate) if output_type == 'both' or output_type == 'plot': i_trial = 0 latents = model_output['latents'][dtype][i_trial][:max_frames] states = model_output['states'][dtype][i_trial][:max_frames] latents_samp = model_output['latents_gen'][i_trial][:max_frames] if not conditional: states_samp = model_output['states_gen'][i_trial][:max_frames] else: states_samp = [] fig = plot_real_vs_sampled( latents, latents_samp, states, states_samp, save_file=save_file, xtick_locs=xtick_locs, frame_rate=hparams['frame_rate'] if frame_rate_beh is None else frame_rate_beh, format=format) if output_type == 'movie': return None elif output_type == 'both' or output_type == 'plot': return fig else: raise ValueError('"%s" is an invalid output_type' % output_type)
[docs]def make_real_vs_sampled_movies( ims_recon, ims_recon_samp, conditional, save_file=None, frame_rate=15): """Produce movie with (AE) reconstructed video and sampled video. Parameters ---------- ims_recon : :obj:`np.ndarray` shape (n_frames, y_pix, x_pix) ims_recon_samp : :obj:`np.ndarray` shape (n_frames, y_pix, x_pix) conditional : :obj:`bool` conditional vs unconditional samples; for creating reconstruction title save_file : :obj:`str`, optional full save file (path and filename) frame_rate : :obj:`float`, optional frame rate of saved movie """ n_frames = ims_recon.shape[0] n_plots = 2 [y_pix, x_pix] = ims_recon[0].shape fig_dim_div = x_pix * n_plots / 10 # aiming for dim 1 being 10 x_dim = x_pix * n_plots / fig_dim_div y_dim = y_pix / fig_dim_div fig, axes = plt.subplots(1, n_plots, figsize=(x_dim, y_dim)) for j in range(2): axes[j].set_xticks([]) axes[j].set_yticks([]) axes[0].set_title('Real Reconstructions\n', fontsize=16) if conditional: title_str = 'Generative Reconstructions\n(Conditional)' else: title_str = 'Generative Reconstructions\n(Unconditional)' axes[1].set_title(title_str, fontsize=16) fig.tight_layout(pad=0) im_kwargs = {'cmap': 'gray', 'vmin': 0, 'vmax': 1, 'animated': True} ims = [] for i in range(n_frames): ims_curr = [] im = axes[0].imshow(ims_recon[i], **im_kwargs) ims_curr.append(im) im = axes[1].imshow(ims_recon_samp[i], **im_kwargs) ims_curr.append(im) ims.append(ims_curr) ani = animation.ArtistAnimation(fig, ims, blit=True, repeat_delay=1000) save_movie(save_file, ani, frame_rate=frame_rate)
[docs]def plot_real_vs_sampled( latents, latents_samp, states, states_samp, save_file=None, xtick_locs=None, frame_rate=None, format='png'): """Plot real and sampled latents overlaying real and (potentially sampled) states. Parameters ---------- latents : :obj:`np.ndarray` shape (n_frames, n_latents) latents_samp : :obj:`np.ndarray` shape (n_frames, n_latents) states : :obj:`np.ndarray` shape (n_frames,) states_samp : :obj:`np.ndarray` shape (n_frames,) if :obj:`latents_samp` are not conditioned on :obj:`states`, otherwise shape (0,) save_file : :obj:`str` full save file (path and filename) xtick_locs : :obj:`array-like`, optional tick locations in bin values for plot frame_rate : :obj:`float`, optional behavioral video framerate; to properly relabel xticks format : :obj:`str`, optional any accepted matplotlib save format, e.g. 'png' | 'pdf' | 'jpeg' Returns ------- :obj:`matplotlib.figure.Figure` matplotlib figure handle """ fig, axes = plt.subplots(2, 1, figsize=(10, 8)) # plot observations axes[0] = plot_states_overlaid_with_latents( latents, states, ax=axes[0], xtick_locs=xtick_locs, frame_rate=frame_rate) axes[0].set_xticks([]) axes[0].set_xlabel('') axes[0].set_title('Inferred latents') # plot samples if len(states_samp) == 0: plot_states = states title_str = 'Sampled latents' else: plot_states = states_samp title_str = 'Sampled states and latents' axes[1] = plot_states_overlaid_with_latents( latents_samp, plot_states, ax=axes[1], xtick_locs=xtick_locs, frame_rate=frame_rate) axes[1].set_title(title_str) if save_file is not None: make_dir_if_not_exists(save_file) plt.savefig(save_file + '.' + format, dpi=300, format=format) return fig
[docs]def plot_states_overlaid_with_latents( latents, states, save_file=None, ax=None, xtick_locs=None, frame_rate=None, cmap='tab20b', format='png'): """Plot states for a single trial overlaid with latents. Parameters ---------- latents : :obj:`np.ndarray` shape (n_frames, n_latents) states : :obj:`np.ndarray` shape (n_frames,) save_file : :obj:`str`, optional full save file (path and filename) ax : :obj:`matplotlib.Axes` object or :obj:`NoneType`, optional axes to plot into; if :obj:`NoneType`, a new figure is created xtick_locs : :obj:`array-like`, optional tick locations in bin values for plot frame_rate : :obj:`float`, optional behavioral video framerate; to properly relabel xticks cmap : :obj:`str`, optional matplotlib colormap format : :obj:`str`, optional any accepted matplotlib save format, e.g. 'png' | 'pdf' | 'jpeg' Returns ------- :obj:`matplotlib.figure.Figure` matplotlib figure handle if :obj:`ax=None`, otherwise updated axis """ if ax is None: fig = plt.figure(figsize=(8, 4)) ax = fig.gca() else: fig = None spc = 1.1 * abs(latents.max()) n_latents = latents.shape[1] plotting_latents = latents + spc * np.arange(n_latents) ymin = min(-spc, np.min(plotting_latents)) ymax = max(spc * n_latents, np.max(plotting_latents)) ax.imshow( states[None, :], aspect='auto', extent=(0, len(latents), ymin, ymax), cmap=cmap, alpha=1.0) ax.plot(plotting_latents, '-k', lw=3) ax.set_ylim([ymin, ymax]) # yticks = spc * np.arange(n_latents) # ax.set_yticks(yticks[::2]) # ax.set_yticklabels(np.arange(n_latents)[::2]) ax.set_yticks([]) # ax.set_ylabel('Latent') ax.set_xlabel('Time (bins)') if xtick_locs is not None: ax.set_xticks(xtick_locs) if frame_rate is not None: ax.set_xticklabels((np.asarray(xtick_locs) / frame_rate).astype('int')) ax.set_xlabel('Time (sec)') if save_file is not None: make_dir_if_not_exists(save_file) plt.savefig(save_file + '.' + format, dpi=300, format=format) if fig is None: return ax else: return fig
[docs]def plot_state_transition_matrix(model, deridge=False): """Plot Markov transition matrix for arhmm. Parameters ---------- model : :obj:`ssm.HMM` object deridge : :obj:`bool`, optional remove diagonal to more clearly see dynamic range of off-diagonal entries Returns ------- :obj:`matplotlib.figure.Figure` matplotlib figure handle """ trans = np.copy(model.transitions.transition_matrix) if deridge: n_states = trans.shape[0] for i in range(n_states): trans[i, i] = np.nan clim = np.nanmax(np.abs(trans)) else: clim = 1 fig = plt.figure() plt.imshow(trans, clim=[-clim, clim], cmap='RdBu_r') plt.colorbar() plt.ylabel('State (t)') plt.xlabel('State (t+1)') plt.title('State transition matrix') plt.show() return fig
[docs]def plot_dynamics_matrices(model, deridge=False): """Plot autoregressive dynamics matrices for arhmm. Parameters ---------- model : :obj:`ssm.HMM` object deridge : :obj:`bool`, optional remove diagonal to more clearly see dynamic range of off-diagonal entries Returns ------- :obj:`matplotlib.figure.Figure` matplotlib figure handle """ K = model.K D = model.D n_lags = model.observations.lags if n_lags == 1: n_cols = 3 fac = 1 elif n_lags == 2: n_cols = 3 fac = 1 / n_lags elif n_lags == 3: n_cols = 3 fac = 1.25 / n_lags elif n_lags == 4: n_cols = 3 fac = 1.50 / n_lags elif n_lags == 5: n_cols = 2 fac = 1.75 / n_lags else: n_cols = 1 fac = 1 n_rows = int(np.ceil(K / n_cols)) fig = plt.figure(figsize=(4 * n_cols, 4 * n_rows * fac)) mats = np.copy(model.observations.As) if deridge: for k in range(K): for d in range(model.D): mats[k, d, d] = np.nan clim = np.nanmax(np.abs(mats)) else: clim = np.max(np.abs(mats)) for k in range(K): plt.subplot(n_rows, n_cols, k + 1) im = plt.imshow(mats[k], cmap='RdBu_r', clim=[-clim, clim]) for lag in range(n_lags - 1): plt.axvline((lag + 1) * D - 0.5, ymin=0, ymax=K, color=[0, 0, 0]) plt.xticks([]) plt.yticks([]) plt.title('State %i' % k) plt.tight_layout() fig.subplots_adjust(right=0.8) cbar_ax = fig.add_axes([0.85, 0.4, 0.03, 0.2]) fig.colorbar(im, cax=cbar_ax) # plt.suptitle('Dynamics matrices') return fig
[docs]def plot_obs_biases(model): """Plot observation bias vectors for arhmm. Parameters ---------- model : :obj:`ssm.HMM` object Returns ------- :obj:`matplotlib.figure.Figure` matplotlib figure handle """ fig = plt.figure(figsize=(6, 4)) mats = np.copy(model.observations.bs.T) clim = np.max(np.abs(mats)) plt.imshow(mats, cmap='RdBu_r', clim=[-clim, clim], aspect='auto') plt.xlabel('State') plt.yticks([]) plt.ylabel('Observation dimension') plt.tight_layout() plt.colorbar() plt.title('State biases') plt.show() return fig
[docs]def plot_obs_covariance_matrices(model): """Plot observation covariance matrices for arhmm. Parameters ---------- model : :obj:`ssm.HMM` object Returns ------- :obj:`matplotlib.figure.Figure` matplotlib figure handle """ K = model.K n_cols = int(np.sqrt(K)) n_rows = int(np.ceil(K / n_cols)) fig = plt.figure(figsize=(3 * n_cols, 3 * n_rows)) mats = np.copy(model.observations.Sigmas) clim = np.quantile(np.abs(mats), 0.95) for k in range(K): plt.subplot(n_rows, n_cols, k + 1) im = plt.imshow(mats[k], cmap='RdBu_r', clim=[-clim, clim]) plt.xticks([]) plt.yticks([]) plt.title('State %i' % k) plt.tight_layout() fig.subplots_adjust(right=0.8) cbar_ax = fig.add_axes([0.85, 0.4, 0.03, 0.2]) fig.colorbar(im, cax=cbar_ax) return fig