plot_latent_traversals

behavenet.plotting.cond_ae_utils.plot_latent_traversals(lab, expt, animal, session, model_class, alpha, beta, n_ae_latents, rng_seed_model, experiment_name, n_labels, label_idxs, hparams=None, label_min_p=5, label_max_p=95, channel=0, n_frames_zs=4, n_frames_zu=4, trial=None, trial_idx=1, batch_idx=1, crop_type=None, crop_kwargs=None, sess_idx=0, sess_ids=None, save_file=None, format='pdf', **kwargs)[source]

Plot video frames representing the traversal of individual dimensions of the latent space.

Parameters:

lab (str) – lab id

exptstr

expt id

animalstr

animal id

sessionstr

session id

model_classstr

model class in which to perform traversal; currently supported models are: ‘ae’ | ‘vae’ | ‘cond-ae’ | ‘cond-vae’ | ‘beta-tcvae’ | ‘cond-ae-msp’ | ‘ps-vae’ note that models with conditional encoders are not currently supported

alphafloat

ps-vae alpha value

betafloat

ps-vae beta value

n_ae_latentsint

dimensionality of unsupervised latents

rng_seed_modelint

model seed

experiment_namestr

test-tube experiment name

n_labelsstr

dimensionality of supervised latent space (ignored when using fully unsupervised models)

label_idxsarray-like, optional

set of label indices (dimensions) to individually traverse

hparamsstr, optional

If not NoneType, uses these hparams instead of required args

label_min_pfloat, optional

lower percentile of training data used to compute range of traversal

label_max_pfloat, optional

upper percentile of training data used to compute range of traversal

channelint, optional

image channel to plot

n_frames_zsint, optional

number of frames (points) to display for traversal through supervised dimensions

n_frames_zuint, optional

number of frames (points) to display for traversal through unsupervised dimensions

trialint, optional

trial index into all possible trials (train, val, test); one of trial or trial_idx must be specified; trial takes precedence over trial_idx

trial_idxint, optional

trial index of base frame used for interpolation

batch_idxint, optional

batch index of base frame used for interpolation

crop_typestr, optional

cropping method used on interpolated frames ‘fixed’ | None

crop_kwargsdict, optional

if crop_type is not None, provides information about the crop keys for ‘fixed’ type: ‘y_0’, ‘x_0’, ‘y_ext’, ‘x_ext’; window is (y_0 - y_ext, y_0 + y_ext) in vertical direction and (x_0 - x_ext, x_0 + x_ext) in horizontal direction

sess_idxint, optional

session index into data generator

sess_idslist, optional

each entry is a session dict with keys ‘lab’, ‘expt’, ‘animal’, ‘session’; for loading labels and labels_sc

save_filestr, optional

absolute path of save file; does not need file extension

formatstr, optional

format of saved image; ‘pdf’ | ‘png’ | ‘jpeg’ | …

kwargs

arguments are keys of hparams, for example to set train_frac, rng_seed_model, etc.