import os
import copy
import pickle
import numpy as np
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
from tqdm import tqdm
from behavenet import get_user_dir
from behavenet import make_dir_if_not_exists
from behavenet.data.utils import build_data_generator
from behavenet.data.utils import load_labels_like_latents
from behavenet.fitting.eval import get_reconstruction
from behavenet.fitting.utils import experiment_exists
from behavenet.fitting.utils import get_best_model_and_data
from behavenet.fitting.utils import get_expt_dir
from behavenet.fitting.utils import get_lab_example
from behavenet.fitting.utils import get_session_dir
from behavenet.plotting import concat
from behavenet.plotting import get_crop
from behavenet.plotting import load_latents
from behavenet.plotting import load_metrics_csv_as_df
from behavenet.plotting import save_movie
# to ignore imports for sphix-autoapidoc
__all__ = [
'get_input_range', 'compute_range', 'get_labels_2d_for_trial', 'get_model_input',
'interpolate_2d', 'interpolate_1d', 'interpolate_point_path', 'plot_2d_frame_array',
'plot_1d_frame_array', 'make_interpolated', 'make_interpolated_multipanel', 'fit_classifier',
'plot_psvae_training_curves', 'plot_hyperparameter_search_results',
'plot_label_reconstructions', 'plot_latent_traversals', 'make_latent_traversal_movie',
'plot_mspsvae_training_curves', 'plot_mspsvae_hyperparameter_search_results',
'make_session_swap_movie']
# ----------------------------------------
# low-level util functions
# ----------------------------------------
[docs]def compute_range(values_list, min_p=5, max_p=95):
"""Compute min and max of a list of numbers using percentiles.
Parameters
----------
values_list : :obj:`list`
list of np.ndarrays; min/max calculated over axis 0 once all lists are vertically stacked
min_p : :obj:`int`
defines lower end of range; percentile in [0, 100]
max_p : :obj:`int`
defines upper end of range; percentile in [0, 100]
Returns
-------
:obj:`dict`
lower ['min'] and upper ['max'] range of input
"""
if np.any([len(arr) == 0 for arr in values_list]):
values_ = []
for arr in values_list:
if len(arr) != 0:
values_.append(arr)
values = np.vstack(values_)
else:
values = np.vstack(values_list)
ranges = {
'min': np.nanpercentile(values, min_p, axis=0),
'max': np.nanpercentile(values, max_p, axis=0),
'med': np.nanpercentile(values, 50, axis=0)}
return ranges
[docs]def get_labels_2d_for_trial(
hparams, sess_ids, trial=None, trial_idx=None, sess_idx=0, dtype='test', data_gen=None):
"""Return scaled labels (in pixel space) for a given trial.
Parameters
----------
hparams : :obj:`dict`
needs to contain enough information to build a data generator
sess_ids : :obj:`list` of :obj:`dict`
each entry is a session dict with keys 'lab', 'expt', 'animal', 'session'
trial : :obj:`int`, optional
trial index into all possible trials (train, val, test); one of `trial` or `trial_idx`
must be specified; `trial` takes precedence over `trial_idx`
trial_idx : :obj:`int`, optional
trial index into trial type defined by `dtype`; one of `trial` or `trial_idx` must be
specified; `trial` takes precedence over `trial_idx`
sess_idx : :obj:`int`, optional
session index into data generator
dtype : :obj:`str`, optional
data type that is indexed by `trial_idx`; 'train' | 'val' | 'test'
data_gen : :obj:`ConcatSessionGenerator` object, optional
for generating labels
Returns
-------
:obj:`tuple`
- labels_2d_pt (:obj:`torch.Tensor`) of shape (batch, n_labels, y_pix, x_pix)
- labels_2d_np (:obj:`np.ndarray`) of shape (batch, n_labels, y_pix, x_pix)
"""
if (trial_idx is not None) and (trial is not None):
raise ValueError('only one of "trial" or "trial_idx" can be specified')
if data_gen is None:
hparams_new = copy.deepcopy(hparams)
hparams_new['conditional_encoder'] = True # ensure scaled labels are returned
hparams_new['device'] = 'cpu'
hparams_new['as_numpy'] = False
hparams_new['batch_load'] = True
data_gen = build_data_generator(hparams_new, sess_ids, export_csv=False)
# get trial
if trial is None:
trial = data_gen.datasets[sess_idx].batch_idxs[dtype][trial_idx]
batch = data_gen.datasets[sess_idx][trial]
labels_2d_pt = batch['labels_sc']
labels_2d_np = labels_2d_pt.cpu().detach().numpy()
return labels_2d_pt, labels_2d_np
[docs]def interpolate_2d(
interp_type, model, ims_0, latents_0, labels_0, labels_sc_0, mins, maxes, input_idxs,
n_frames, crop_type=None, mins_sc=None, maxes_sc=None, crop_kwargs=None,
marker_idxs=None, ch=0):
"""Return reconstructed images created by interpolating through latent/label space.
Parameters
----------
interp_type : :obj:`str`
'latents' | 'labels'
model : :obj:`behavenet.models` object
autoencoder model
ims_0 : :obj:`torch.Tensor`
base images for interpolating labels, of shape (1, n_channels, y_pix, x_pix)
latents_0 : :obj:`np.ndarray`
base latents of shape (1, n_latents); only two of these dimensions will be changed if
`interp_type='latents'`
labels_0 : :obj:`np.ndarray`
base labels of shape (1, n_labels)
labels_sc_0 : :obj:`np.ndarray`
base scaled labels in pixel space of shape (1, n_labels, y_pix, x_pix)
mins : :obj:`array-like`
minimum values of labels/latents, one for each dim
maxes : :obj:`list`
maximum values of labels/latents, one for each dim
input_idxs : :obj:`list`
indices of labels/latents that will be interpolated; for labels, must be y first, then x
for proper marker recording
n_frames : :obj:`int`
number of interpolation points between mins and maxes (inclusive)
crop_type : :obj:`str` or :obj:`NoneType`, optional
currently only implements 'fixed'; if not None, cropped images are returned, and returned
labels are also cropped so that they can be plotted on top of the cropped images; if None,
returned cropped images are empty and labels are relative to original image size
mins_sc : :obj:`list`, optional
min values of scaled labels that correspond to min values of labels when using conditional
encoders
maxes_sc : :obj:`list`, optional
max values of scaled labels that correspond to max values of labels when using conditional
encoders
crop_kwargs : :obj:`dict`, optional
define center and extent of crop if `crop_type='fixed'`; keys are 'x_0', 'x_ext', 'y_0',
'y_ext'
marker_idxs : :obj:`list`, optional
indices of `labels_sc_0` that will be interpolated; note that this is analogous but
different from `input_idxs`, since the 2d tensor `labels_sc_0` has half as many label
dimensions as `latents_0` and `labels_0`
ch : :obj:`int`, optional
specify which channel of input images to return (can only be a single value)
Returns
-------
:obj:`tuple`
- ims_list (:obj:`list` of :obj:`list` of :obj:`np.ndarray`) interpolated images
- labels_list (:obj:`list` of :obj:`list` of :obj:`np.ndarray`) interpolated labels
- ims_crop_list (:obj:`list` of :obj:`list` of :obj:`np.ndarray`) interpolated , cropped
images
"""
if interp_type == 'labels':
from behavenet.data.transforms import MakeOneHot2D
_, _, y_pix, x_pix = ims_0.shape
one_hot_2d = MakeOneHot2D(y_pix, x_pix)
# compute grid for relevant inputs
n_interp_dims = len(input_idxs)
assert n_interp_dims == 2
# compute ranges for relevant inputs
inputs = []
inputs_sc = []
for d in input_idxs:
inputs.append(np.linspace(mins[d], maxes[d], n_frames))
if mins_sc is not None and maxes_sc is not None:
inputs_sc.append(np.linspace(mins_sc[d], maxes_sc[d], n_frames))
else:
if interp_type == 'labels':
raise NotImplementedError
ims_list = []
ims_crop_list = []
labels_list = []
# latent_vals = []
for i0 in range(n_frames):
ims_tmp = []
ims_crop_tmp = []
labels_tmp = []
# latents_tmp = []
for i1 in range(n_frames):
if interp_type == 'latents':
# get (new) latents
latents = np.copy(latents_0)
latents[0, input_idxs[0]] = inputs[0][i0]
latents[0, input_idxs[1]] = inputs[1][i1]
# get scaled labels (for markers)
labels_sc = _get_updated_scaled_labels(labels_sc_0)
if model.hparams['model_class'] == 'cond-ae-msp':
# get reconstruction
im_tmp = get_reconstruction(
model,
torch.from_numpy(latents).float(),
apply_inverse_transform=True)
else:
# get labels
if model.hparams['model_class'] == 'ae' \
or model.hparams['model_class'] == 'vae' \
or model.hparams['model_class'] == 'beta-tcvae' \
or model.hparams['model_class'] == 'ps-vae' \
or model.hparams['model_class'] == 'msps-vae':
labels = None
elif model.hparams['model_class'] == 'cond-ae' \
or model.hparams['model_class'] == 'cond-vae':
labels = torch.from_numpy(labels_0).float()
else:
raise NotImplementedError
# get reconstruction
im_tmp = get_reconstruction(
model,
torch.from_numpy(latents).float(),
labels=labels)
elif interp_type == 'labels':
# get (new) scaled labels
labels_sc = _get_updated_scaled_labels(
labels_sc_0, input_idxs, [inputs_sc[0][i0], inputs_sc[1][i1]])
if len(labels_sc_0.shape) == 4:
# 2d scaled labels
labels_2d = torch.from_numpy(one_hot_2d(labels_sc)).float()
else:
# 1d scaled labels
labels_2d = None
if model.hparams['model_class'] == 'cond-ae-msp' \
or model.hparams['model_class'] == 'ps-vae' \
or model.hparams['model_class'] == 'msps-vae':
# change latents that correspond to desired labels
latents = np.copy(latents_0)
latents[0, input_idxs[0]] = inputs[0][i0]
latents[0, input_idxs[1]] = inputs[1][i1]
# get reconstruction
im_tmp = get_reconstruction(model, latents, apply_inverse_transform=True)
else:
# get (new) labels
labels = np.copy(labels_0)
labels[0, input_idxs[0]] = inputs[0][i0]
labels[0, input_idxs[1]] = inputs[1][i1]
# get reconstruction
im_tmp = get_reconstruction(
model,
ims_0,
labels=torch.from_numpy(labels).float(),
labels_2d=labels_2d)
else:
raise NotImplementedError
ims_tmp.append(np.copy(im_tmp[0, ch]))
if crop_type:
x_min_tmp = crop_kwargs['x_0'] - crop_kwargs['x_ext']
y_min_tmp = crop_kwargs['y_0'] - crop_kwargs['y_ext']
else:
x_min_tmp = 0
y_min_tmp = 0
if interp_type == 'labels':
labels_tmp.append([
np.copy(labels_sc[0, input_idxs[0]]) - y_min_tmp,
np.copy(labels_sc[0, input_idxs[1]]) - x_min_tmp])
elif interp_type == 'latents' and labels_sc_0 is not None:
labels_tmp.append([
np.copy(labels_sc[0, marker_idxs[0]]) - y_min_tmp,
np.copy(labels_sc[0, marker_idxs[1]]) - x_min_tmp])
else:
labels_tmp.append([np.nan, np.nan])
if crop_type:
ims_crop_tmp.append(get_crop(
im_tmp[0, 0], crop_kwargs['y_0'], crop_kwargs['y_ext'], crop_kwargs['x_0'],
crop_kwargs['x_ext']))
else:
ims_crop_tmp.append([])
ims_list.append(ims_tmp)
ims_crop_list.append(ims_crop_tmp)
labels_list.append(labels_tmp)
return ims_list, labels_list, ims_crop_list
[docs]def interpolate_1d(
interp_type, model, ims_0, latents_0, labels_0, labels_sc_0, mins, maxes, input_idxs,
n_frames, crop_type=None, mins_sc=None, maxes_sc=None, crop_kwargs=None,
marker_idxs=None, ch=0):
"""Return reconstructed images created by interpolating through latent/label space.
Parameters
----------
interp_type : :obj:`str`
'latents' | 'labels'
model : :obj:`behavenet.models` object
autoencoder model
ims_0 : :obj:`torch.Tensor`
base images for interpolating labels, of shape (1, n_channels, y_pix, x_pix)
latents_0 : :obj:`np.ndarray`
base latents of shape (1, n_latents); only two of these dimensions will be changed if
`interp_type='latents'`
labels_0 : :obj:`np.ndarray`
base labels of shape (1, n_labels)
labels_sc_0 : :obj:`np.ndarray`
base scaled labels in pixel space of shape (1, n_labels, y_pix, x_pix)
mins : :obj:`array-like`
minimum values of all labels/latents
maxes : :obj:`array-like`
maximum values of all labels/latents
input_idxs : :obj:`array-like`
indices of labels/latents that will be interpolated
n_frames : :obj:`int`
number of interpolation points between mins and maxes (inclusive)
crop_type : :obj:`str` or :obj:`NoneType`, optional
currently only implements 'fixed'; if not None, cropped images are returned, and returned
labels are also cropped so that they can be plotted on top of the cropped images; if None,
returned cropped images are empty and labels are relative to original image size
mins_sc : :obj:`list`, optional
min values of scaled labels that correspond to min values of labels when using conditional
encoders
maxes_sc : :obj:`list`, optional
max values of scaled labels that correspond to max values of labels when using conditional
encoders
crop_kwargs : :obj:`dict`, optional
define center and extent of crop if `crop_type='fixed'`; keys are 'x_0', 'x_ext', 'y_0',
'y_ext'
marker_idxs : :obj:`list`, optional
indices of `labels_sc_0` that will be interpolated; note that this is analogous but
different from `input_idxs`, since the 2d tensor `labels_sc_0` has half as many label
dimensions as `latents_0` and `labels_0`
ch : :obj:`int`, optional
specify which channel of input images to return (can only be a single value)
Returns
-------
:obj:`tuple`
- ims_list (:obj:`list` of :obj:`list` of :obj:`np.ndarray`) interpolated images
- labels_list (:obj:`list` of :obj:`list` of :obj:`np.ndarray`) interpolated labels
- ims_crop_list (:obj:`list` of :obj:`list` of :obj:`np.ndarray`) interpolated , cropped
images
"""
if interp_type == 'labels':
from behavenet.data.transforms import MakeOneHot2D
_, _, y_pix, x_pix = ims_0.shape
one_hot_2d = MakeOneHot2D(y_pix, x_pix)
n_interp_dims = len(input_idxs)
# compute ranges for relevant inputs
inputs = []
inputs_sc = []
for d in input_idxs:
inputs.append(np.linspace(mins[d], maxes[d], n_frames))
if mins_sc is not None and maxes_sc is not None:
inputs_sc.append(np.linspace(mins_sc[d], maxes_sc[d], n_frames))
else:
if interp_type == 'labels':
raise NotImplementedError
ims_list = []
ims_crop_list = []
labels_list = []
# latent_vals = []
for i0 in range(n_interp_dims):
ims_tmp = []
ims_crop_tmp = []
labels_tmp = []
for i1 in range(n_frames):
if interp_type == 'latents':
# get (new) latents
latents = np.copy(latents_0)
latents[0, input_idxs[i0]] = inputs[i0][i1]
# get scaled labels (for markers)
labels_sc = _get_updated_scaled_labels(labels_sc_0)
if model.hparams['model_class'] == 'cond-ae-msp':
# get reconstruction
im_tmp = get_reconstruction(
model,
torch.from_numpy(latents).float(),
apply_inverse_transform=True)
else:
# get labels
if model.hparams['model_class'] == 'ae' \
or model.hparams['model_class'] == 'vae' \
or model.hparams['model_class'] == 'beta-tcvae' \
or model.hparams['model_class'] == 'ps-vae' \
or model.hparams['model_class'] == 'msps-vae':
labels = None
elif model.hparams['model_class'] == 'cond-ae' \
or model.hparams['model_class'] == 'cond-vae':
labels = torch.from_numpy(labels_0).float()
else:
raise NotImplementedError
# get reconstruction
im_tmp = get_reconstruction(
model,
torch.from_numpy(latents).float(),
labels=labels)
elif interp_type == 'labels':
# get (new) scaled labels
labels_sc = _get_updated_scaled_labels(
labels_sc_0, input_idxs[i0], inputs_sc[i0][i1])
if len(labels_sc_0.shape) == 4:
# 2d scaled labels
labels_2d = torch.from_numpy(one_hot_2d(labels_sc)).float()
else:
# 1d scaled labels
labels_2d = None
if model.hparams['model_class'] == 'cond-ae-msp' \
or model.hparams['model_class'] == 'ps-vae' \
or model.hparams['model_class'] == 'msps-vae':
# change latents that correspond to desired labels
latents = np.copy(latents_0)
latents[0, input_idxs[i0]] = inputs[i0][i1]
# get reconstruction
im_tmp = get_reconstruction(model, latents, apply_inverse_transform=True)
else:
# get (new) labels
labels = np.copy(labels_0)
labels[0, input_idxs[i0]] = inputs[i0][i1]
# get reconstruction
im_tmp = get_reconstruction(
model,
ims_0,
labels=torch.from_numpy(labels).float(),
labels_2d=labels_2d)
else:
raise NotImplementedError
ims_tmp.append(np.copy(im_tmp[0, ch]))
if crop_type:
x_min_tmp = crop_kwargs['x_0'] - crop_kwargs['x_ext']
y_min_tmp = crop_kwargs['y_0'] - crop_kwargs['y_ext']
else:
x_min_tmp = 0
y_min_tmp = 0
if interp_type == 'labels':
labels_tmp.append([
np.copy(labels_sc[0, input_idxs[0]]) - y_min_tmp,
np.copy(labels_sc[0, input_idxs[1]]) - x_min_tmp])
elif interp_type == 'latents' and labels_sc_0 is not None:
labels_tmp.append([
np.copy(labels_sc[0, marker_idxs[0]]) - y_min_tmp,
np.copy(labels_sc[0, marker_idxs[1]]) - x_min_tmp])
else:
labels_tmp.append([np.nan, np.nan])
if crop_type:
ims_crop_tmp.append(get_crop(
im_tmp[0, 0], crop_kwargs['y_0'], crop_kwargs['y_ext'], crop_kwargs['x_0'],
crop_kwargs['x_ext']))
else:
ims_crop_tmp.append([])
ims_list.append(ims_tmp)
ims_crop_list.append(ims_crop_tmp)
labels_list.append(labels_tmp)
return ims_list, labels_list, ims_crop_list
[docs]def interpolate_point_path(
interp_type, model, ims_0, labels_0, points, n_frames=10, ch=0, crop_kwargs=None,
apply_inverse_transform=True):
"""Return reconstructed images created by interpolating through multiple points.
This function is a simplified version of :func:`interpolate_1d()`; this function computes a
traversal for a single dimension instead of all dimensions; also, this function does not
support conditional encoders, nor does it attempt to compute the interpolated, scaled values
of the labels as :func:`interpolate_1d()` does. This function should supercede
:func:`interpolate_1d()` in a future refactor. Also note that this function is utilized by
the code to make traversal movies, whereas :func:`interpolate_1d()` is utilized by the code to
make traversal plots.
Parameters
----------
interp_type : :obj:`str`
'latents' | 'labels'
model : :obj:`behavenet.models` object
autoencoder model
ims_0 : :obj:`np.ndarray`
base images for interpolating labels, of shape (1, n_channels, y_pix, x_pix)
labels_0 : :obj:`np.ndarray`
base labels of shape (1, n_labels); these values will be used if
`interp_type='latents'`, and they will be ignored if `inter_type='labels'`
(since `points` will be used)
points : :obj:`list`
one entry for each point in path; each entry is an np.ndarray of shape (n_latents,)
n_frames : :obj:`int` or :obj:`array-like`
number of interpolation points between each point; can be an integer that is used
for all paths, or an array/list of length one less than number of points
ch : :obj:`int`, optional
specify which channel of input images to return; if not an int, all channels are
concatenated in the horizontal dimension
crop_kwargs : :obj:`dict`, optional
if crop_type is not None, provides information about the crop (for a fixed crop window)
keys : 'y_0', 'x_0', 'y_ext', 'x_ext'; window is
(y_0 - y_ext, y_0 + y_ext) in vertical direction and
(x_0 - x_ext, x_0 + x_ext) in horizontal direction
apply_inverse_transform : :obj:`bool`
if inputs are latents (and model class is 'cond-ae-msp' or 'ps-vae'), apply inverse
transform to put in original latent space
Returns
-------
:obj:`tuple`
- ims_list (:obj:`list` of :obj:`np.ndarray`) interpolated images
- inputs_list (:obj:`list` of :obj:`np.ndarray`) interpolated values
"""
if model.hparams.get('conditional_encoder', False):
raise NotImplementedError
n_points = len(points)
if isinstance(n_frames, int):
n_frames = [n_frames] * (n_points - 1)
assert len(n_frames) == (n_points - 1)
ims_list = []
inputs_list = []
for p in range(n_points - 1):
p0 = points[None, p]
p1 = points[None, p + 1]
p_vec = (p1 - p0) / n_frames[p]
for pn in range(n_frames[p] + 1):
vec = p0 + pn * p_vec
if interp_type == 'latents':
if model.hparams['model_class'] == 'cond-ae' \
or model.hparams['model_class'] == 'cond-vae':
im_tmp = get_reconstruction(
model, vec, apply_inverse_transform=apply_inverse_transform,
labels=torch.from_numpy(labels_0).float().to(model.hparams['device']))
else:
im_tmp = get_reconstruction(
model, vec, apply_inverse_transform=apply_inverse_transform)
elif interp_type == 'labels':
if model.hparams['model_class'] == 'cond-ae-msp' \
or model.hparams['model_class'] == 'ps-vae' \
or model.hparams['model_class'] == 'msps-vae':
im_tmp = get_reconstruction(
model, vec, apply_inverse_transform=True)
else: # cond-ae
im_tmp = get_reconstruction(
model, ims_0,
labels=torch.from_numpy(vec).float().to(model.hparams['device']))
else:
raise NotImplementedError
if crop_kwargs is not None:
if not isinstance(ch, int):
raise ValueError('"ch" must be an integer to use crop_kwargs')
ims_list.append(get_crop(
im_tmp[0, ch],
crop_kwargs['y_0'], crop_kwargs['y_ext'],
crop_kwargs['x_0'], crop_kwargs['x_ext']))
else:
if isinstance(ch, int):
ims_list.append(np.copy(im_tmp[0, ch]))
else:
ims_list.append(np.copy(concat(im_tmp[0])))
inputs_list.append(vec)
return ims_list, inputs_list
def _get_updated_scaled_labels(labels_og, idxs=None, vals=None):
"""Helper function for interpolate_xd functions."""
if labels_og is not None:
if len(labels_og.shape) == 4:
# 2d scaled labels
tmp = np.copy(labels_og)
t, y, x = np.where(tmp[0] == 1)
labels_sc = np.hstack([x, y])[None, :]
else:
# 1d scaled labels
labels_sc = np.copy(labels_og)
if idxs is not None:
if isinstance(idxs, int):
assert isinstance(vals, float)
idxs = [idxs]
vals = [vals]
else:
assert len(idxs) == len(vals)
for idx, val in zip(idxs, vals):
labels_sc[0, idx] = val
else:
labels_sc = None
return labels_sc
# ----------------------------------------
# mid-level plotting functions
# ----------------------------------------
[docs]def plot_2d_frame_array(
ims_list, markers=None, im_kwargs=None, marker_kwargs=None, figsize=None, save_file=None,
format='pdf'):
"""Plot list of list of interpolated images output by :func:`interpolate_2d()` in a 2d grid.
Parameters
----------
ims_list : :obj:`list` of :obj:`list`
each inner list element holds an np.ndarray of shape (y_pix, x_pix)
markers : :obj:`list` of :obj:`list` or NoneType, optional
each inner list element holds an array-like object with values (y_pix, x_pix); if None,
markers are not plotted on top of frames
im_kwargs : :obj:`dict` or NoneType, optional
kwargs for `matplotlib.pyplot.imshow()` function (vmin, vmax, cmap, etc)
marker_kwargs : :obj:`dict` or NoneType, optional
kwargs for `matplotlib.pyplot.plot()` function (markersize, markeredgewidth, etc)
figsize : :obj:`tuple`, optional
(width, height) in inches
save_file : :obj:`str` or NoneType, optional
figure saved if not None
format : :obj:`str`, optional
format of saved image; 'pdf' | 'png' | 'jpeg' | ...
"""
n_y = len(ims_list)
n_x = len(ims_list[0])
if figsize is None:
y_pix, x_pix = ims_list[0][0].shape
# how many inches per pixel?
in_per_pix = 15 / (x_pix * n_x)
figsize = (15, in_per_pix * y_pix * n_y)
fig, axes = plt.subplots(n_y, n_x, figsize=figsize)
if im_kwargs is None:
im_kwargs = {'vmin': 0, 'vmax': 1, 'cmap': 'gray'}
if marker_kwargs is None:
marker_kwargs = {'markersize': 20, 'markeredgewidth': 3}
for r, ims_list_y in enumerate(ims_list):
for c, im in enumerate(ims_list_y):
axes[r, c].imshow(im, **im_kwargs)
axes[r, c].set_xticks([])
axes[r, c].set_yticks([])
if markers is not None:
axes[r, c].plot(
markers[r][c][1], markers[r][c][0], 'o', **marker_kwargs)
plt.subplots_adjust(wspace=0, hspace=0, bottom=0, left=0, top=1, right=1)
if save_file is not None:
make_dir_if_not_exists(save_file)
plt.savefig(save_file + '.' + format, dpi=300, bbox_inches='tight')
plt.show()
[docs]def plot_1d_frame_array(
ims_list, markers=None, im_kwargs=None, marker_kwargs=None, plot_ims=True, plot_diffs=True,
figsize=None, save_file=None, format='pdf'):
"""Plot list of list of interpolated images output by :func:`interpolate_1d()` in a 2d grid.
Parameters
----------
ims_list : :obj:`list` of :obj:`list`
each inner list element holds an np.ndarray of shape (y_pix, x_pix)
markers : :obj:`list` of :obj:`list` or NoneType, optional
each inner list element holds an array-like object with values (y_pix, x_pix); if None,
markers are not plotted on top of frames
im_kwargs : :obj:`dict` or NoneType, optional
kwargs for `matplotlib.pyplot.imshow()` function (vmin, vmax, cmap, etc)
marker_kwargs : :obj:`dict` or NoneType, optional
kwargs for `matplotlib.pyplot.plot()` function (markersize, markeredgewidth, etc)
plot_ims : :obj:`bool`, optional
plot images
plot_diffs : :obj:`bool`, optional
plot differences
figsize : :obj:`tuple`, optional
(width, height) in inches
save_file : :obj:`str` or NoneType, optional
figure saved if not None
format : :obj:`str`, optional
format of saved image; 'pdf' | 'png' | 'jpeg' | ...
"""
if not (plot_ims or plot_diffs):
raise ValueError('Must plot at least one of ims or diffs')
if plot_ims and plot_diffs:
n_y = len(ims_list) * 2
offset = 2
else:
n_y = len(ims_list)
offset = 1
n_x = len(ims_list[0])
if figsize is None:
y_pix, x_pix = ims_list[0][0].shape
# how many inches per pixel?
in_per_pix = 15 / (x_pix * n_x)
figsize = (15, in_per_pix * y_pix * n_y)
fig, axes = plt.subplots(n_y, n_x, figsize=figsize)
if im_kwargs is None:
im_kwargs = {'vmin': 0, 'vmax': 1, 'cmap': 'gray'}
if marker_kwargs is None:
marker_kwargs = {'markersize': 20, 'markeredgewidth': 3}
for r, ims_list_y in enumerate(ims_list):
base_im = ims_list_y[0]
for c, im in enumerate(ims_list_y):
# plot original images
if plot_ims:
axes[offset * r, c].imshow(im, **im_kwargs)
axes[offset * r, c].set_xticks([])
axes[offset * r, c].set_yticks([])
if markers is not None:
axes[offset * r, c].plot(
markers[r][c][1], markers[r][c][0], 'o', **marker_kwargs)
# plot differences
if plot_diffs and plot_ims:
axes[offset * r + 1, c].imshow(0.5 + (im - base_im), **im_kwargs)
axes[offset * r + 1, c].set_xticks([])
axes[offset * r + 1, c].set_yticks([])
elif plot_diffs:
axes[offset * r, c].imshow(0.5 + (im - base_im), **im_kwargs)
axes[offset * r, c].set_xticks([])
axes[offset * r, c].set_yticks([])
plt.subplots_adjust(wspace=0, hspace=0, bottom=0, left=0, top=1, right=1)
if save_file is not None:
make_dir_if_not_exists(save_file)
plt.savefig(save_file + '.' + format, dpi=300, bbox_inches='tight')
plt.show()
[docs]def make_interpolated(
ims, save_file, markers=None, text=None, text_title=None, text_color=[1, 1, 1],
fontsize=4, frame_rate=20, scale=3, markersize=10, markeredgecolor='w', markeredgewidth=1,
ax=None):
"""Make a latent space interpolation movie.
Parameters
----------
ims : :obj:`list` of :obj:`np.ndarray`
each list element is an array of shape (y_pix, x_pix)
save_file : :obj:`str`
absolute path of save file; does not need file extension, will automatically be saved as
mp4. To save as a gif, include the '.gif' file extension in `save_file`. The movie will
only be saved if `ax` is `NoneType`; else the list of animated frames is returned
markers : :obj:`array-like`, optional
array of size (n_frames, 2) which specifies the (x, y) coordinates of a marker on each
frame
text : :obj:`array-like`, optional
array of size (n_frames) which specifies text printed in the lower left corner of each
frame
text_title : :obj:`array-like`, optional
array of size (n_frames) which specifies text printed in the upper left corner of each
frame
text_color : :obj:`array-like`, optional
rgb array specifying color of `text` and `text_title`, if applicable
frame_rate : :obj:`float`, optional
frame rate of saved movie
scale : :obj:`float`, optional
width of panel is (scale / 2) inches
markersize : :obj:`float`, optional
size of marker if `markers` is not `NoneType`
markeredgecolor : :obj:`float`, optional
color of marker edge if `markers` is not `NoneType`
markeredgewidth : :obj:`float`, optional
width of marker edge if `markers` is not `NoneType`
ax : :obj:`matplotlib.axes.Axes` object
optional axis in which to plot the frames; if this argument is not `NoneType` the list of
animated frames is returned and the movie is not saved
Returns
-------
:obj:`list`
list of list of animated frames if `ax` is True; else save movie
"""
y_pix, x_pix = ims[0].shape
if ax is None:
fig_width = scale / 2
fig_height = y_pix / x_pix * scale / 2
fig = plt.figure(figsize=(fig_width, fig_height), dpi=300)
ax = plt.gca()
return_ims = False
else:
return_ims = True
ax.set_xticks([])
ax.set_yticks([])
default_kwargs = {'animated': True, 'cmap': 'gray', 'vmin': 0, 'vmax': 1}
txt_kwargs = {
'fontsize': fontsize, 'color': text_color, 'fontname': 'monospace',
'horizontalalignment': 'left', 'verticalalignment': 'center',
'transform': ax.transAxes}
# ims is a list of lists, each row is a list of artists to draw in the current frame; here we
# are just animating one artist, the image, in each frame
ims_ani = []
for i, im in enumerate(ims):
im_tmp = []
im_tmp.append(ax.imshow(im, **default_kwargs))
# [s.set_visible(False) for s in ax.spines.values()]
if markers is not None:
im_tmp.append(ax.plot(
markers[i, 0], markers[i, 1], '.r', markersize=markersize,
markeredgecolor=markeredgecolor, markeredgewidth=markeredgewidth)[0])
if text is not None:
im_tmp.append(ax.text(0.02, 0.06, text[i], **txt_kwargs))
if text_title is not None:
im_tmp.append(ax.text(0.02, 0.92, text_title[i], **txt_kwargs))
ims_ani.append(im_tmp)
if return_ims:
return ims_ani
else:
plt.tight_layout(pad=0)
ani = animation.ArtistAnimation(fig, ims_ani, blit=True, repeat_delay=1000)
save_movie(save_file, ani, frame_rate=frame_rate)
[docs]def make_interpolated_multipanel(
ims, save_file, markers=None, text=None, text_title=None, frame_rate=20, n_cols=3, scale=1,
**kwargs):
"""Make a multi-panel latent space interpolation movie.
Parameters
----------
ims : :obj:`list` of :obj:`list` of :obj:`np.ndarray`
each list element is used to for a single panel, and is another list that contains arrays
of shape (y_pix, x_pix)
save_file : :obj:`str`
absolute path of save file; does not need file extension, will automatically be saved as
mp4. To save as a gif, include the '.gif' file extension in `save_file`.
markers : :obj:`list` of :obj:`array-like`, optional
each list element is used for a single panel, and is an array of size (n_frames, 2)
which specifies the (x, y) coordinates of a marker on each frame for that panel
text : :obj:`list` of :obj:`array-like`, optional
each list element is used for a single panel, and is an array of size (n_frames) which
specifies text printed in the lower left corner of each frame for that panel
text_title : :obj:`list` of :obj:`array-like`, optional
each list element is used for a single panel, and is an array of size (n_frames) which
specifies text printed in the upper left corner of each frame for that panel
frame_rate : :obj:`float`, optional
frame rate of saved movie
n_cols : :obj:`int`, optional
movie is `n_cols` panels wide
scale : :obj:`float`, optional
width of panel is (scale / 2) inches
kwargs
arguments are additional arguments to :func:`make_interpolated`, like 'markersize',
'markeredgewidth', 'markeredgecolor', etc.
"""
n_panels = len(ims)
markers = [None] * n_panels if markers is None else markers
text = [None] * n_panels if text is None else text
y_pix, x_pix = ims[0][0].shape
n_rows = int(np.ceil(n_panels / n_cols))
fig_width = scale / 2 * n_cols
fig_height = y_pix / x_pix * scale / 2 * n_rows
fig, axes = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height), dpi=300)
plt.subplots_adjust(wspace=0, hspace=0, left=0, bottom=0, right=1, top=1)
# fill out empty panels with black frames
while len(ims) < n_rows * n_cols:
ims.append(np.zeros(ims[0].shape))
markers.append(None)
text.append(None)
# ims is a list of lists, each row is a list of artists to draw in the current frame; here we
# are just animating one artist, the image, in each frame
ims_ani = []
for i, (ims_curr, markers_curr, text_curr) in enumerate(zip(ims, markers, text)):
col = i % n_cols
row = int(np.floor(i / n_cols))
if i == 0:
text_title_str = text_title
else:
text_title_str = None
if n_rows == 1:
ax = axes[col]
elif n_cols == 1:
ax = axes[row]
else:
ax = axes[row, col]
ims_ani_curr = make_interpolated(
ims=ims_curr, markers=markers_curr, text=text_curr, text_title=text_title_str, ax=ax,
save_file=None, **kwargs)
ims_ani.append(ims_ani_curr)
# turn off other axes
i += 1
while i < n_rows * n_cols:
col = i % n_cols
row = int(np.floor(i / n_cols))
axes[row, col].set_axis_off()
i += 1
# rearrange ims:
# currently a list of length n_panels, each element of which is a list of length n_t
# we need a list of length n_t, each element of which is a list of length n_panels
n_frames = len(ims_ani[0])
ims_final = [[] for _ in range(n_frames)]
for i in range(n_frames):
for j in range(n_panels):
ims_final[i] += ims_ani[j][i]
ani = animation.ArtistAnimation(fig, ims_final, blit=True, repeat_delay=1000)
save_movie(save_file, ani, frame_rate=frame_rate)
# ----------------------------------------
# helper functions for high-level plotting
# ----------------------------------------
def _get_psvae_hparams(**kwargs):
hparams = {
'data_dir': get_user_dir('data'),
'save_dir': get_user_dir('save'),
'model_class': 'ps-vae',
'model_type': 'conv',
'rng_seed_data': 0,
'trial_splits': '8;1;1;0',
'train_frac': 1.0,
'rng_seed_model': 0,
'device': 'cpu',
'as_numpy': False,
'batch_load': True,
'fit_sess_io_layers': False,
'learning_rate': 1e-4,
'l2_reg': 0,
'conditional_encoder': False,
'vae.beta': 1}
# update hparams
for key, val in kwargs.items():
if key == 'alpha' or key == 'beta' or key == 'delta':
hparams['ps_vae.%s' % key] = val
else:
hparams[key] = val
return hparams
def apply_masks(data, masks):
return data[masks == 1]
def get_label_r2(
hparams, model, data_generator, version, label_names, dtype='val', overwrite=False):
from sklearn.metrics import r2_score
n_labels = len(label_names)
save_file = os.path.join(
hparams['expt_dir'], 'version_%i' % version, 'r2_supervised.csv')
if not os.path.exists(save_file) or overwrite:
if not os.path.exists(save_file):
print('R^2 metrics do not exist; computing from scratch')
else:
print('overwriting metrics at %s' % save_file)
metrics_df = []
data_generator.reset_iterators(dtype)
for i_test in tqdm(range(data_generator.n_tot_batches[dtype])):
# get next minibatch and put it on the device
data, sess = data_generator.next_batch(dtype)
x = data['images'][0]
y = data['labels'][0].cpu().detach().numpy()
if 'labels_masks' in data:
n = data['labels_masks'][0].cpu().detach().numpy()
else:
n = np.ones_like(y)
z = model.get_transformed_latents(x, dataset=sess)
for i in range(n_labels):
y_true = apply_masks(y[:, i], n[:, i])
y_pred = apply_masks(z[:, i], n[:, i])
if len(y_true) > 10:
r2 = r2_score(y_true, y_pred, multioutput='variance_weighted')
mse = np.mean(np.square(y_true - y_pred))
else:
r2 = np.nan
mse = np.nan
metrics_df.append(pd.DataFrame({
'Trial': data['batch_idx'].item(),
'Label': label_names[i],
'R2': r2,
'MSE': mse,
'Model': 'MSPS-VAE'}, index=[0]))
metrics_df = pd.concat(metrics_df)
print('saving results to %s' % save_file)
metrics_df.to_csv(save_file, index=False, header=True)
else:
print('loading results from %s' % save_file)
metrics_df = pd.read_csv(save_file)
return metrics_df
def collect_data(data_generator, model, dtype, fit_full=False):
ys = []
zs = []
masks = []
trials = []
sessions = []
data_generator.reset_iterators(dtype)
for _ in tqdm(range(data_generator.n_tot_batches[dtype])):
data, sess = data_generator.next_batch(dtype)
x = data['images'][0]
y = data['labels'][0] if 'labels' in data else None
n = data['labels_masks'][0] if 'labels_masks' in data else None
if model.hparams['model_class'] == 'ae':
z, _, _ = model.encoding(x, dataset=sess)
elif model.hparams['model_class'] == 'vae' or model.hparams['model_class'] == 'cond-vae':
z, _, _, _ = model.encoding(x, dataset=sess)
elif model.hparams['model_class'] == 'sss-vae' or model.hparams['model_class'] == 'ps-vae':
yhat, w, _, _, _ = model.encoding(x, dataset=sess)
if fit_full:
z = torch.cat([yhat, w], axis=1)
else:
z = w
elif model.hparams['model_class'] == 'msps-vae':
z_s, z_b, z, _, _, _ = model.encoding(x, dataset=sess)
else:
raise NotImplementedError
if y is not None:
ys.append(y.cpu().detach().numpy())
zs.append(z.cpu().detach().numpy())
if n is None:
if len(ys) > 0:
masks.append(np.ones_like(ys[-1]))
else:
masks.append(None)
else:
masks.append(n.cpu().detach().numpy())
trials.append(data['batch_idx'].item())
sessions.append(sess * np.ones(zs[-1].shape[0]))
return ys, zs, masks, trials, sessions
[docs]def fit_classifier(model, data_generator, dtype='val', fit_full=False, overwrite=False):
"""Fit classifier model from latent space to session id."""
from sklearn.linear_model import LogisticRegressionCV as LR
from sklearn.metrics import accuracy_score
save_file = os.path.join(
model.hparams['expt_dir'], 'version_%i' % model.version, 'fc_session_classification.npy')
if not os.path.exists(save_file) or overwrite:
if not os.path.exists(save_file):
print('FC metrics do not exist; computing from scratch')
else:
print('overwriting metrics at %s' % save_file)
print('collecting training labels and latents')
_, zs_tr, _, _, sessions_tr = collect_data(
data_generator, model, dtype='train', fit_full=fit_full)
print('done')
print('collecting %s labels and latents' % dtype)
_, zs, _, _, sessions = collect_data(data_generator, model, dtype=dtype, fit_full=fit_full)
print('done')
print('fitting linear classifier model with training data')
ys_mat = np.concatenate(sessions_tr)
zs_mat = np.concatenate(zs_tr, axis=0)
lr = LR(
Cs=[0.00001, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], cv=5, penalty='l2',
multi_class='multinomial').fit(zs_mat, ys_mat)
print('done')
print('computing fraction correct on %s data' % dtype)
y_pred = lr.predict(np.concatenate(zs))
y_true = np.concatenate(sessions)
fc = accuracy_score(y_true, y_pred)
print('done')
print('saving results to %s' % save_file)
np.save(save_file, np.array([fc]))
else:
print('loading results from %s' % save_file)
fc = np.load(save_file)[0]
lr = None
return fc, lr
# ----------------------------------------
# high-level plotting functions
# ----------------------------------------
[docs]def plot_psvae_training_curves(
lab, expt, animal, session, alphas, betas, n_ae_latents, rng_seeds_model,
experiment_name, n_labels, dtype='val', save_file=None, format='pdf', **kwargs):
"""Create training plots for each term in the ps-vae objective function.
The `dtype` argument controls which type of trials are plotted ('train' or 'val').
Additionally, multiple models can be plotted simultaneously by varying one (and only one) of
the following parameters:
- alpha
- beta
- number of unsupervised latents
- random seed used to initialize model weights
Each of these entries must be an array of length 1 except for one option, which can be an array
of arbitrary length (corresponding to already trained models). This function generates a single
plot with panels for each of the following terms:
- total loss
- pixel mse
- label R^2 (note the objective function contains the label MSE, but R^2 is easier to parse)
- KL divergence of supervised latents
- index-code mutual information of unsupervised latents
- total correlation of unsupervised latents
- dimension-wise KL of unsupervised latents
- subspace overlap
Parameters
----------
lab : :obj:`str`
lab id
expt : :obj:`str`
expt id
animal : :obj:`str`
animal id
session : :obj:`str`
session id
alphas : :obj:`array-like`
alpha values to plot
betas : :obj:`array-like`
beta values to plot
n_ae_latents : :obj:`array-like`
unsupervised dimensionalities to plot
rng_seeds_model : :obj:`array-like`
model seeds to plot
experiment_name : :obj:`str`
test-tube experiment name
n_labels : :obj:`int`
dimensionality of supervised latent space
dtype : :obj:`str`
'train' | 'val'
save_file : :obj:`str`, optional
absolute path of save file; does not need file extension
format : :obj:`str`, optional
format of saved image; 'pdf' | 'png' | 'jpeg' | ...
kwargs
arguments are keys of `hparams`, for example to set `train_frac`, `rng_seed_model`, etc.
"""
# check for arrays, turn ints into lists
n_arrays = 0
hue = None
if len(alphas) > 1:
n_arrays += 1
hue = 'alpha'
if len(betas) > 1:
n_arrays += 1
hue = 'beta'
if len(n_ae_latents) > 1:
n_arrays += 1
hue = 'n latents'
if len(rng_seeds_model) > 1:
n_arrays += 1
hue = 'rng seed'
if n_arrays > 1:
raise ValueError(
'Can only set one of "alphas", "betas", "n_ae_latents", or "rng_seeds_model"' +
'as an array')
# set model info
hparams = _get_psvae_hparams(experiment_name=experiment_name, **kwargs)
metrics_list = [
'loss', 'loss_data_mse', 'label_r2', 'loss_zs_kl', 'loss_zu_mi', 'loss_zu_tc',
'loss_zu_dwkl']
metrics_dfs = []
i = 0
for alpha in alphas:
for beta in betas:
for n_latents in n_ae_latents:
for rng in rng_seeds_model:
# update hparams
hparams['ps_vae.alpha'] = alpha
hparams['ps_vae.beta'] = beta
hparams['n_ae_latents'] = n_latents + n_labels
hparams['rng_seed_model'] = rng
try:
get_lab_example(hparams, lab, expt)
hparams['animal'] = animal
hparams['session'] = session
hparams['session_dir'], sess_ids = get_session_dir(hparams)
hparams['expt_dir'] = get_expt_dir(hparams)
_, version = experiment_exists(hparams, which_version=True)
print(
'loading results with alpha=%i, beta=%i (version %i)' %
(alpha, beta, version))
metrics_dfs.append(load_metrics_csv_as_df(
hparams, lab, expt, metrics_list, version=None))
metrics_dfs[i]['alpha'] = alpha
metrics_dfs[i]['beta'] = beta
metrics_dfs[i]['n latents'] = hparams['n_ae_latents']
metrics_dfs[i]['rng seed'] = rng
i += 1
except TypeError:
print('could not find model for alpha=%i, beta=%i' % (alpha, beta))
continue
metrics_df = pd.concat(metrics_dfs, sort=False)
sns.set_style('white')
sns.set_context('talk')
data_queried = metrics_df[
(metrics_df.epoch > 10) & ~pd.isna(metrics_df.val) & (metrics_df.dtype == dtype)]
g = sns.FacetGrid(
data_queried, col='loss', col_wrap=3, hue=hue, sharey=False, height=4)
g = g.map(plt.plot, 'epoch', 'val').add_legend() # , color=".3", fit_reg=False, x_jitter=.1);
if save_file is not None:
make_dir_if_not_exists(save_file)
g.savefig(save_file + '.' + format, dpi=300, format=format)
[docs]def plot_hyperparameter_search_results(
lab, expt, animal, session, n_labels, label_names, alpha_weights, alpha_n_ae_latents,
alpha_expt_name, beta_weights, beta_n_ae_latents, beta_expt_name, alpha, beta, save_file,
batch_size=None, format='pdf', **kwargs):
"""Create a variety of diagnostic plots to assess the ps-vae hyperparameters.
These diagnostic plots are based on the recommended way to perform a hyperparameter search in
the ps-vae models; first, fix beta=1, and do a sweep over alpha values and number
of latents (for example alpha=[50, 100, 500, 1000] and n_ae_latents=[2, 4, 8, 16]). The best
alpha value is subjective because it involves a tradeoff between pixel mse and label mse. After
choosing a suitable value, fix alpha and the number of latents and vary beta. This function
will then plot the following panels:
- pixel mse as a function of alpha/num latents (for fixed beta)
- label mse as a function of alpha/num_latents (for fixed beta)
- pixel mse as a function of beta (for fixed alpha/n_ae_latents)
- label mse as a function of beta (for fixed alpha/n_ae_latents)
- index-code mutual information (part of the KL decomposition) as a function of beta (for
fixed alpha/n_ae_latents)
- total correlation(part of the KL decomposition) as a function of beta (for fixed
alpha/n_ae_latents)
- dimension-wise KL (part of the KL decomposition) as a function of beta (for fixed
alpha/n_ae_latents)
- average correlation coefficient across all pairs of unsupervised latent dims as a function of
beta (for fixed alpha/n_ae_latents)
Parameters
----------
lab : :obj:`str`
lab id
expt : :obj:`str`
expt id
animal : :obj:`str`
animal id
session : :obj:`str`
session id
n_labels : :obj:`str`
number of label dims
label_names : :obj:`array-like`
names of label dims
alpha_weights : :obj:`array-like`
array of alpha weights for fixed values of beta
alpha_n_ae_latents : :obj:`array-like`
array of latent dimensionalities for fixed values of beta using alpha_weights
alpha_expt_name : :obj:`str`
test-tube experiment name of alpha-based hyperparam search
beta_weights : :obj:`array-like`
array of beta weights for a fixed value of alpha
beta_n_ae_latents : :obj:`int`
latent dimensionality used for beta hyperparam search
beta_expt_name : :obj:`str`
test-tube experiment name of beta hyperparam search
alpha : :obj:`float`
fixed value of alpha for beta search
beta : :obj:`float`
fixed value of beta for alpha search
save_file : :obj:`str`
absolute path of save file; does not need file extension
batch_size : :obj:`int`, optional
size of batches, used to compute correlation coefficient per batch; if NoneType, the
correlation coefficient is computed across all time points
format : :obj:`str`, optional
format of saved image; 'pdf' | 'png' | 'jpeg' | ...
kwargs
arguments are keys of `hparams`, preceded by either `alpha_` or `beta_`. For example,
to set the train frac of the alpha models, use `alpha_train_frac`; to set the rng_data_seed
of the beta models, use `beta_rng_data_seed`.
"""
# -----------------------------------------------------
# load pixel/label MSE as a function of n_latents/alpha
# -----------------------------------------------------
# set model info
hparams = _get_psvae_hparams(experiment_name=alpha_expt_name)
# update hparams
for key, val in kwargs.items():
# hparam vals should be named 'alpha_[property]', for example 'alpha_train_frac'
if key.split('_')[0] == 'alpha':
prop = key[6:]
hparams[prop] = val
else:
hparams[key] = val
metrics_list = ['loss_data_mse']
metrics_dfs_frame = []
metrics_dfs_marker = []
for n_latent in alpha_n_ae_latents:
hparams['n_ae_latents'] = n_latent + n_labels
for alpha_ in alpha_weights:
hparams['ps_vae.alpha'] = alpha_
hparams['ps_vae.beta'] = beta
try:
get_lab_example(hparams, lab, expt)
hparams['animal'] = animal
hparams['session'] = session
hparams['session_dir'], sess_ids = get_session_dir(hparams)
hparams['expt_dir'] = get_expt_dir(hparams)
_, version = experiment_exists(hparams, which_version=True)
print('loading results with alpha=%i, beta=%i (version %i)' % (
hparams['ps_vae.alpha'], hparams['ps_vae.beta'], version))
# get frame mse
metrics_dfs_frame.append(load_metrics_csv_as_df(
hparams, lab, expt, metrics_list, version=None, test=True))
metrics_dfs_frame[-1]['alpha'] = alpha_
metrics_dfs_frame[-1]['n_latents'] = hparams['n_ae_latents']
# get marker mse
model, data_gen = get_best_model_and_data(
hparams, Model=None, load_data=True, version=version)
metrics_df_ = get_label_r2(
hparams, model, data_gen, version, label_names, dtype='val')
metrics_df_['alpha'] = alpha_
metrics_df_['n_latents'] = hparams['n_ae_latents']
metrics_dfs_marker.append(metrics_df_[metrics_df_.Model == 'PS-VAE'])
except TypeError:
print('could not find model for alpha=%i, beta=%i' % (
hparams['ps_vae.alpha'], hparams['ps_vae.beta']))
continue
metrics_df_frame = pd.concat(metrics_dfs_frame, sort=False)
metrics_df_marker = pd.concat(metrics_dfs_marker, sort=False)
print('done')
# -----------------------------------------------------
# load pixel/label MSE as a function of beta
# -----------------------------------------------------
# update hparams
hparams['experiment_name'] = beta_expt_name
for key, val in kwargs.items():
# hparam vals should be named 'beta_[property]', for example 'beta_train_frac'
if key.split('_')[0] == 'beta':
prop = key[5:]
hparams[prop] = val
metrics_list = ['loss_data_mse', 'loss_zu_mi', 'loss_zu_tc', 'loss_zu_dwkl']
metrics_dfs_frame_bg = []
metrics_dfs_marker_bg = []
metrics_dfs_corr_bg = []
for beta in beta_weights:
hparams['n_ae_latents'] = beta_n_ae_latents + n_labels
hparams['ps_vae.alpha'] = alpha
hparams['ps_vae.beta'] = beta
try:
get_lab_example(hparams, lab, expt)
hparams['animal'] = animal
hparams['session'] = session
hparams['session_dir'], sess_ids = get_session_dir(hparams)
hparams['expt_dir'] = get_expt_dir(hparams)
_, version = experiment_exists(hparams, which_version=True)
print('loading results with alpha=%i, beta=%i, (version %i)' % (
hparams['ps_vae.alpha'], hparams['ps_vae.beta'], version))
# get frame mse
metrics_dfs_frame_bg.append(load_metrics_csv_as_df(
hparams, lab, expt, metrics_list, version=None, test=True))
metrics_dfs_frame_bg[-1]['beta'] = beta
# get marker mse
model, data_gen = get_best_model_and_data(
hparams, Model=None, load_data=True, version=version)
metrics_df_ = get_label_r2(hparams, model, data_gen, version, label_names, dtype='val')
metrics_df_['beta'] = beta
metrics_dfs_marker_bg.append(metrics_df_[metrics_df_.Model == 'PS-VAE'])
# get corr
latents = load_latents(hparams, version, dtype='test')
if batch_size is None:
corr = np.corrcoef(latents[:, n_labels + np.array([0, 1])].T)
metrics_dfs_corr_bg.append(pd.DataFrame({
'loss': 'corr',
'dtype': 'test',
'val': np.abs(corr[0, 1]),
'beta': beta}, index=[0]))
else:
n_batches = int(np.ceil(latents.shape[0] / batch_size))
for i in range(n_batches):
corr = np.corrcoef(
latents[i * batch_size:(i + 1) * batch_size,
n_labels + np.array([0, 1])].T)
metrics_dfs_corr_bg.append(pd.DataFrame({
'loss': 'corr',
'dtype': 'test',
'val': np.abs(corr[0, 1]),
'beta': beta}, index=[0]))
except TypeError:
print('could not find model for alpha=%i, beta=%i' % (
hparams['ps_vae.alpha'], hparams['ps_vae.beta']))
continue
print()
metrics_df_frame_bg = pd.concat(metrics_dfs_frame_bg, sort=False)
metrics_df_marker_bg = pd.concat(metrics_dfs_marker_bg, sort=False)
metrics_df_corr_bg = pd.concat(metrics_dfs_corr_bg, sort=False)
print('done')
# -----------------------------------------------------
# ----------------- PLOT DATA -------------------------
# -----------------------------------------------------
sns.set_style('white')
sns.set_context('paper', font_scale=1.2)
alpha_palette = sns.color_palette('Greens')
beta_palette = sns.color_palette('Reds', len(metrics_df_corr_bg.beta.unique()))
from matplotlib.gridspec import GridSpec
fig = plt.figure(figsize=(12, 7), dpi=300)
n_rows = 2
n_cols = 12
gs = GridSpec(n_rows, n_cols, figure=fig)
def despine(ax):
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
sns.set_palette(alpha_palette)
# --------------------------------------------------
# MSE per pixel
# --------------------------------------------------
ax_pixel_mse_alpha = fig.add_subplot(gs[0, 0:3])
data_queried = metrics_df_frame[(metrics_df_frame.dtype == 'test')]
sns.barplot(x='n_latents', y='val', hue='alpha', data=data_queried, ax=ax_pixel_mse_alpha)
ax_pixel_mse_alpha.legend().set_visible(False)
ax_pixel_mse_alpha.set_xlabel('Latent dimension')
ax_pixel_mse_alpha.set_ylabel('MSE per pixel')
ax_pixel_mse_alpha.ticklabel_format(axis='y', style='sci', scilimits=(-3, 3))
ax_pixel_mse_alpha.set_title('Beta=1, Gamma=0')
despine(ax_pixel_mse_alpha)
# --------------------------------------------------
# MSE per marker
# --------------------------------------------------
ax_marker_mse_alpha = fig.add_subplot(gs[0, 3:6])
data_queried = metrics_df_marker
sns.barplot(x='n_latents', y='MSE', hue='alpha', data=data_queried, ax=ax_marker_mse_alpha)
ax_marker_mse_alpha.set_xlabel('Latent dimension')
ax_marker_mse_alpha.set_ylabel('MSE per marker')
ax_marker_mse_alpha.set_title('Beta=1, Gamma=0')
ax_marker_mse_alpha.legend(frameon=True, title='Alpha')
despine(ax_marker_mse_alpha)
# --------------------------------------------------
# MSE per pixel (beta)
# --------------------------------------------------
ax_pixel_mse_bg = fig.add_subplot(gs[0, 6:9])
data_queried = metrics_df_frame_bg[
(metrics_df_frame_bg.dtype == 'test') &
(metrics_df_frame_bg.loss == 'loss_data_mse') &
(metrics_df_frame_bg.epoch == 200)]
sns.barplot(x='beta', y='val', data=data_queried, ax=ax_pixel_mse_bg)
ax_pixel_mse_bg.legend().set_visible(False)
ax_pixel_mse_bg.set_xlabel('Beta')
ax_pixel_mse_bg.set_ylabel('MSE per pixel')
ax_pixel_mse_bg.ticklabel_format(axis='y', style='sci', scilimits=(-3, 3))
ax_pixel_mse_bg.set_title('Latents=%i, Alpha=%i' % (hparams['n_ae_latents'], alpha))
despine(ax_pixel_mse_bg)
# --------------------------------------------------
# MSE per marker (beta)
# --------------------------------------------------
ax_marker_mse_bg = fig.add_subplot(gs[0, 9:12])
data_queried = metrics_df_marker_bg
sns.barplot(x='beta', y='MSE', data=data_queried, ax=ax_marker_mse_bg)
ax_marker_mse_bg.set_xlabel('Beta')
ax_marker_mse_bg.set_ylabel('MSE per marker')
ax_marker_mse_bg.set_title('Latents=%i, Alpha=%i' % (hparams['n_ae_latents'], alpha))
despine(ax_marker_mse_bg)
# --------------------------------------------------
# ICMI
# --------------------------------------------------
ax_icmi = fig.add_subplot(gs[1, 0:3])
data_queried = metrics_df_frame_bg[
(metrics_df_frame_bg.dtype == 'test') &
(metrics_df_frame_bg.loss == 'loss_zu_mi') &
(metrics_df_frame_bg.epoch == 200)]
sns.lineplot(x='beta', y='val', data=data_queried, ax=ax_icmi, ci=None)
ax_icmi.legend().set_visible(False)
ax_icmi.set_xlabel('Beta')
ax_icmi.set_ylabel('Index-code Mutual Information')
ax_icmi.set_title('Latents=%i, Alpha=%i' % (hparams['n_ae_latents'], alpha))
despine(ax_icmi)
# --------------------------------------------------
# TC
# --------------------------------------------------
ax_tc = fig.add_subplot(gs[1, 3:6])
data_queried = metrics_df_frame_bg[
(metrics_df_frame_bg.dtype == 'test') &
(metrics_df_frame_bg.loss == 'loss_zu_tc') &
(metrics_df_frame_bg.epoch == 200)]
sns.lineplot(x='beta', y='val', data=data_queried, ax=ax_tc, ci=None)
ax_tc.legend().set_visible(False)
ax_tc.set_xlabel('Beta')
ax_tc.set_ylabel('Total Correlation')
ax_tc.set_title('Latents=%i, Alpha=%i' % (hparams['n_ae_latents'], alpha))
despine(ax_tc)
# --------------------------------------------------
# DWKL
# --------------------------------------------------
ax_dwkl = fig.add_subplot(gs[1, 6:9])
data_queried = metrics_df_frame_bg[
(metrics_df_frame_bg.dtype == 'test') &
(metrics_df_frame_bg.loss == 'loss_zu_dwkl') &
(metrics_df_frame_bg.epoch == 200)]
sns.lineplot(x='beta', y='val', data=data_queried, ax=ax_dwkl, ci=None)
ax_dwkl.legend().set_visible(False)
ax_dwkl.set_xlabel('Beta')
ax_dwkl.set_ylabel('Dimension-wise KL')
ax_dwkl.set_title('Latents=%i, Alpha=%i' % (hparams['n_ae_latents'], alpha))
despine(ax_dwkl)
# --------------------------------------------------
# CC
# --------------------------------------------------
ax_cc = fig.add_subplot(gs[1, 9:12])
data_queried = metrics_df_corr_bg
sns.lineplot(x='beta', y='val', data=data_queried, ax=ax_cc, ci=None)
ax_cc.legend().set_visible(False)
ax_cc.set_xlabel('Beta')
ax_cc.set_ylabel('Correlation Coefficient')
ax_cc.set_title('Latents=%i, Alpha=%i' % (hparams['n_ae_latents'], alpha))
despine(ax_cc)
plt.tight_layout(h_pad=3) # h_pad is fraction of font size
# reset to default color palette
# sns.set_palette(sns.color_palette(None, 10))
sns.reset_orig()
if save_file is not None:
make_dir_if_not_exists(save_file)
plt.savefig(save_file + '.' + format, dpi=300, format=format)
[docs]def plot_label_reconstructions(
lab, expt, animal, session, n_ae_latents, experiment_name, n_labels, trials, version=None,
plot_scale=0.5, sess_idx=0, save_file=None, format='pdf', xtick_locs=None, frame_rate=None,
max_traces=8, add_r2=True, add_legend=True, colored_predictions=True, concat_trials=False,
hparams=None, **kwargs):
"""Plot labels and their reconstructions from an ps-vae.
Parameters
----------
lab : :obj:`str`
lab id
expt : :obj:`str`
expt id
animal : :obj:`str`
animal id
session : :obj:`str`
session id
n_ae_latents : :obj:`str`
dimensionality of unsupervised latent space; n_labels will be added to this
experiment_name : :obj:`str`
test-tube experiment name
n_labels : :obj:`str`
dimensionality of supervised latent space
trials : :obj:`array-like`
array of trials to reconstruct
version : :obj:`str` or :obj:`int`, optional
can be 'best' to load best model, and integer to load a specific model, or NoneType to use
the values in hparams to load a specific model
plot_scale : :obj:`float`
scale the magnitude of reconstructions
sess_idx : :obj:`int`, optional
session index into data generator
save_file : :obj:`str`, optional
absolute path of save file; does not need file extension
format : :obj:`str`, optional
format of saved image; 'pdf' | 'png' | 'jpeg' | ...
xtick_locs : :obj:`array-like`, optional
tick locations in units of bins
frame_rate : :obj:`float`, optional
frame rate of behavorial video; to properly relabel xticks
max_traces : :obj:`int`, optional
maximum number of traces to plot, for easier visualization
add_r2 : :obj:`bool`, optional
print R2 value on plot
add_legend : :obj:`bool`, optional
print legend on plot
colored_predictions : :obj:`bool`, optional
color predictions using default seaborn colormap; else predictions are black
concat_trials : :obj:`bool`, optional
True to plot all trials together, separated by a small gap
hparams : :obj:`dict`, optional
If not NoneType, uses these hparams instead of required args
kwargs
arguments are keys of `hparams`, for example to set `train_frac`, `rng_seed_model`, etc.
"""
from behavenet.plotting.decoder_utils import plot_neural_reconstruction_traces
if len(trials) == 1:
concat_trials = False
# set model info
if hparams is None:
hparams = _get_psvae_hparams(
experiment_name=experiment_name, n_ae_latents=n_ae_latents + n_labels, **kwargs)
get_lab_example(hparams, lab, expt)
hparams['animal'] = animal
hparams['session'] = session
model, data_generator = get_best_model_and_data(
hparams, Model=None, load_data=True, version=version, data_kwargs=None)
print(data_generator)
print('alpha: %i' % model.hparams['ps_vae.alpha'])
print('beta: %i' % model.hparams['ps_vae.beta'])
print('model seed: %i' % model.hparams['rng_seed_model'])
n_blank = 5 # buffer time points between trials if concatenating
labels_og_all = []
labels_pred_all = []
for trial in trials:
# collect data
batch = data_generator.datasets[sess_idx][trial]
labels_og = batch['labels'].detach().cpu().numpy()
labels_pred = model.get_predicted_labels(batch['images']).detach().cpu().numpy()
if 'labels_masks' in batch:
labels_masks = batch['labels_masks'].detach().cpu().numpy()
labels_og[labels_masks == 0] = np.nan
# store data
labels_og_all.append(labels_og)
labels_pred_all.append(labels_pred)
if trial != trials[-1]:
labels_og_all.append(np.nan * np.zeros((n_blank, labels_og.shape[1])))
labels_pred_all.append(np.nan * np.zeros((n_blank, labels_pred.shape[1])))
# plot data from single trial
if not concat_trials:
if save_file is not None:
save_file_trial = save_file + '_trial-%i' % trial
else:
save_file_trial = None
plot_neural_reconstruction_traces(
labels_og, labels_pred, scale=plot_scale, save_file=save_file_trial, format=format,
xtick_locs=xtick_locs, frame_rate=frame_rate, max_traces=max_traces, add_r2=add_r2,
add_legend=add_legend, colored_predictions=colored_predictions)
# plot data from all trials
if concat_trials:
if save_file is not None:
save_file_trial = save_file + '_trial-{}'.format(trials)
else:
save_file_trial = None
plot_neural_reconstruction_traces(
np.vstack(labels_og_all), np.vstack(labels_pred_all), scale=plot_scale,
save_file=save_file_trial, format=format,
xtick_locs=xtick_locs, frame_rate=frame_rate, max_traces=max_traces, add_r2=add_r2,
add_legend=add_legend, colored_predictions=colored_predictions)
[docs]def plot_latent_traversals(
lab, expt, animal, session, model_class, alpha, beta, n_ae_latents, rng_seed_model,
experiment_name, n_labels, label_idxs, hparams=None, label_min_p=5, label_max_p=95,
channel=0, n_frames_zs=4, n_frames_zu=4, trial=None, trial_idx=1, batch_idx=1,
crop_type=None, crop_kwargs=None, sess_idx=0, sess_ids=None, save_file=None, format='pdf',
**kwargs):
"""Plot video frames representing the traversal of individual dimensions of the latent space.
Parameters
----------
lab : :obj:`str`
lab id
expt : :obj:`str`
expt id
animal : :obj:`str`
animal id
session : :obj:`str`
session id
model_class : :obj:`str`
model class in which to perform traversal; currently supported models are:
'ae' | 'vae' | 'cond-ae' | 'cond-vae' | 'beta-tcvae' | 'cond-ae-msp' | 'ps-vae'
note that models with conditional encoders are not currently supported
alpha : :obj:`float`
ps-vae alpha value
beta : :obj:`float`
ps-vae beta value
n_ae_latents : :obj:`int`
dimensionality of unsupervised latents
rng_seed_model : :obj:`int`
model seed
experiment_name : :obj:`str`
test-tube experiment name
n_labels : :obj:`str`
dimensionality of supervised latent space (ignored when using fully unsupervised models)
label_idxs : :obj:`array-like`, optional
set of label indices (dimensions) to individually traverse
hparams : :obj:`str`, optional
If not NoneType, uses these hparams instead of required args
label_min_p : :obj:`float`, optional
lower percentile of training data used to compute range of traversal
label_max_p : :obj:`float`, optional
upper percentile of training data used to compute range of traversal
channel : :obj:`int`, optional
image channel to plot
n_frames_zs : :obj:`int`, optional
number of frames (points) to display for traversal through supervised dimensions
n_frames_zu : :obj:`int`, optional
number of frames (points) to display for traversal through unsupervised dimensions
trial : :obj:`int`, optional
trial index into all possible trials (train, val, test); one of `trial` or `trial_idx`
must be specified; `trial` takes precedence over `trial_idx`
trial_idx : :obj:`int`, optional
trial index of base frame used for interpolation
batch_idx : :obj:`int`, optional
batch index of base frame used for interpolation
crop_type : :obj:`str`, optional
cropping method used on interpolated frames
'fixed' | None
crop_kwargs : :obj:`dict`, optional
if crop_type is not None, provides information about the crop
keys for 'fixed' type: 'y_0', 'x_0', 'y_ext', 'x_ext'; window is
(y_0 - y_ext, y_0 + y_ext) in vertical direction and
(x_0 - x_ext, x_0 + x_ext) in horizontal direction
sess_idx : :obj:`int`, optional
session index into data generator
sess_ids : :obj:`list`, optional
each entry is a session dict with keys 'lab', 'expt', 'animal', 'session'; for loading
labels and labels_sc
save_file : :obj:`str`, optional
absolute path of save file; does not need file extension
format : :obj:`str`, optional
format of saved image; 'pdf' | 'png' | 'jpeg' | ...
kwargs
arguments are keys of `hparams`, for example to set `train_frac`, `rng_seed_model`, etc.
"""
if hparams is None:
hparams = _get_psvae_hparams(
model_class=model_class, alpha=alpha, beta=beta, n_ae_latents=n_ae_latents,
experiment_name=experiment_name, rng_seed_model=rng_seed_model, **kwargs)
if model_class == 'cond-ae-msp' or model_class == 'ps-vae':
hparams['n_ae_latents'] += n_labels
if model_class == 'msps-vae':
hparams['n_ae_latents'] += hparams.get('n_background', 0)
# programmatically fill out other hparams options
get_lab_example(hparams, lab, expt)
hparams['animal'] = animal
hparams['session'] = session
hparams['session_dir'], sess_ids = get_session_dir(hparams)
hparams['expt_dir'] = get_expt_dir(hparams)
_, version = experiment_exists(hparams, which_version=True)
model_ae, data_generator = get_best_model_and_data(hparams, Model=None, version=version)
# temporarily set n_sessions_per_batch to 1 for msps; reset at end of function
n_sessions_per_batch = hparams.get('n_sessions_per_batch', 1)
hparams['n_sessions_per_batch'] = 1
n_background = hparams.get('n_background', 0)
model_class = hparams['model_class']
# get latent/label info
latent_range = get_input_range(
'latents', hparams, sess_ids=sess_ids, sess_idx=sess_idx, model=model_ae,
data_gen=data_generator, min_p=15, max_p=85, version=version, apply_label_masks=True)
label_range = get_input_range(
'labels', hparams, sess_ids=sess_ids, sess_idx=sess_idx, min_p=label_min_p,
max_p=label_max_p, apply_label_masks=True)
try:
label_sc_range = get_input_range(
'labels_sc', hparams, sess_ids=sess_ids, sess_idx=sess_idx, min_p=label_min_p,
max_p=label_max_p, apply_label_masks=True)
except KeyError:
import copy
label_sc_range = copy.deepcopy(label_range)
# ----------------------------------------
# label traversals
# ----------------------------------------
interp_func_label = interpolate_1d
plot_func_label = plot_1d_frame_array
tmp = trial_idx if trial_idx is not None else trial
save_file_new = save_file + '_label-traversals_%i-%i' % (tmp, batch_idx)
if model_class == 'cond-ae' or model_class == 'cond-ae-msp' or model_class == 'ps-vae' or \
model_class == 'cond-vae' or model_class == 'msps-vae':
# get model input for this trial
ims_pt, ims_np, latents_np, labels_pt, labels_np, labels_2d_pt, labels_2d_np = \
get_model_input(
data_generator, hparams, model_ae, trial_idx=trial_idx, trial=trial,
compute_latents=True, compute_scaled_labels=False, compute_2d_labels=False,
sess_idx=sess_idx)
if labels_2d_np is None:
labels_2d_np = np.copy(labels_np)
if crop_type == 'fixed':
crop_kwargs_ = crop_kwargs
else:
crop_kwargs_ = None
# perform interpolation
ims_label, markers_loc_label, ims_crop_label = interp_func_label(
'labels', model_ae, ims_pt[None, batch_idx, :], latents_np[None, batch_idx, :],
labels_np[None, batch_idx, :], labels_2d_np[None, batch_idx, :],
mins=label_range['min'], maxes=label_range['max'],
n_frames=n_frames_zs, input_idxs=label_idxs, crop_type=crop_type,
mins_sc=label_sc_range['min'], maxes_sc=label_sc_range['max'],
crop_kwargs=crop_kwargs_, ch=channel)
# plot interpolation
if crop_type:
marker_kwargs = {
'markersize': 30, 'markeredgewidth': 8, 'markeredgecolor': [1, 1, 0],
'fillstyle': 'none'}
plot_func_label(
ims_crop_label, markers=None, marker_kwargs=marker_kwargs, save_file=save_file_new,
format=format)
else:
marker_kwargs = {
'markersize': 20, 'markeredgewidth': 5, 'markeredgecolor': [1, 1, 0],
'fillstyle': 'none'}
plot_func_label(
ims_label, markers=None, marker_kwargs=marker_kwargs, save_file=save_file_new,
format=format)
# ----------------------------------------
# latent traversals
# ----------------------------------------
interp_func_latent = interpolate_1d
plot_func_latent = plot_1d_frame_array
save_file_new = save_file + '_latent-traversals_%i-%i' % (tmp, batch_idx)
if hparams['model_class'] == 'cond-ae-msp' or hparams['model_class'] == 'ps-vae':
if n_ae_latents is None:
n_ae_latents = hparams['n_ae_latents'] - n_labels
latent_idxs = n_labels + np.arange(n_ae_latents)
elif hparams['model_class'] == 'msps-vae':
if n_ae_latents is None:
n_ae_latents = hparams['n_ae_latents'] - n_labels - n_background
latent_idxs = n_labels + n_background + np.arange(n_ae_latents)
elif hparams['model_class'] == 'ae' \
or hparams['model_class'] == 'vae' \
or hparams['model_class'] == 'cond-vae' \
or hparams['model_class'] == 'beta-tcvae':
if n_ae_latents is None:
n_ae_latents = hparams['n_ae_latents']
latent_idxs = np.arange(n_ae_latents)
else:
raise NotImplementedError
# simplify options here
scaled_labels = False
twod_labels = False
crop_type = None
crop_kwargs = None
labels_2d_np_sel = None
# get model input for this trial
ims_pt, ims_np, latents_np, labels_pt, labels_np, labels_2d_pt, labels_2d_np = \
get_model_input(
data_generator, hparams, model_ae, trial=trial, trial_idx=trial_idx,
compute_latents=True, compute_scaled_labels=scaled_labels,
compute_2d_labels=twod_labels, sess_idx=sess_idx)
# latents_np[:, n_labels:] = 0
if hparams['model_class'] == 'ae' or hparams['model_class'] == 'beta-tcvae':
labels_np_sel = labels_np
else:
labels_np_sel = labels_np[None, batch_idx, :]
# perform interpolation
ims_latent, markers_loc_latent_, ims_crop_latent = interp_func_latent(
'latents', model_ae, ims_pt[None, batch_idx, :], latents_np[None, batch_idx, :],
labels_np_sel, labels_2d_np_sel,
mins=latent_range['min'], maxes=latent_range['max'],
n_frames=n_frames_zu, input_idxs=latent_idxs, crop_type=crop_type,
mins_sc=None, maxes_sc=None, crop_kwargs=crop_kwargs, ch=channel)
# plot interpolation
marker_kwargs = {
'markersize': 20, 'markeredgewidth': 5, 'markeredgecolor': [1, 1, 0],
'fillstyle': 'none'}
plot_func_latent(
ims_latent, markers=None, marker_kwargs=marker_kwargs, save_file=save_file_new,
format=format)
hparams['n_sessions_per_batch'] = n_sessions_per_batch
[docs]def make_latent_traversal_movie(
lab, expt, animal, session, model_class, alpha, beta, n_ae_latents,
rng_seed_model, experiment_name, n_labels, trial_idxs, batch_idxs, trials, hparams=None,
label_min_p=5, label_max_p=95, channel=0, sess_idx=0, sess_ids=None, force_sess_ids=False,
n_frames=10, n_buffer_frames=5, crop_kwargs=None, n_cols=3, movie_kwargs={},
panel_titles=None, order_idxs=None, split_movies=False, save_file=None, **kwargs):
"""Create a multi-panel movie with each panel showing traversals of an individual latent dim.
The traversals will start at a lower bound, increase to an upper bound, then return to a lower
bound; the traversal of each dimension occurs simultaneously. It is also possible to specify
multiple base frames for the traversals; the traversal of each base frame is separated by
several blank frames. Note that support for plotting markers on top of the corresponding
supervised dimensions is not supported by this function.
Parameters
----------
lab : :obj:`str`
lab id
expt : :obj:`str`
expt id
animal : :obj:`str`
animal id
session : :obj:`str`
session id
model_class : :obj:`str`
model class in which to perform traversal; currently supported models are:
'ae' | 'vae' | 'cond-ae' | 'cond-vae' | 'ps-vae'
note that models with conditional encoders are not currently supported
alpha : :obj:`float`
ps-vae alpha value
beta : :obj:`float`
ps-vae beta value
n_ae_latents : :obj:`int`
dimensionality of unsupervised latents
rng_seed_model : :obj:`int`
model seed
experiment_name : :obj:`str`
test-tube experiment name
n_labels : :obj:`str`
dimensionality of supervised latent space (ignored when using fully unsupervised models)
trial_idxs : :obj:`array-like` of :obj:`int`
trial indices of base frames used for interpolation; if an entry is an integer, the
corresponding entry in `trials` must be `None`. This value is a trial index into all
*test* trials, and is not affected by how the test trials are shuffled. The `trials`
argument (see below) takes precedence over `trial_idxs`.
batch_idxs : :obj:`array-like` of :obj:`int`
batch indices of base frames used for interpolation; correspond to entries in `trial_idxs`
and `trials`
trials : :obj:`array-like` of :obj:`int`
trials of base frame used for interpolation; if an entry is an integer, the
corresponding entry in `trial_idxs` must be `None`. This value is a trial index into all
possible trials (train, val, test), whereas `trial_idxs` is an index only into test trials
hparams : :obj:`str`, optional
If not NoneType, uses these hparams instead of required args
label_min_p : :obj:`float`, optional
lower percentile of training data used to compute range of traversal
label_max_p : :obj:`float`, optional
upper percentile of training data used to compute range of traversal
channel : :obj:`int`, optional
image channel to plot
sess_idx : :obj:`int`, optional
session index into data generator
sess_ids : :obj:`list`, optional
each entry is a session dict with keys 'lab', 'expt', 'animal', 'session'; for loading
labels and labels_sc
force_sess_ids : :obj:`bool`, optional
True to force the creation of a new data generator based on the provided sess_ids, rather
than the default associated with the model; necessary for performing latent traversals on
sessions that were not used for training
n_frames : :obj:`int`, optional
number of frames (points) to display for traversal across latent dimensions; the movie
will display a traversal of `n_frames` across each dim, then another traversal of
`n_frames` in the opposite direction
n_buffer_frames : :obj:`int`, optional
number of blank frames to insert between base frames
crop_kwargs : :obj:`dict`, optional
if crop_type is not None, provides information about the crop (for a fixed crop window)
keys : 'y_0', 'x_0', 'y_ext', 'x_ext'; window is
(y_0 - y_ext, y_0 + y_ext) in vertical direction and
(x_0 - x_ext, x_0 + x_ext) in horizontal direction
n_cols : :obj:`int`, optional
movie is `n_cols` panels wide
movie_kwargs : :obj:`dict`, optional
additional kwargs for individual panels; possible keys are 'markersize', 'markeredgecolor',
'markeredgewidth', and 'text_color'
panel_titles : :obj:`list` of :obj:`str`, optional
optional titles for each panel
order_idxs : :obj:`array-like`, optional
used to reorder panels (which are plotted in row-major order) if desired; can also be used
to choose a subset of latent dimensions to include
split_movies : :obj:`bool`, optional
True to save a separate latent traversal movie for each latent dimension
save_file : :obj:`str`, optional
absolute path of save file; does not need file extension, will automatically be saved as
mp4. To save as a gif, include the '.gif' file extension in `save_file`
hparams : :obj:`dict`, optional
kwargs
arguments are keys of `hparams`, for example to set `train_frac`, `rng_seed_model`, etc.
"""
if hparams is None:
hparams = _get_psvae_hparams(
model_class=model_class, alpha=alpha, beta=beta, n_ae_latents=n_ae_latents,
experiment_name=experiment_name, rng_seed_model=rng_seed_model, **kwargs)
if model_class == 'cond-ae-msp' or model_class == 'ps-vae' or model_class == 'beta-tcvae':
hparams['n_ae_latents'] += n_labels
elif model_class == 'msps-vae':
hparams['n_ae_latents'] += n_labels + hparams['n_background']
get_lab_example(hparams, lab, expt)
hparams['animal'] = animal
hparams['session'] = session
hparams['session_dir'], sess_ids = get_session_dir(hparams)
hparams['expt_dir'] = get_expt_dir(hparams)
_, version = experiment_exists(hparams, which_version=True)
# load model and data
if force_sess_ids:
model_ae, _ = get_best_model_and_data(hparams, version=version, load_data=False)
data_generator = build_data_generator(hparams, sess_ids, export_csv=False)
else:
model_ae, data_generator = get_best_model_and_data(hparams, version=version)
# temporarily set n_sessions_per_batch to 1 for msps; reset at end of function
n_sessions_per_batch = hparams.get('n_sessions_per_batch', 1)
hparams['n_sessions_per_batch'] = 1
n_background = hparams.get('n_background', 0)
panel_titles = [''] * (n_labels + n_background + n_ae_latents) if panel_titles is None \
else panel_titles
# get latent/label info
# latent_range = get_input_range(
# 'latents', hparams, sess_ids=sess_ids, sess_idx=sess_idx, model=model_ae,
# data_gen=data_generator, min_p=15, max_p=85, version=version, apply_label_masks=True)
latent_range = get_input_range(
'latents', hparams, sess_ids=sess_ids, sess_idx=np.arange(len(sess_ids)), model=model_ae,
data_gen=data_generator, min_p=15, max_p=85, version=version, apply_label_masks=True)
label_range = get_input_range(
'labels', hparams, sess_ids=sess_ids, sess_idx=sess_idx, min_p=label_min_p,
max_p=label_max_p, apply_label_masks=True)
# ----------------------------------------
# collect frames/latents/labels
# ----------------------------------------
if hparams['model_class'] == 'vae':
csl = False
c2dl = False
else:
csl = False
c2dl = False
ims_pt = []
ims_np = []
latents_np = []
labels_pt = []
labels_np = []
# labels_2d_pt = []
# labels_2d_np = []
for trial, trial_idx in zip(trials, trial_idxs):
ims_pt_, ims_np_, latents_np_, labels_pt_, labels_np_, labels_2d_pt_, labels_2d_np_ = \
get_model_input(
data_generator, hparams, model_ae, trial_idx=trial_idx, trial=trial,
sess_idx=sess_idx, compute_latents=True, compute_scaled_labels=csl,
compute_2d_labels=c2dl, max_frames=200)
ims_pt.append(ims_pt_)
ims_np.append(ims_np_)
latents_np.append(latents_np_)
labels_pt.append(labels_pt_)
labels_np.append(labels_np_)
# labels_2d_pt.append(labels_2d_pt_)
# labels_2d_np.append(labels_2d_np_)
if hparams['model_class'] == 'ps-vae':
label_idxs = np.arange(n_labels)
latent_idxs = n_labels + np.arange(n_ae_latents)
elif hparams['model_class'] == 'vae' or hparams['model_class'] == 'beta-tcvae':
label_idxs = []
latent_idxs = np.arange(hparams['n_ae_latents'])
elif hparams['model_class'] == 'cond-vae':
label_idxs = np.arange(n_labels)
latent_idxs = np.arange(hparams['n_ae_latents'])
elif hparams['model_class'] == 'msps-vae':
label_idxs = np.arange(n_labels)
latent_idxs = n_labels + np.arange(n_ae_latents + n_background)
else:
raise Exception
# ----------------------------------------
# label traversals
# ----------------------------------------
ims_all = []
txt_strs_all = []
txt_strs_titles = []
for label_idx in label_idxs:
ims = []
txt_strs = []
for b, batch_idx in enumerate(batch_idxs):
if hparams['model_class'] == 'ps-vae' or hparams['model_class'] == 'msps-vae':
points = np.array([latents_np[b][batch_idx, :]] * 3)
elif hparams['model_class'] == 'cond-vae':
points = np.array([labels_np[b][batch_idx, :]] * 3)
else:
raise Exception
points[0, label_idx] = label_range['min'][label_idx]
points[1, label_idx] = label_range['max'][label_idx]
points[2, label_idx] = label_range['min'][label_idx]
ims_curr, inputs = interpolate_point_path(
'labels', model_ae, ims_pt[b][None, batch_idx, :],
labels_np[b][None, batch_idx, :], points=points, n_frames=n_frames, ch=channel,
crop_kwargs=crop_kwargs)
ims.append(ims_curr)
txt_strs += [panel_titles[label_idx] for _ in range(len(ims_curr))]
if label_idx == 0:
tmp = trial_idxs[b] if trial_idxs[b] is not None else trials[b]
txt_strs_titles += [
'base frame %02i-%02i' % (tmp, batch_idx) for _ in range(len(ims_curr))]
# add blank frames
if len(batch_idxs) > 1:
y_pix, x_pix = ims_curr[0].shape
ims.append([np.zeros((y_pix, x_pix)) for _ in range(n_buffer_frames)])
txt_strs += ['' for _ in range(n_buffer_frames)]
if label_idx == 0:
txt_strs_titles += ['' for _ in range(n_buffer_frames)]
ims_all.append(np.vstack(ims))
txt_strs_all.append(txt_strs)
# ----------------------------------------
# latent traversals
# ----------------------------------------
crop_kwargs_ = None
for latent_idx in latent_idxs:
ims = []
txt_strs = []
for b, batch_idx in enumerate(batch_idxs):
points = np.array([latents_np[b][batch_idx, :]] * 3)
# points[:, latent_idxs] = 0
points[0, latent_idx] = latent_range['min'][latent_idx]
points[1, latent_idx] = latent_range['max'][latent_idx]
points[2, latent_idx] = latent_range['min'][latent_idx]
if hparams['model_class'] == 'vae' or hparams['model_class'] == 'beta-tcvae':
labels_curr = None
else:
labels_curr = labels_np[b][None, batch_idx, :]
ims_curr, inputs = interpolate_point_path(
'latents', model_ae, ims_pt[b][None, batch_idx, :],
labels_curr, points=points, n_frames=n_frames, ch=channel,
crop_kwargs=crop_kwargs_)
ims.append(ims_curr)
if hparams['model_class'] == 'cond-vae':
txt_strs += [panel_titles[latent_idx + n_labels] for _ in range(len(ims_curr))]
else:
txt_strs += [panel_titles[latent_idx] for _ in range(len(ims_curr))]
if latent_idx == 0 and len(label_idxs) == 0:
# add frame ids here if skipping labels
tmp = trial_idxs[b] if trial_idxs[b] is not None else trials[b]
txt_strs_titles += [
'base frame %02i-%02i' % (tmp, batch_idx) for _ in range(len(ims_curr))]
# add blank frames
if len(batch_idxs) > 1:
y_pix, x_pix = ims_curr[0].shape
ims.append([np.zeros((y_pix, x_pix)) for _ in range(n_buffer_frames)])
txt_strs += ['' for _ in range(n_buffer_frames)]
if latent_idx == 0 and len(label_idxs) == 0:
txt_strs_titles += ['' for _ in range(n_buffer_frames)]
ims_all.append(np.vstack(ims))
txt_strs_all.append(txt_strs)
# ----------------------------------------
# make video
# ----------------------------------------
if order_idxs is None:
# don't change order of latents
order_idxs = np.arange(len(ims_all))
if split_movies:
for idx in order_idxs:
if save_file.split('.')[-1] == 'gif':
save_file_new = save_file[:-4] + '_latent-%i.gif' % idx
elif save_file.split('.')[-1] == 'mp4':
save_file_new = save_file[:-4] + '_latent-%i.mp4' % idx
else:
save_file_new = save_file + '_latent-%i' % 0
make_interpolated(
ims=ims_all[idx],
text=txt_strs_all[idx],
text_title=txt_strs_titles,
save_file=save_file_new, scale=3, **movie_kwargs)
else:
make_interpolated_multipanel(
ims=[ims_all[i] for i in order_idxs],
text=[txt_strs_all[i] for i in order_idxs],
text_title=txt_strs_titles,
save_file=save_file, scale=2, n_cols=n_cols, **movie_kwargs)
hparams['n_sessions_per_batch'] = n_sessions_per_batch
[docs]def plot_mspsvae_training_curves(
hparams, alpha, beta, delta, rng_seed_model, n_latents, n_background, n_labels, lab=None,
expt=None, dtype='val', version_dir=None, save_file=None, format='pdf', **kwargs):
"""Create training plots for each term in the ps-vae objective function.
The `dtype` argument controls which type of trials are plotted ('train' or 'val').
Additionally, multiple models can be plotted simultaneously by varying one (and only one) of
the following parameters:
- alpha
- beta
- gamma
- number of unsupervised latents
- random seed used to initialize model weights
Each of these entries must be an array of length 1 except for one option, which can be an array
of arbitrary length (corresponding to already trained models). This function generates a single
plot with panels for each of the following terms:
- total loss
- pixel mse
- label R^2 (note the objective function contains the label MSE, but R^2 is easier to parse)
- KL divergence of supervised latents
- index-code mutual information of unsupervised latents
- total correlation of unsupervised latents
- dimension-wise KL of unsupervised latents
- subspace overlap
Parameters
----------
hparams
alpha : :obj:`array-like`
alpha values to plot
beta : :obj:`array-like`
beta values to plot
delta : :obj:`array-like`
delta values to plot
n_ae_latents : :obj:`array-like`
unsupervised dimensionalities to plot
rng_seeds_model : :obj:`array-like`
model seeds to plot
n_labels : :obj:`int`
dimensionality of supervised latent space
dtype : :obj:`str`
'train' | 'val'
save_file : :obj:`str`, optional
absolute path of save file; does not need file extension
format : :obj:`str`, optional
format of saved image; 'pdf' | 'png' | 'jpeg' | ...
kwargs
arguments are keys of `hparams`, for example to set `train_frac`, `rng_seed_model`, etc.
"""
if dtype == 'val':
hue = 'dataset'
else:
hue = None
metrics_list = [
'loss', 'loss_data_mse', 'label_r2',
'loss_zs_kl', 'loss_zu_mi', 'loss_zu_tc', 'loss_zu_dwkl', 'loss_triplet']
# update hparams
hparams['ps_vae.alpha'] = alpha
hparams['ps_vae.beta'] = beta
hparams['ps_vae.delta'] = delta
hparams['n_ae_latents'] = n_latents + n_background + n_labels
hparams['rng_seed_model'] = rng_seed_model
if version_dir is None:
try:
_, version = experiment_exists(hparams, which_version=True)
print(
'loading results with alpha=%i, beta=%i, delta=%i (version %i)' %
(alpha, beta, delta, version))
metrics_df = load_metrics_csv_as_df(hparams, lab, expt, metrics_list, version=None)
except TypeError:
print('could not find model for alpha=%i, beta=%i, delta=%i' % (alpha, beta, delta))
return None
else:
metrics_df = load_metrics_csv_as_df(
hparams, lab=None, expt=None, metrics_list=metrics_list, version_dir=version_dir)
sns.set_style('white')
sns.set_context('talk')
data_queried = metrics_df[
(metrics_df.epoch > 10) & ~pd.isna(metrics_df.val) & (metrics_df.dtype == dtype)]
g = sns.FacetGrid(
data_queried, col='loss', col_wrap=3, hue=hue, sharey=False, height=4)
g = g.map(plt.plot, 'epoch', 'val').add_legend() # , color=".3", fit_reg=False, x_jitter=.1);
g.fig.subplots_adjust(top=0.9)
g.fig.suptitle('alpha=%i, beta=%i, delta=%i, rng=%i' % (alpha, beta, delta, rng_seed_model))
if save_file is not None:
make_dir_if_not_exists(save_file)
g.savefig(save_file + '.' + format, dpi=300, format=format)
[docs]def plot_mspsvae_hyperparameter_search_results(
hparams, sess_ids, label_names, n_background, alpha_weights, alpha_n_ae_latents,
alpha_expt_name, beta_weights, delta_weights, beta_delta_n_ae_latents,
beta_delta_expt_name, alpha, beta, delta, save_file, batch_size=None, format='pdf',
**kwargs):
"""Create a variety of diagnostic plots to assess the msps-vae hyperparameters.
These diagnostic plots are based on the recommended way to perform a hyperparameter search in
the ps-vae models; first, fix beta=1 and gamma=0, and do a sweep over alpha values and number
of latents (for example alpha=[50, 100, 500, 1000] and n_ae_latents=[2, 4, 8, 16]). The best
alpha value is subjective because it involves a tradeoff between pixel mse and label mse. After
choosing a suitable value, fix alpha and the number of latents and vary beta and gamma. This
function will then plot the following panels:
- pixel mse as a function of alpha/num latents (for fixed beta/gamma)
- label mse as a function of alpha/num_latents (for fixed beta/gamma)
- pixel mse as a function of beta/gamma (for fixed alpha/n_ae_latents)
- label mse as a function of beta/gamma (for fixed alpha/n_ae_latents)
- index-code mutual information (part of the KL decomposition) as a function of beta/gamma (for
fixed alpha/n_ae_latents)
- total correlation(part of the KL decomposition) as a function of beta/gamma (for fixed
alpha/n_ae_latents)
- dimension-wise KL (part of the KL decomposition) as a function of beta/gamma (for fixed
alpha/n_ae_latents)
- average correlation coefficient across all pairs of unsupervised latent dims as a function of
beta/gamma (for fixed alpha/n_ae_latents)
- subspace overlap computed as ||[A; B] - I||_2^2 for A, B the projections to the supervised
and unsupervised subspaces, respectively, and I the identity - as a function of beta/gamma
(for fixed alpha/n_ae_latents)
- example subspace overlap matrix for gamma=0 and beta=1, with fixed alpha/n_ae_latents
- example subspace overlap matrix for gamma=1000 and beta=1, with fixed alpha/n_ae_latents
Parameters
----------
hparams : :obj:`dict`
sess_ids : :obj:`list`
label_names : :obj:`array-like`
names of label dims
n_background : :obj:`int`
dimensionality of background latents
alpha_weights : :obj:`array-like`
array of alpha weights for fixed values of beta, delta
alpha_n_ae_latents : :obj:`array-like`
array of latent dimensionalities for fixed values of beta, delta using alpha_weights
alpha_expt_name : :obj:`str`
test-tube experiment name of alpha-based hyperparam search
beta_weights : :obj:`array-like`
array of beta weights for a fixed value of alpha
delta_weights : :obj:`array-like`
array of beta weights for a fixed value of alpha
beta_delta_n_ae_latents : :obj:`int`
latent dimensionality used for beta-delta hyperparam search
beta_delta_expt_name : :obj:`str`
test-tube experiment name of beta-delta hyperparam search
alpha : :obj:`float`
fixed value of alpha for beta-delta search
beta : :obj:`float`
fixed value of beta for alpha search
delta : :obj:`float`
fixed value of gamma for alpha search
save_file : :obj:`str`
absolute path of save file; does not need file extension
batch_size : :obj:`int`, optional
size of batches, used to compute correlation coefficient per batch; if NoneType, the
correlation coefficient is computed across all time points
format : :obj:`str`, optional
format of saved image; 'pdf' | 'png' | 'jpeg' | ...
kwargs
arguments are keys of `hparams`, preceded by either `alpha_` or `beta_delta_`. For example,
to set the train frac of the alpha models, use `alpha_train_frac`; to set the rng_data_seed
of the beta-delta models, use `beta_delta_rng_data_seed`.
"""
# -----------------------------------------------------
# load pixel/label MSE as a function of n_latents/alpha
# -----------------------------------------------------
n_labels = len(label_names)
# set model info
# hparams = _get_psvae_hparams(experiment_name=alpha_expt_name)
hparams['experiment_name'] = alpha_expt_name
# update hparams
for key, val in kwargs.items():
# hparam vals should be named 'alpha_[property]', for example 'alpha_train_frac'
if key.split('_')[0] == 'alpha':
prop = key[6:]
hparams[prop] = val
metrics_list = ['loss_data_mse']
metrics_dfs_frame = []
metrics_dfs_marker = []
for n_latent in alpha_n_ae_latents:
hparams['n_ae_latents'] = n_latent + n_labels + n_background
hparams['expt_dir'] = get_expt_dir(hparams)
for alpha_ in alpha_weights:
hparams['ps_vae.alpha'] = alpha_
hparams['ps_vae.beta'] = beta
hparams['ps_vae.delta'] = delta
try:
_, version = experiment_exists(hparams, which_version=True)
version_dir = os.path.join(hparams['expt_dir'], 'version_%i' % version)
print('loading results with alpha=%i, beta=%i, delta=%i (version %i)' % (
hparams['ps_vae.alpha'], hparams['ps_vae.beta'], hparams['ps_vae.delta'],
version))
# get frame mse
metrics_dfs_frame.append(load_metrics_csv_as_df(
hparams, None, None, metrics_list, version=None, test=False,
version_dir=version_dir))
metrics_dfs_frame[-1]['alpha'] = alpha_
metrics_dfs_frame[-1]['n_latents'] = hparams['n_ae_latents']
# get marker mse
hparams_new = copy.deepcopy(hparams)
hparams_new['n_sessions_per_bach'] = 1
model, data_gen = get_best_model_and_data(
hparams_new, Model=None, load_data=True, version=version)
metrics_df_ = get_label_r2(
hparams, model, data_gen, version, label_names, dtype='val')
metrics_df_['alpha'] = alpha_
metrics_df_['n_latents'] = hparams['n_ae_latents']
metrics_dfs_marker.append(metrics_df_[metrics_df_.Model == 'MSPS-VAE'])
except TypeError:
print('could not find model for alpha=%i, beta=%i, delta=%i' % (
hparams['ps_vae.alpha'], hparams['ps_vae.beta'], hparams['ps_vae.delta']))
continue
metrics_df_frame = pd.concat(metrics_dfs_frame, sort=False)
metrics_df_marker = pd.concat(metrics_dfs_marker, sort=False)
print('done')
# -----------------------------------------------------
# load pixel/label MSE as a function of beta/delta
# -----------------------------------------------------
# update hparams
hparams['experiment_name'] = beta_delta_expt_name
for key, val in kwargs.items():
# hparam vals should be named 'beta_delta_[property]', for example 'alpha_train_frac'
if key.split('_')[0] == 'beta' and key.split('_')[1] == 'delta':
prop = key[11:]
hparams[prop] = val
metrics_list = ['loss_data_mse', 'loss_zu_mi', 'loss_zu_tc', 'loss_zu_dwkl', 'loss_triplet']
metrics_dfs_frame_bg = []
metrics_dfs_marker_bg = []
metrics_dfs_corr_bg = []
for beta in beta_weights:
for delta in delta_weights:
if delta < 50:
continue
hparams['n_ae_latents'] = beta_delta_n_ae_latents + n_labels + n_background
hparams['ps_vae.alpha'] = alpha
hparams['ps_vae.beta'] = beta
hparams['ps_vae.delta'] = delta
hparams['rng_seed_model'] = 3 if (beta == 10 and delta == 50) else 0
try:
hparams['expt_dir'] = get_expt_dir(hparams)
_, version = experiment_exists(hparams, which_version=True)
version_dir = os.path.join(hparams['expt_dir'], 'version_%i' % version)
print('loading results with alpha=%i, beta=%i, delta=%i (version %i)' % (
hparams['ps_vae.alpha'], hparams['ps_vae.beta'], hparams['ps_vae.delta'],
version))
# get frame mse -------------------------------------------------------------------
metrics_dfs_frame_bg.append(load_metrics_csv_as_df(
hparams, None, None, metrics_list, version=None, test=False,
version_dir=version_dir))
metrics_dfs_frame_bg[-1]['beta'] = beta
metrics_dfs_frame_bg[-1]['delta'] = delta
# get marker mse ------------------------------------------------------------------
hparams_new = copy.deepcopy(hparams)
hparams_new['n_sessions_per_bach'] = 1
model, data_gen = get_best_model_and_data(
hparams_new, Model=None, load_data=True, version=version)
metrics_df_ = get_label_r2(
hparams, model, data_gen, version, label_names, dtype='val')
metrics_df_['beta'] = beta
metrics_df_['delta'] = delta
metrics_dfs_marker_bg.append(metrics_df_[metrics_df_.Model == 'MSPS-VAE'])
# get classification accuracy -----------------------------------------------------
fc, _ = fit_classifier(model, data_gen, dtype='val')
metrics_dfs_corr_bg.append(pd.DataFrame({
'loss': 'fc',
'dtype': 'val',
'val': fc,
'beta': beta,
'delta': delta}, index=[0]))
# get corr ------------------------------------------------------------------------
latents = []
for sess_id in sess_ids:
hparams['lab'] = sess_id['lab']
hparams['expt'] = sess_id['expt']
hparams['animal'] = sess_id['animal']
hparams['session'] = sess_id['session']
latents.append(load_latents(hparams, version, dtype='train'))
latents = np.vstack(latents)
if batch_size is None:
corr = np.corrcoef(latents[:, n_labels + n_background + np.array([0, 1])].T)
metrics_dfs_corr_bg.append(pd.DataFrame({
'loss': 'corr',
'dtype': 'val',
'val': np.abs(corr[0, 1]),
'beta': beta,
'delta': delta}, index=[0]))
else:
n_batches = int(np.ceil(latents.shape[0] / batch_size))
for i in range(n_batches):
corr = np.corrcoef(
latents[i * batch_size:(i + 1) * batch_size,
n_labels + n_background + np.array([0, 1])].T)
metrics_dfs_corr_bg.append(pd.DataFrame({
'loss': 'corr',
'dtype': 'val',
'val': np.abs(corr[0, 1]),
'beta': beta,
'delta': delta}, index=[0]))
except TypeError:
print('could not find model for alpha=%i, beta=%i, delta=%i' % (
hparams['ps_vae.alpha'], hparams['ps_vae.beta'], hparams['ps_vae.delta']))
continue
print()
metrics_df_frame_bg = pd.concat(metrics_dfs_frame_bg, sort=False)
metrics_df_marker_bg = pd.concat(metrics_dfs_marker_bg, sort=False)
metrics_df_corr_bg = pd.concat(metrics_dfs_corr_bg, sort=False)
print('done')
# -----------------------------------------------------
# ----------------- PLOT DATA -------------------------
# -----------------------------------------------------
sns.set_style('white')
sns.set_context('paper', font_scale=1.2)
alpha_palette = sns.color_palette('Greens')
beta_palette = sns.color_palette('Reds', len(metrics_df_corr_bg.beta.unique()))
delta_palette = sns.color_palette('Blues', len(metrics_df_corr_bg.delta.unique()))
from matplotlib.gridspec import GridSpec
fig = plt.figure(figsize=(12, 10), dpi=300)
n_rows = 3
n_cols = 12
gs = GridSpec(n_rows, n_cols, figure=fig)
def despine(ax):
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
sns.set_palette(alpha_palette)
# --------------------------------------------------
# MSE per pixel
# --------------------------------------------------
ax_pixel_mse_alpha = fig.add_subplot(gs[0, 0:3])
data_queried = metrics_df_frame[
(metrics_df_frame.dtype == 'val') &
(metrics_df_frame.epoch == metrics_df_frame.epoch.max())]
sns.barplot(x='n_latents', y='val', hue='alpha', data=data_queried, ax=ax_pixel_mse_alpha)
ax_pixel_mse_alpha.legend().set_visible(False)
ax_pixel_mse_alpha.set_xlabel('Latent dimension')
ax_pixel_mse_alpha.set_ylabel('MSE per pixel')
ax_pixel_mse_alpha.ticklabel_format(axis='y', style='sci', scilimits=(-3, 3))
ax_pixel_mse_alpha.set_title('Beta=%i, Delta=%i' % (beta, delta))
despine(ax_pixel_mse_alpha)
# --------------------------------------------------
# MSE per marker
# --------------------------------------------------
ax_marker_mse_alpha = fig.add_subplot(gs[0, 3:6])
data_queried = metrics_df_marker
sns.barplot(x='n_latents', y='MSE', hue='alpha', data=data_queried, ax=ax_marker_mse_alpha)
ax_marker_mse_alpha.set_xlabel('Latent dimension')
ax_marker_mse_alpha.set_ylabel('MSE per marker')
ax_marker_mse_alpha.set_title('Beta=%i, Delta=%i' % (beta, delta))
ax_marker_mse_alpha.legend(frameon=True, title='Alpha')
despine(ax_marker_mse_alpha)
sns.set_palette(delta_palette)
# --------------------------------------------------
# MSE per pixel (beta/delta)
# --------------------------------------------------
ax_pixel_mse_bg = fig.add_subplot(gs[0, 6:9])
data_queried = metrics_df_frame_bg[
(metrics_df_frame_bg.dtype == 'val') &
(metrics_df_frame_bg.loss == 'loss_data_mse') &
(metrics_df_frame_bg.epoch == 200)]
sns.barplot(x='beta', y='val', hue='delta', data=data_queried, ax=ax_pixel_mse_bg)
ax_pixel_mse_bg.legend().set_visible(False)
ax_pixel_mse_bg.set_xlabel('Beta')
ax_pixel_mse_bg.set_ylabel('MSE per pixel')
ax_pixel_mse_bg.ticklabel_format(axis='y', style='sci', scilimits=(-3, 3))
ax_pixel_mse_bg.set_title('Latents=%i, Alpha=%i' % (hparams['n_ae_latents'], alpha))
despine(ax_pixel_mse_bg)
# --------------------------------------------------
# MSE per marker (beta/delta)
# --------------------------------------------------
ax_marker_mse_bg = fig.add_subplot(gs[0, 9:12])
data_queried = metrics_df_marker_bg
sns.barplot(x='beta', y='MSE', hue='delta', data=data_queried, ax=ax_marker_mse_bg)
ax_marker_mse_bg.set_xlabel('Beta')
ax_marker_mse_bg.set_ylabel('MSE per marker')
ax_marker_mse_bg.set_title('Latents=%i, Alpha=%i' % (hparams['n_ae_latents'], alpha))
ax_marker_mse_bg.legend(frameon=True, title='Delta', loc='lower left')
despine(ax_marker_mse_bg)
# --------------------------------------------------
# ICMI
# --------------------------------------------------
ax_icmi = fig.add_subplot(gs[1, 0:4])
data_queried = metrics_df_frame_bg[
(metrics_df_frame_bg.dtype == 'val') &
(metrics_df_frame_bg.loss == 'loss_zu_mi') &
(metrics_df_frame_bg.epoch == 200)]
sns.lineplot(
x='beta', y='val', hue='delta', data=data_queried, ax=ax_icmi, ci=None,
palette=delta_palette)
ax_icmi.legend().set_visible(False)
ax_icmi.set_xlabel('Beta')
ax_icmi.set_ylabel('Index-code Mutual Information')
ax_icmi.set_title('Latents=%i, Alpha=%i' % (hparams['n_ae_latents'], alpha))
despine(ax_icmi)
# --------------------------------------------------
# TC
# --------------------------------------------------
ax_tc = fig.add_subplot(gs[1, 4:8])
data_queried = metrics_df_frame_bg[
(metrics_df_frame_bg.dtype == 'val') &
(metrics_df_frame_bg.loss == 'loss_zu_tc') &
(metrics_df_frame_bg.epoch == 200)]
sns.lineplot(
x='beta', y='val', hue='delta', data=data_queried, ax=ax_tc, ci=None,
palette=delta_palette)
ax_tc.legend().set_visible(False)
ax_tc.set_xlabel('Beta')
ax_tc.set_ylabel('Total Correlation')
ax_tc.set_title('Latents=%i, Alpha=%i' % (hparams['n_ae_latents'], alpha))
despine(ax_tc)
# --------------------------------------------------
# DWKL
# --------------------------------------------------
ax_dwkl = fig.add_subplot(gs[1, 8:12])
data_queried = metrics_df_frame_bg[
(metrics_df_frame_bg.dtype == 'val') &
(metrics_df_frame_bg.loss == 'loss_zu_dwkl') &
(metrics_df_frame_bg.epoch == 200)]
sns.lineplot(
x='beta', y='val', hue='delta', data=data_queried, ax=ax_dwkl, ci=None,
palette=delta_palette)
ax_dwkl.legend().set_visible(False)
ax_dwkl.set_xlabel('Beta')
ax_dwkl.set_ylabel('Dimension-wise KL')
ax_dwkl.set_title('Latents=%i, Alpha=%i' % (hparams['n_ae_latents'], alpha))
despine(ax_dwkl)
# --------------------------------------------------
# CC
# --------------------------------------------------
ax_cc = fig.add_subplot(gs[2, 0:4])
data_queried = metrics_df_corr_bg[metrics_df_corr_bg.loss == 'corr']
sns.lineplot(
x='beta', y='val', hue='delta', data=data_queried, ax=ax_cc, ci=None,
palette=delta_palette)
ax_cc.legend().set_visible(False)
ax_cc.set_xlabel('Beta')
ax_cc.set_ylabel('Correlation Coefficient')
ax_cc.set_title('Latents=%i, Alpha=%i' % (hparams['n_ae_latents'], alpha))
despine(ax_cc)
# --------------------------------------------------
# session classification
# --------------------------------------------------
ax_fc = fig.add_subplot(gs[2, 4:8])
data_queried = metrics_df_corr_bg[metrics_df_corr_bg.loss == 'fc']
sns.lineplot(
x='beta', y='val', hue='delta', data=data_queried, ax=ax_fc, ci=None,
palette=delta_palette)
ax_fc.legend().set_visible(False)
ax_fc.set_xlabel('Beta')
ax_fc.set_ylabel('Session Classification')
ax_fc.set_title('Latents=%i, Alpha=%i' % (hparams['n_ae_latents'], alpha))
despine(ax_fc)
# --------------------------------------------------
# triplet loss
# --------------------------------------------------
ax_orth = fig.add_subplot(gs[2, 8:12])
data_queried = metrics_df_frame_bg[
(metrics_df_frame_bg.dtype == 'train') &
(metrics_df_frame_bg.loss == 'loss_triplet') &
(metrics_df_frame_bg.epoch == 200) &
~metrics_df_frame_bg.val.isna()]
sns.lineplot(
x='delta', y='val', hue='beta', data=data_queried, ax=ax_orth, ci=None,
palette=beta_palette)
ax_orth.legend(frameon=False, title='Beta')
ax_orth.set_xlabel('Delta')
ax_orth.set_ylabel('Triplet loss')
ax_orth.set_title('Latents=%i, Alpha=%i' % (hparams['n_ae_latents'], alpha))
despine(ax_orth)
plt.tight_layout(h_pad=3) # h_pad is fraction of font size
# reset to default color palette
# sns.set_palette(sns.color_palette(None, 10))
sns.reset_orig()
if save_file is not None:
make_dir_if_not_exists(save_file)
plt.savefig(save_file + '.' + format, dpi=300, format=format)
[docs]def make_session_swap_movie(
sess_ids, hparams, version, n_labels, n_background, sess_idx, trials, trial_idxs=None,
n_buffer_frames=5, frame_rate=15, layout_pattern=None, save_file=None, **kwargs):
"""Create a multipanel movie, each panel showing reconstruction with different session context.
TODO
Parameters
----------
sess_ids : :obj:`list` of `dicts`
hparams
version
n_labels : :obj:`int`
dimensionality of supervised latent space (ignored when using fully unsupervised models)
n_background : :obj:`int`
sess_idx : :obj:`int`
session index into data generator
trials : :obj:`array-like` of :obj:`int`
trials of base frame used for interpolation; if an entry is an integer, the
corresponding entry in `trial_idxs` must be `None`. This value is a trial index into all
possible trials (train, val, test), whereas `trial_idxs` is an index only into test trials
trial_idxs : :obj:`list`, optional
list of test trials to construct videos from; if :obj:`NoneType`, use first test
trial
n_buffer_frames : :obj:`int`, optional
number of blank frames to insert between base frames
frame_rate
layout_pattern : :obj:`np.ndarray`
boolean array that determines where reconstructed frames are placed in a grid
save_file : :obj:`str`, optional
absolute path of save file; does not need file extension, will automatically be saved as
mp4. To save as a gif, include the '.gif' file extension in `save_file`
kwargs
arguments are keys of `hparams`, for example to set `train_frac`, `rng_seed_model`, etc.
"""
from behavenet.plotting.ae_utils import make_reconstruction_movie
panel_titles = ['Original'] + ['Transfer %i' % i for i in range(len(sess_ids) - 1)]
# load standard data generator
hp = copy.deepcopy(hparams)
hp['n_sessions_per_batch'] = 1
model_ae, data_generator = get_best_model_and_data(hp, Model=None, version=version)
# get latent/label info
background_idxs = np.arange(n_labels, n_labels + n_background)
background_medians = []
for s in range(len(sess_ids)):
latent_range = get_input_range(
'latents', hp, sess_ids=sess_ids, sess_idx=s, model=model_ae,
data_gen=data_generator, min_p=15, max_p=85, version=version,
apply_label_masks=True)
background_medians.append(latent_range['med'][background_idxs])
if trial_idxs is None:
trial_idxs = [None] * len(trials)
if trials is None:
trials = [None] * len(trial_idxs)
ims_recon = [[] for _ in panel_titles]
for trial_idx, trial in zip(trial_idxs, trials):
# get model inputs
ims_orig_pt, ims_orig_np, latents_np, labels_pt, _, labels_2d_pt, _ = get_model_input(
data_generator, hp, model_ae, trial_idx=trial_idx, trial=trial,
sess_idx=sess_idx, max_frames=400, compute_latents=True,
compute_2d_labels=False)
for s in range(len(sess_ids)):
# get model outputs
# if s == sess_idx:
# # get normal reconstruction
# ims_recon_tmp = get_reconstruction(
# model_ae, ims_orig_pt, labels=labels_pt, labels_2d=labels_2d_pt,
# dataset=sess_idx)
# else:
# swap out context latents
latents_np[:, background_idxs] = background_medians[s]
ims_recon_tmp = get_reconstruction(
model_ae, latents_np, apply_inverse_transform=False)
ims_recon[s].append(ims_recon_tmp)
# add a couple black frames to separate trials
final_trial = True
if (trial_idx is not None and (trial_idx != trial_idxs[-1])) or \
(trial is not None and (trial != trials[-1])):
final_trial = False
if not final_trial:
_, n, y_p, x_p = ims_recon[s][-1].shape
ims_recon[s].append(np.zeros((n_buffer_frames, n, y_p, x_p)))
# concatenate everything along time dimension
for i, ims in enumerate(ims_recon):
ims_recon[i] = np.concatenate(ims, axis=0)
# put original session in first position
if sess_idx != 0:
ims_recon[0], ims_recon[sess_idx] = ims_recon[sess_idx], ims_recon[0]
if layout_pattern is None:
if len(panel_titles) < 4:
n_rows, n_cols = 1, len(panel_titles)
elif len(panel_titles) == 4:
n_rows, n_cols = 2, 2
elif len(panel_titles) > 4:
n_rows, n_cols = 2, 3
else:
raise ValueError('too many sessions')
else:
assert np.sum(layout_pattern) == len(ims_recon)
n_rows, n_cols = layout_pattern.shape
count = 0
for pos_r in layout_pattern:
for pos_c in pos_r:
if not pos_c:
ims_recon.insert(count, [])
panel_titles.insert(count, [])
count += 1
make_reconstruction_movie(
ims=ims_recon, titles=panel_titles, n_rows=n_rows, n_cols=n_cols,
save_file=save_file, frame_rate=frame_rate)