plot_label_reconstructions

behavenet.plotting.cond_ae_utils.plot_label_reconstructions(lab, expt, animal, session, n_ae_latents, experiment_name, n_labels, trials, version=None, plot_scale=0.5, sess_idx=0, save_file=None, format='pdf', xtick_locs=None, frame_rate=None, max_traces=8, add_r2=True, add_legend=True, colored_predictions=True, concat_trials=False, hparams=None, **kwargs)[source]

Plot labels and their reconstructions from an ps-vae.

Parameters:
  • lab (str) – lab id

  • expt (str) – expt id

  • animal (str) – animal id

  • session (str) – session id

  • n_ae_latents (str) – dimensionality of unsupervised latent space; n_labels will be added to this

  • experiment_name (str) – test-tube experiment name

  • n_labels (str) – dimensionality of supervised latent space

  • trials (array-like) – array of trials to reconstruct

  • version (str or int, optional) – can be ‘best’ to load best model, and integer to load a specific model, or NoneType to use the values in hparams to load a specific model

  • plot_scale (float) – scale the magnitude of reconstructions

  • sess_idx (int, optional) – session index into data generator

  • save_file (str, optional) – absolute path of save file; does not need file extension

  • format (str, optional) – format of saved image; ‘pdf’ | ‘png’ | ‘jpeg’ | …

  • xtick_locs (array-like, optional) – tick locations in units of bins

  • frame_rate (float, optional) – frame rate of behavorial video; to properly relabel xticks

  • max_traces (int, optional) – maximum number of traces to plot, for easier visualization

  • add_r2 (bool, optional) – print R2 value on plot

  • add_legend (bool, optional) – print legend on plot

  • colored_predictions (bool, optional) – color predictions using default seaborn colormap; else predictions are black

  • concat_trials (bool, optional) – True to plot all trials together, separated by a small gap

  • hparams (dict, optional) – If not NoneType, uses these hparams instead of required args

  • kwargs – arguments are keys of hparams, for example to set train_frac, rng_seed_model, etc.