plot_mspsvae_hyperparameter_search_results¶
- behavenet.plotting.cond_ae_utils.plot_mspsvae_hyperparameter_search_results(hparams, sess_ids, label_names, n_background, alpha_weights, alpha_n_ae_latents, alpha_expt_name, beta_weights, delta_weights, beta_delta_n_ae_latents, beta_delta_expt_name, alpha, beta, delta, save_file, batch_size=None, format='pdf', **kwargs)[source]¶
Create a variety of diagnostic plots to assess the msps-vae hyperparameters.
These diagnostic plots are based on the recommended way to perform a hyperparameter search in the ps-vae models; first, fix beta=1 and gamma=0, and do a sweep over alpha values and number of latents (for example alpha=[50, 100, 500, 1000] and n_ae_latents=[2, 4, 8, 16]). The best alpha value is subjective because it involves a tradeoff between pixel mse and label mse. After choosing a suitable value, fix alpha and the number of latents and vary beta and gamma. This function will then plot the following panels:
pixel mse as a function of alpha/num latents (for fixed beta/gamma)
label mse as a function of alpha/num_latents (for fixed beta/gamma)
pixel mse as a function of beta/gamma (for fixed alpha/n_ae_latents)
label mse as a function of beta/gamma (for fixed alpha/n_ae_latents)
index-code mutual information (part of the KL decomposition) as a function of beta/gamma (for fixed alpha/n_ae_latents)
total correlation(part of the KL decomposition) as a function of beta/gamma (for fixed alpha/n_ae_latents)
dimension-wise KL (part of the KL decomposition) as a function of beta/gamma (for fixed alpha/n_ae_latents)
average correlation coefficient across all pairs of unsupervised latent dims as a function of beta/gamma (for fixed alpha/n_ae_latents)
subspace overlap computed as ||[A; B] - I||_2^2 for A, B the projections to the supervised and unsupervised subspaces, respectively, and I the identity - as a function of beta/gamma (for fixed alpha/n_ae_latents)
example subspace overlap matrix for gamma=0 and beta=1, with fixed alpha/n_ae_latents
example subspace overlap matrix for gamma=1000 and beta=1, with fixed alpha/n_ae_latents
- Parameters:
hparams (
dict) –sess_ids (
list) –label_names (
array-like) – names of label dimsn_background (
int) – dimensionality of background latentsalpha_weights (
array-like) – array of alpha weights for fixed values of beta, deltaalpha_n_ae_latents (
array-like) – array of latent dimensionalities for fixed values of beta, delta using alpha_weightsalpha_expt_name (
str) – test-tube experiment name of alpha-based hyperparam searchbeta_weights (
array-like) – array of beta weights for a fixed value of alphadelta_weights (
array-like) – array of beta weights for a fixed value of alphabeta_delta_n_ae_latents (
int) – latent dimensionality used for beta-delta hyperparam searchbeta_delta_expt_name (
str) – test-tube experiment name of beta-delta hyperparam searchalpha (
float) – fixed value of alpha for beta-delta searchbeta (
float) – fixed value of beta for alpha searchdelta (
float) – fixed value of gamma for alpha searchsave_file (
str) – absolute path of save file; does not need file extensionbatch_size (
int, optional) – size of batches, used to compute correlation coefficient per batch; if NoneType, the correlation coefficient is computed across all time pointsformat (
str, optional) – format of saved image; ‘pdf’ | ‘png’ | ‘jpeg’ | …kwargs – arguments are keys of hparams, preceded by either alpha_ or beta_delta_. For example, to set the train frac of the alpha models, use alpha_train_frac; to set the rng_data_seed of the beta-delta models, use beta_delta_rng_data_seed.