behavenet.plotting

Plotting and video documentation.

behavenet.plotting Package

Utility functions shared across multiple plotting modules.

Functions

concat(ims[, axis])

Concatenate two channels along x or y direction (useful for data with multiple views).

get_crop(im, y_0, y_ext, x_0, x_ext)

Get crop of image, filling in borders with zeros.

load_latents(hparams, version[, dtype])

Load all latents as a single array.

load_metrics_csv_as_df(hparams, lab, expt, …)

Load metrics csv file and return as a pandas dataframe for easy plotting.

save_movie(save_file, ani[, frame_rate])

Save out matplotlib ArtistAnimation

behavenet.plotting.ae_utils Module

Plotting and video making functions for autoencoders.

Functions

make_ae_reconstruction_movie_wrapper(…[, …])

Produce movie with original video, reconstructed video, and residual.

make_reconstruction_movie(ims[, titles, …])

Produce movie with original video and reconstructed videos.

behavenet.plotting.cond_ae_utils Module

Functions

get_input_range(input_type, hparams[, …])

Helper function to compute input range for a variety of data types.

compute_range(values_list[, min_p, max_p])

Compute min and max of a list of numbers using percentiles.

get_labels_2d_for_trial(hparams, sess_ids[, …])

Return scaled labels (in pixel space) for a given trial.

get_model_input(data_generator, hparams, model)

Return images, latents, and labels for a given trial.

interpolate_2d(interp_type, model, ims_0, …)

Return reconstructed images created by interpolating through latent/label space.

interpolate_1d(interp_type, model, ims_0, …)

Return reconstructed images created by interpolating through latent/label space.

interpolate_point_path(interp_type, model, …)

Return reconstructed images created by interpolating through multiple points.

plot_2d_frame_array(ims_list[, markers, …])

Plot list of list of interpolated images output by interpolate_2d() in a 2d grid.

plot_1d_frame_array(ims_list[, markers, …])

Plot list of list of interpolated images output by interpolate_1d() in a 2d grid.

make_interpolated(ims, save_file[, markers, …])

Make a latent space interpolation movie.

make_interpolated_multipanel(ims, save_file)

Make a multi-panel latent space interpolation movie.

fit_classifier(model, data_generator[, …])

Fit classifier model from latent space to session id.

plot_psvae_training_curves(lab, expt, …[, …])

Create training plots for each term in the ps-vae objective function.

plot_hyperparameter_search_results(lab, …)

Create a variety of diagnostic plots to assess the ps-vae hyperparameters.

plot_label_reconstructions(lab, expt, …[, …])

Plot labels and their reconstructions from an ps-vae.

plot_latent_traversals(lab, expt, animal, …)

Plot video frames representing the traversal of individual dimensions of the latent space.

make_latent_traversal_movie(lab, expt, …)

Create a multi-panel movie with each panel showing traversals of an individual latent dim.

plot_mspsvae_training_curves(hparams, alpha, …)

Create training plots for each term in the ps-vae objective function.

plot_mspsvae_hyperparameter_search_results(…)

Create a variety of diagnostic plots to assess the msps-vae hyperparameters.

make_session_swap_movie(sess_ids, hparams, …)

Create a multipanel movie, each panel showing reconstruction with different session context.

behavenet.plotting.arhmm_utils Module

Plotting and video making functions for ARHMMs.

Functions

get_discrete_chunks(states[, include_edges])

Find occurences of each discrete state.

get_state_durations(latents, hmm[, …])

Calculate frame count for each state.

get_latent_arrays_by_dtype(data_generator[, …])

Collect data from data generator and put into dictionary with dtypes for keys.

get_model_latents_states(hparams, version[, …])

Return arhmm defined in hparams with associated latents and states.

make_syllable_movies_wrapper(hparams, save_file)

Present video clips of each individual syllable in separate panels.

make_syllable_movies(ims_orig, state_list, …)

Present video clips of each individual syllable in separate panels

real_vs_sampled_wrapper(output_type, …[, …])

Produce movie with (AE) reconstructed video and sampled video.

make_real_vs_sampled_movies(ims_recon, …)

Produce movie with (AE) reconstructed video and sampled video.

plot_real_vs_sampled(latents, latents_samp, …)

Plot real and sampled latents overlaying real and (potentially sampled) states.

plot_states_overlaid_with_latents(latents, …)

Plot states for a single trial overlaid with latents.

plot_state_transition_matrix(model[, deridge])

Plot Markov transition matrix for arhmm.

plot_dynamics_matrices(model[, deridge])

Plot autoregressive dynamics matrices for arhmm.

plot_obs_biases(model)

Plot observation bias vectors for arhmm.

plot_obs_covariance_matrices(model)

Plot observation covariance matrices for arhmm.

behavenet.plotting.decoder_utils Module

Plotting functions for decoders.

Functions

get_r2s_by_trial(hparams, model_types)

For a given session, load R^2 metrics from all decoders defined by hparams.

get_best_models(metrics_df)

Find best decoder over l2 regularization and learning rate.

get_r2s_across_trials(hparams, best_models_df)

Calculate R^2 across all test trials (rather than on a trial-by-trial basis)

make_neural_reconstruction_movie_wrapper(…)

Produce movie with original video, ae reconstructed video, and neural reconstructed video.

make_neural_reconstruction_movie(ims_orig, …)

Produce movie with original video, ae reconstructed video, and neural reconstructed video.

plot_neural_reconstruction_traces_wrapper(hparams)

Plot ae latents and their neural reconstructions.

plot_neural_reconstruction_traces(traces_ae, …)

Plot ae latents and their neural reconstructions.