plot_mspsvae_hyperparameter_search_results

behavenet.plotting.cond_ae_utils.plot_mspsvae_hyperparameter_search_results(hparams, sess_ids, label_names, n_background, alpha_weights, alpha_n_ae_latents, alpha_expt_name, beta_weights, delta_weights, beta_delta_n_ae_latents, beta_delta_expt_name, alpha, beta, delta, save_file, batch_size=None, format='pdf', **kwargs)[source]

Create a variety of diagnostic plots to assess the msps-vae hyperparameters.

These diagnostic plots are based on the recommended way to perform a hyperparameter search in the ps-vae models; first, fix beta=1 and gamma=0, and do a sweep over alpha values and number of latents (for example alpha=[50, 100, 500, 1000] and n_ae_latents=[2, 4, 8, 16]). The best alpha value is subjective because it involves a tradeoff between pixel mse and label mse. After choosing a suitable value, fix alpha and the number of latents and vary beta and gamma. This function will then plot the following panels:

  • pixel mse as a function of alpha/num latents (for fixed beta/gamma)

  • label mse as a function of alpha/num_latents (for fixed beta/gamma)

  • pixel mse as a function of beta/gamma (for fixed alpha/n_ae_latents)

  • label mse as a function of beta/gamma (for fixed alpha/n_ae_latents)

  • index-code mutual information (part of the KL decomposition) as a function of beta/gamma (for fixed alpha/n_ae_latents)

  • total correlation(part of the KL decomposition) as a function of beta/gamma (for fixed alpha/n_ae_latents)

  • dimension-wise KL (part of the KL decomposition) as a function of beta/gamma (for fixed alpha/n_ae_latents)

  • average correlation coefficient across all pairs of unsupervised latent dims as a function of beta/gamma (for fixed alpha/n_ae_latents)

  • subspace overlap computed as ||[A; B] - I||_2^2 for A, B the projections to the supervised and unsupervised subspaces, respectively, and I the identity - as a function of beta/gamma (for fixed alpha/n_ae_latents)

  • example subspace overlap matrix for gamma=0 and beta=1, with fixed alpha/n_ae_latents

  • example subspace overlap matrix for gamma=1000 and beta=1, with fixed alpha/n_ae_latents

Parameters:
  • hparams (dict) –

  • sess_ids (list) –

  • label_names (array-like) – names of label dims

  • n_background (int) – dimensionality of background latents

  • alpha_weights (array-like) – array of alpha weights for fixed values of beta, delta

  • alpha_n_ae_latents (array-like) – array of latent dimensionalities for fixed values of beta, delta using alpha_weights

  • alpha_expt_name (str) – test-tube experiment name of alpha-based hyperparam search

  • beta_weights (array-like) – array of beta weights for a fixed value of alpha

  • delta_weights (array-like) – array of beta weights for a fixed value of alpha

  • beta_delta_n_ae_latents (int) – latent dimensionality used for beta-delta hyperparam search

  • beta_delta_expt_name (str) – test-tube experiment name of beta-delta hyperparam search

  • alpha (float) – fixed value of alpha for beta-delta search

  • beta (float) – fixed value of beta for alpha search

  • delta (float) – fixed value of gamma for alpha search

  • save_file (str) – absolute path of save file; does not need file extension

  • batch_size (int, optional) – size of batches, used to compute correlation coefficient per batch; if NoneType, the correlation coefficient is computed across all time points

  • format (str, optional) – format of saved image; ‘pdf’ | ‘png’ | ‘jpeg’ | …

  • kwargs – arguments are keys of hparams, preceded by either alpha_ or beta_delta_. For example, to set the train frac of the alpha models, use alpha_train_frac; to set the rng_data_seed of the beta-delta models, use beta_delta_rng_data_seed.