plot_latent_traversals¶
- behavenet.plotting.cond_ae_utils.plot_latent_traversals(lab, expt, animal, session, model_class, alpha, beta, n_ae_latents, rng_seed_model, experiment_name, n_labels, label_idxs, hparams=None, label_min_p=5, label_max_p=95, channel=0, n_frames_zs=4, n_frames_zu=4, trial=None, trial_idx=1, batch_idx=1, crop_type=None, crop_kwargs=None, sess_idx=0, sess_ids=None, save_file=None, format='pdf', **kwargs)[source]¶
Plot video frames representing the traversal of individual dimensions of the latent space.
- Parameters:
lab (
str) – lab id
- expt
str expt id
- animal
str animal id
- session
str session id
- model_class
str model class in which to perform traversal; currently supported models are: ‘ae’ | ‘vae’ | ‘cond-ae’ | ‘cond-vae’ | ‘beta-tcvae’ | ‘cond-ae-msp’ | ‘ps-vae’ note that models with conditional encoders are not currently supported
- alpha
float ps-vae alpha value
- beta
float ps-vae beta value
- n_ae_latents
int dimensionality of unsupervised latents
- rng_seed_model
int model seed
- experiment_name
str test-tube experiment name
- n_labels
str dimensionality of supervised latent space (ignored when using fully unsupervised models)
- label_idxs
array-like, optional set of label indices (dimensions) to individually traverse
- hparams
str, optional If not NoneType, uses these hparams instead of required args
- label_min_p
float, optional lower percentile of training data used to compute range of traversal
- label_max_p
float, optional upper percentile of training data used to compute range of traversal
- channel
int, optional image channel to plot
- n_frames_zs
int, optional number of frames (points) to display for traversal through supervised dimensions
- n_frames_zu
int, optional number of frames (points) to display for traversal through unsupervised dimensions
- trial
int, optional trial index into all possible trials (train, val, test); one of trial or trial_idx must be specified; trial takes precedence over trial_idx
- trial_idx
int, optional trial index of base frame used for interpolation
- batch_idx
int, optional batch index of base frame used for interpolation
- crop_type
str, optional cropping method used on interpolated frames ‘fixed’ | None
- crop_kwargs
dict, optional if crop_type is not None, provides information about the crop keys for ‘fixed’ type: ‘y_0’, ‘x_0’, ‘y_ext’, ‘x_ext’; window is (y_0 - y_ext, y_0 + y_ext) in vertical direction and (x_0 - x_ext, x_0 + x_ext) in horizontal direction
- sess_idx
int, optional session index into data generator
- sess_ids
list, optional each entry is a session dict with keys ‘lab’, ‘expt’, ‘animal’, ‘session’; for loading labels and labels_sc
- save_file
str, optional absolute path of save file; does not need file extension
- format
str, optional format of saved image; ‘pdf’ | ‘png’ | ‘jpeg’ | …
- kwargs
arguments are keys of hparams, for example to set train_frac, rng_seed_model, etc.