fit

behavenet.fitting.training.fit(hparams, model, data_generator, exp, method='ae')[source]

Fit pytorch models with stochastic gradient descent and early stopping.

Training parameters such as min epochs, max epochs, and early stopping hyperparameters are specified in hparams.

For more information on how early stopping is implemented, see the class EarlyStopping.

Training progess is monitored by calculating the model loss on both training data and validation data. The training loss is calculated each epoch, and the validation loss is calculated according to the hparams key 'val_check_interval'. For example, if val_check_interval=5 then the validation loss is calculated every 5 epochs. If val_check_interval=0.5 then the validation loss is calculated twice per epoch - after the first half of the batches have been processed, then again after all batches have been processed.

Monitored metrics are saved in a csv file in the model directory. This logging is handled by the testtube package and the class Logger.

At the end of training, model outputs (such as latents for autoencoder models, or predictions for decoder models) can optionally be computed and saved using the hparams keys 'export_latents' or 'export_predictions', respectively.

Parameters:
  • hparams (dict) – model/training specification

  • model (PyTorch model) – model to fit

  • data_generator (ConcatSessionsGenerator object) – data generator to serve data batches

  • exp (test_tube.Experiment object) – for logging training progress

  • method (str) – specifies the type of loss - ‘ae’ | ‘ae-msp’ | ‘nll’ | ‘conv-decoder’