plot_psvae_training_curves¶
- behavenet.plotting.cond_ae_utils.plot_psvae_training_curves(lab, expt, animal, session, alphas, betas, n_ae_latents, rng_seeds_model, experiment_name, n_labels, dtype='val', save_file=None, format='pdf', **kwargs)[source]¶
Create training plots for each term in the ps-vae objective function.
The dtype argument controls which type of trials are plotted (‘train’ or ‘val’). Additionally, multiple models can be plotted simultaneously by varying one (and only one) of the following parameters:
alpha
beta
number of unsupervised latents
random seed used to initialize model weights
Each of these entries must be an array of length 1 except for one option, which can be an array of arbitrary length (corresponding to already trained models). This function generates a single plot with panels for each of the following terms:
total loss
pixel mse
label R^2 (note the objective function contains the label MSE, but R^2 is easier to parse)
KL divergence of supervised latents
index-code mutual information of unsupervised latents
total correlation of unsupervised latents
dimension-wise KL of unsupervised latents
subspace overlap
- Parameters:
lab (
str) – lab idexpt (
str) – expt idanimal (
str) – animal idsession (
str) – session idalphas (
array-like) – alpha values to plotbetas (
array-like) – beta values to plotn_ae_latents (
array-like) – unsupervised dimensionalities to plotrng_seeds_model (
array-like) – model seeds to plotexperiment_name (
str) – test-tube experiment namen_labels (
int) – dimensionality of supervised latent spacedtype (
str) – ‘train’ | ‘val’save_file (
str, optional) – absolute path of save file; does not need file extensionformat (
str, optional) – format of saved image; ‘pdf’ | ‘png’ | ‘jpeg’ | …kwargs – arguments are keys of hparams, for example to set train_frac, rng_seed_model, etc.