behavenet.fitting¶
Model fitting documentation.
behavenet.fitting.eval Module¶
Utility functions for evaluating model fits.
Functions¶
|
Export predicted latents using an already initialized data_generator and model. |
|
Export decoder predictions using an already initialized data_generator and model. |
|
Export predicted latents using an already initialized data_generator and model. |
|
Export plot with MSE/LL as a function of training epochs. |
|
Reconstruct an image from either image or latent inputs. |
|
Calculate a single R2 value across all test batches for a decoder. |
behavenet.fitting.losses Module¶
Custom losses for PyTorch models.
Functions¶
|
Compute mean square error (MSE) loss with masks. |
|
Compute multivariate Gaussian log-likelihood with a fixed diagonal noise covariance matrix. |
|
Convert a Gaussian log-likelihood term to MSE by removing constants and swapping variances. |
|
Compute element-wise KL(q(z) || N(0, 1)) where q(z) is a normal parameterized by mu, logvar. |
|
Estimate index code mutual information in a batch. |
|
Estimate total correlation in a batch. |
|
Estimate dimensionwise KL divergence to standard normal in a batch. |
|
Decompose KL term in VAE loss. |
|
Compute inner product between subspaces defined by matrices A and B. |
|
Compute triplet loss to learn separated embedding space. |
behavenet.fitting.training Module¶
Functions and classes for fitting PyTorch models with stochastic gradient descent.
Functions¶
|
Fit pytorch models with stochastic gradient descent and early stopping. |
Classes¶
|
Base method for logging loss metrics. |
|
Stop training when a monitored quantity has stopped improving. |
behavenet.fitting.utils Module¶
Utility functions for managing model paths and the hparams dict.
Functions¶
|
Get all first-level subdirectories in a given path (no recursion). |
|
Get session-level directory for saving model outputs from hparams dict. |
|
Get output directories associated with a particular model class/type/testtube expt name. |
|
Read csv file that contains lab/expt/animal/session info. |
|
Export list of sessions to csv file. |
|
Determine if session defined by session_id dict is in the multi-session session_dir. |
|
Find all session dirs (single- and multi-session) that contain the session in hparams. |
|
Search testtube versions to find if experiment with the same hyperparameters has been fit. |
|
Returns dict containing all params considered essential for defining a model in that class. |
|
Export hyperparameter dictionary. |
|
Helper function to load data-specific hyperparameters and update hparams. |
|
Return brain region string that combines region name and inclusion info. |
|
Create test-tube experiment for logging training and storing models. |
|
Get best model version from a test tube experiment. |
|
Load the best model (and data) defined by hparams out of all available test-tube versions. |