fit¶
- behavenet.fitting.training.fit(hparams, model, data_generator, exp, method='ae')[source]¶
Fit pytorch models with stochastic gradient descent and early stopping.
Training parameters such as min epochs, max epochs, and early stopping hyperparameters are specified in
hparams.For more information on how early stopping is implemented, see the class
EarlyStopping.Training progess is monitored by calculating the model loss on both training data and validation data. The training loss is calculated each epoch, and the validation loss is calculated according to the
hparamskey'val_check_interval'. For example, ifval_check_interval=5then the validation loss is calculated every 5 epochs. Ifval_check_interval=0.5then the validation loss is calculated twice per epoch - after the first half of the batches have been processed, then again after all batches have been processed.Monitored metrics are saved in a csv file in the model directory. This logging is handled by the
testtubepackage and the classLogger.At the end of training, model outputs (such as latents for autoencoder models, or predictions for decoder models) can optionally be computed and saved using the
hparamskeys'export_latents'or'export_predictions', respectively.- Parameters:
hparams (
dict) – model/training specificationmodel (
PyTorchmodel) – model to fitdata_generator (
ConcatSessionsGeneratorobject) – data generator to serve data batchesexp (
test_tube.Experimentobject) – for logging training progressmethod (
str) – specifies the type of loss - ‘ae’ | ‘ae-msp’ | ‘nll’ | ‘conv-decoder’