behavenet.fitting

Model fitting documentation.

behavenet.fitting.eval Module

Utility functions for evaluating model fits.

Functions

export_latents(data_generator, model[, filename])

Export predicted latents using an already initialized data_generator and model.

export_predictions(data_generator, model[, …])

Export decoder predictions using an already initialized data_generator and model.

export_states(hparams, data_generator, model)

Export predicted latents using an already initialized data_generator and model.

export_train_plots(hparams, dtype[, …])

Export plot with MSE/LL as a function of training epochs.

get_reconstruction(model, inputs[, dataset, …])

Reconstruct an image from either image or latent inputs.

get_test_metric(hparams, model_version[, …])

Calculate a single R2 value across all test batches for a decoder.

behavenet.fitting.losses Module

Custom losses for PyTorch models.

Functions

mse(y_pred, y_true[, masks])

Compute mean square error (MSE) loss with masks.

gaussian_ll(y_pred, y_mean[, masks, std])

Compute multivariate Gaussian log-likelihood with a fixed diagonal noise covariance matrix.

gaussian_ll_to_mse(ll, n_dims[, …])

Convert a Gaussian log-likelihood term to MSE by removing constants and swapping variances.

kl_div_to_std_normal(mu, logvar)

Compute element-wise KL(q(z) || N(0, 1)) where q(z) is a normal parameterized by mu, logvar.

index_code_mi(z, mu, logvar)

Estimate index code mutual information in a batch.

total_correlation(z, mu, logvar)

Estimate total correlation in a batch.

dimension_wise_kl_to_std_normal(z, mu, logvar)

Estimate dimensionwise KL divergence to standard normal in a batch.

decomposed_kl(z, mu, logvar)

Decompose KL term in VAE loss.

subspace_overlap(A, B[, C])

Compute inner product between subspaces defined by matrices A and B.

triplet_loss(triplet_loss_obj, z, datasets)

Compute triplet loss to learn separated embedding space.

behavenet.fitting.training Module

Functions and classes for fitting PyTorch models with stochastic gradient descent.

Functions

fit(hparams, model, data_generator, exp[, …])

Fit pytorch models with stochastic gradient descent and early stopping.

Classes

Logger([n_datasets])

Base method for logging loss metrics.

EarlyStopping([patience, min_epochs, delta])

Stop training when a monitored quantity has stopped improving.

behavenet.fitting.utils Module

Utility functions for managing model paths and the hparams dict.

Functions

get_subdirs(path)

Get all first-level subdirectories in a given path (no recursion).

get_session_dir(hparams[, session_source])

Get session-level directory for saving model outputs from hparams dict.

get_expt_dir(hparams[, model_class, …])

Get output directories associated with a particular model class/type/testtube expt name.

read_session_info_from_csv(session_file)

Read csv file that contains lab/expt/animal/session info.

export_session_info_to_csv(session_dir, ids_list)

Export list of sessions to csv file.

contains_session(session_dir, session_id)

Determine if session defined by session_id dict is in the multi-session session_dir.

find_session_dirs(hparams)

Find all session dirs (single- and multi-session) that contain the session in hparams.

experiment_exists(hparams[, which_version])

Search testtube versions to find if experiment with the same hyperparameters has been fit.

get_model_params(hparams)

Returns dict containing all params considered essential for defining a model in that class.

export_hparams(hparams, exp)

Export hyperparameter dictionary.

get_lab_example(hparams, lab, expt)

Helper function to load data-specific hyperparameters and update hparams.

get_region_dir(hparams)

Return brain region string that combines region name and inclusion info.

create_tt_experiment(hparams)

Create test-tube experiment for logging training and storing models.

get_best_model_version(expt_dir[, measure, …])

Get best model version from a test tube experiment.

get_best_model_and_data(hparams[, Model, …])

Load the best model (and data) defined by hparams out of all available test-tube versions.