plot_psvae_training_curves

behavenet.plotting.cond_ae_utils.plot_psvae_training_curves(lab, expt, animal, session, alphas, betas, n_ae_latents, rng_seeds_model, experiment_name, n_labels, dtype='val', save_file=None, format='pdf', **kwargs)[source]

Create training plots for each term in the ps-vae objective function.

The dtype argument controls which type of trials are plotted (‘train’ or ‘val’). Additionally, multiple models can be plotted simultaneously by varying one (and only one) of the following parameters:

  • alpha

  • beta

  • number of unsupervised latents

  • random seed used to initialize model weights

Each of these entries must be an array of length 1 except for one option, which can be an array of arbitrary length (corresponding to already trained models). This function generates a single plot with panels for each of the following terms:

  • total loss

  • pixel mse

  • label R^2 (note the objective function contains the label MSE, but R^2 is easier to parse)

  • KL divergence of supervised latents

  • index-code mutual information of unsupervised latents

  • total correlation of unsupervised latents

  • dimension-wise KL of unsupervised latents

  • subspace overlap

Parameters:
  • lab (str) – lab id

  • expt (str) – expt id

  • animal (str) – animal id

  • session (str) – session id

  • alphas (array-like) – alpha values to plot

  • betas (array-like) – beta values to plot

  • n_ae_latents (array-like) – unsupervised dimensionalities to plot

  • rng_seeds_model (array-like) – model seeds to plot

  • experiment_name (str) – test-tube experiment name

  • n_labels (int) – dimensionality of supervised latent space

  • dtype (str) – ‘train’ | ‘val’

  • save_file (str, optional) – absolute path of save file; does not need file extension

  • format (str, optional) – format of saved image; ‘pdf’ | ‘png’ | ‘jpeg’ | …

  • kwargs – arguments are keys of hparams, for example to set train_frac, rng_seed_model, etc.