Source code for behavenet.fitting.training

"""Functions and classes for fitting PyTorch models with stochastic gradient descent."""

import copy
import os
import numpy as np
from tqdm import tqdm
import torch

# TODO: make it easy to finish training if unexpectedly stopped
# TODO: save models at prespecified intervals (check ae recon as a func of epoch w/o retraining)

# to ignore imports for sphix-autoapidoc
__all__ = ['Logger', 'EarlyStopping', 'fit']


[docs]class Logger(object): """Base method for logging loss metrics. Loss metrics are tracked for the aggregate dataset (potentially spanning multiple sessions) as well as session-specific metrics for easier downstream plotting. """ def __init__(self, n_datasets=1): """ Parameters ---------- n_datasets : :obj:`int` total number of datasets (sessions) served by data generator """ self.metrics = {} self.n_datasets = n_datasets dtype_strs = ['train', 'val', 'test', 'curr'] # aggregate metrics over all datasets for dtype in dtype_strs: self.metrics[dtype] = {} # separate metrics by dataset self.metrics_by_dataset = [] if self.n_datasets > 1: for dataset in range(self.n_datasets): self.metrics_by_dataset.append({}) for dtype in dtype_strs: self.metrics_by_dataset[dataset][dtype] = {}
[docs] def reset_metrics(self, dtype): """Reset all metrics. Parameters ---------- dtype : :obj:`str` datatype to reset metrics for (e.g. 'train', 'val', 'test') """ # reset aggregate metrics for key in self.metrics[dtype].keys(): self.metrics[dtype][key] = 0 # reset separated metrics for m in self.metrics_by_dataset: for key in m[dtype].keys(): m[dtype][key] = 0
[docs] def update_metrics(self, dtype, loss_dict, dataset=None): """Update metrics for a specific dtype/dataset. Parameters ---------- dtype : :obj:`str` dataset type to update metrics for (e.g. 'train', 'val', 'test') loss_dict : :obj:`dict` key-value pairs correspond to all quantities that should be logged throughout training; dictionary returned by `loss` attribute of BehaveNet models dataset : :obj:`int` or :obj:`NoneType`, optional if :obj:`NoneType`, updates the aggregated metrics; if :obj:`int`, updates the associated dataset/session """ metrics = {**loss_dict, 'batches': 1} # append `batches` to loss_dict for key, val in metrics.items(): # define metric for the first time if necessary if key not in self.metrics[dtype]: self.metrics[dtype][key] = 0 # update aggregate methods self.metrics[dtype][key] += val # update separated metrics if isinstance(dataset, int) and self.n_datasets > 1: if key not in self.metrics_by_dataset[dataset][dtype]: self.metrics_by_dataset[dataset][dtype][key] = 0 self.metrics_by_dataset[dataset][dtype][key] += val
[docs] def create_metric_row( self, dtype, epoch, batch, dataset, trial, best_epoch=None, by_dataset=False): """Export metrics and other data (e.g. epoch) for logging train progress. Parameters ---------- dtype : :obj:`str` 'train' | 'val' | 'test' epoch : :obj:`int` current training epoch batch : :obj:`int` current training batch dataset : :obj:`int` dataset id for current batch trial : :obj:`int` or :obj:`NoneType` trial id within the current dataset best_epoch : :obj:`int`, optional best current training epoch by_dataset : :obj:`bool`, optional :obj:`True` to return metrics for a specific dataset, :obj:`False` to return metrics aggregated over multiple datasets Returns ------- :obj:`dict` aggregated metrics for current epoch/batch """ if dtype == 'train': prefix = 'tr' elif dtype == 'val': prefix = 'val' elif dtype == 'test': prefix = 'test' else: raise ValueError("%s is an invalid data type" % dtype) metric_row = { 'epoch': epoch, 'batch': batch, 'trial': trial} if dtype == 'val': metric_row['best_val_epoch'] = best_epoch if by_dataset and self.n_datasets > 1: norm = self.metrics_by_dataset[dataset][dtype]['batches'] for key, val in self.metrics_by_dataset[dataset][dtype].items(): if key == 'batches': continue metric_row['%s_%s' % (prefix, key)] = val / norm else: dataset = -1 norm = self.metrics[dtype]['batches'] for key, val in self.metrics[dtype].items(): if key == 'batches': continue metric_row['%s_%s' % (prefix, key)] = val / norm metric_row['dataset'] = dataset return metric_row
[docs] def get_loss(self, dtype): """Return loss aggregated over all datasets. Parameters ---------- dtype : :obj:`str` datatype to calculate loss for (e.g. 'train', 'val', 'test') """ return self.metrics[dtype]['loss'] / self.metrics[dtype]['batches']
[docs]class EarlyStopping(object): """Stop training when a monitored quantity has stopped improving. Adapted from: https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py """ def __init__(self, patience=10, min_epochs=10, delta=0): """ Parameters ---------- patience : :obj:`int`, optional number of previous checks to average over when checking for increase in loss min_epochs : :obj:`int`, optional minimum number of epochs for training delta : :obj:`float`, optional minimum change in monitored quantity to qualify as an improvement """ self.patience = patience self.min_epochs = min_epochs self.delta = delta # keep track of `history` most recent losses # self.prev_losses = np.full(self.history, fill_value=np.nan) self.counter = 0 self.best_epoch = 0 self.best_loss = np.inf self.stopped_epoch = 0 self.should_stop = False
[docs] def on_val_check(self, epoch, curr_loss): """Check to see if loss has begun to increase on validation data for current epoch. Rather than returning the results of the check, this method updates the class attribute :obj:`should_stop`, which is checked externally by the fitting function. Parameters ---------- epoch : :obj:`int` current epoch curr_loss : :obj:`float` current loss """ # prev_mean = np.nanmean(self.prev_losses) # self.prev_losses = np.roll(self.prev_losses, 1) # self.prev_losses[0] = curr_loss # curr_mean = np.nanmean(self.prev_losses) # update best loss and epoch that it happened at if curr_loss < self.best_loss - self.delta: self.best_loss = curr_loss self.best_epoch = epoch self.counter = 0 else: self.counter += 1 # check if smoothed loss is starting to increase; exit training if so if epoch > self.min_epochs and self.counter >= self.patience: print('\n== early stopping criteria met; exiting train loop ==') print('training epochs: %d' % epoch) print('end cost: %04f' % curr_loss) print('best epoch: %i' % self.best_epoch) print('best cost: %04f\n' % self.best_loss) self.stopped_epoch = epoch self.should_stop = True
[docs]def fit(hparams, model, data_generator, exp, method='ae'): """Fit pytorch models with stochastic gradient descent and early stopping. Training parameters such as min epochs, max epochs, and early stopping hyperparameters are specified in :obj:`hparams`. For more information on how early stopping is implemented, see the class :class:`EarlyStopping`. Training progess is monitored by calculating the model loss on both training data and validation data. The training loss is calculated each epoch, and the validation loss is calculated according to the :obj:`hparams` key :obj:`'val_check_interval'`. For example, if :obj:`val_check_interval=5` then the validation loss is calculated every 5 epochs. If :obj:`val_check_interval=0.5` then the validation loss is calculated twice per epoch - after the first half of the batches have been processed, then again after all batches have been processed. Monitored metrics are saved in a csv file in the model directory. This logging is handled by the :obj:`testtube` package and the class :class:`Logger`. At the end of training, model outputs (such as latents for autoencoder models, or predictions for decoder models) can optionally be computed and saved using the :obj:`hparams` keys :obj:`'export_latents'` or :obj:`'export_predictions'`, respectively. Parameters ---------- hparams : :obj:`dict` model/training specification model : :obj:`PyTorch` model model to fit data_generator : :obj:`ConcatSessionsGenerator` object data generator to serve data batches exp : :obj:`test_tube.Experiment` object for logging training progress method : :obj:`str` specifies the type of loss - 'ae' | 'ae-msp' | 'nll' | 'conv-decoder' """ # optimizer setup optimizer = torch.optim.Adam( model.get_parameters(), lr=hparams['learning_rate'], weight_decay=hparams.get('l2_reg', 0), amsgrad=True) # logging setup logger = Logger(n_datasets=data_generator.n_datasets) # early stopping setup if hparams['enable_early_stop']: early_stop = EarlyStopping( patience=hparams['early_stop_history'], min_epochs=hparams['min_n_epochs']) else: early_stop = None # enumerate batches on which validation metrics should be recorded best_val_loss = np.inf best_val_epoch = None best_val_model = None val_check_batch = np.append( hparams['val_check_interval'] * data_generator.n_tot_batches['train'] * np.arange(1, int((hparams['max_n_epochs'] + 1) / hparams['val_check_interval'])), [data_generator.n_tot_batches['train'] * hparams['max_n_epochs'], data_generator.n_tot_batches['train'] * (hparams['max_n_epochs'] + 1)]).astype('int') # set random seeds for training if hparams.get('rng_seed_train', None) is None: rng_train = np.random.randint(0, 10000) else: rng_train = int(hparams['rng_seed_train']) torch.manual_seed(rng_train) np.random.seed(rng_train) expt_dir = os.path.join(hparams['expt_dir'], 'version_%i' % exp.version) i_epoch = 0 best_model_saved = False for i_epoch in range(hparams['max_n_epochs'] + 1): # Note: the 0th epoch has no training (randomly initialized model is evaluated) so we cycle # through `max_n_epochs` training epochs print_epoch(i_epoch, hparams['max_n_epochs']) # control how data is batched to that models can be restarted from a particular epoch torch.manual_seed(rng_train + i_epoch) # order of trials within sessions np.random.seed(rng_train + i_epoch) # order of sessions logger.reset_metrics('train') data_generator.reset_iterators('train') model.curr_epoch = i_epoch # for updating annealed loss terms for i_train in tqdm(range(data_generator.n_tot_batches['train'])): model.train() # zero out gradients. Don't want gradients from previous iterations optimizer.zero_grad() # get next minibatch and put it on the device data, dataset = data_generator.next_batch('train') if data is not None: # this happens when n_tot_batches is incorrectly calculated # call the appropriate loss function loss_dict = model.loss(data, dataset=dataset, accumulate_grad=True) logger.update_metrics('train', loss_dict, dataset=dataset) # step (evaluate untrained network on epoch 0) if i_epoch > 0: optimizer.step() # export training metrics at end of epoch if (i_train + 1) % data_generator.n_tot_batches['train'] == 0: # export aggregated metrics on train data exp.log(logger.create_metric_row( 'train', i_epoch, i_train, -1, trial=-1, by_dataset=False, best_epoch=best_val_epoch)) # export individual session metrics on train/val data if data_generator.n_datasets > 1 and dataset is not None and \ (isinstance(dataset, int) or len(dataset) == 1): for dataset in range(data_generator.n_datasets): exp.log(logger.create_metric_row( 'train', i_epoch, i_train, dataset, trial=-1, by_dataset=True, best_epoch=best_val_epoch)) exp.save() # check validation according to schedule curr_batch = (i_train + 1) + i_epoch * data_generator.n_tot_batches['train'] if np.any(curr_batch == val_check_batch): logger.reset_metrics('val') data_generator.reset_iterators('val') model.eval() for i_val in range(data_generator.n_tot_batches['val']): # get next minibatch and put it on the device data, dataset = data_generator.next_batch('val') # call the appropriate loss function loss_dict = model.loss(data, dataset=dataset, accumulate_grad=False) logger.update_metrics('val', loss_dict, dataset=dataset) # save best val model if logger.get_loss('val') < best_val_loss: best_val_loss = logger.get_loss('val') model.save(os.path.join(expt_dir, 'best_val_model.pt')) best_model_saved = True model.hparams = None best_val_model = copy.deepcopy(model) model.hparams = hparams best_val_model.hparams = hparams best_val_epoch = i_epoch # export aggregated metrics on val data exp.log(logger.create_metric_row( 'val', i_epoch, i_train, -1, trial=-1, by_dataset=False, best_epoch=best_val_epoch)) # export individual session metrics on val data if data_generator.n_datasets > 1 and \ (isinstance(dataset, int) or len(dataset) == 1): for dataset in range(data_generator.n_datasets): exp.log(logger.create_metric_row( 'val', i_epoch, i_train, dataset, trial=-1, by_dataset=True, best_epoch=best_val_epoch)) exp.save() if hparams['enable_early_stop']: early_stop.on_val_check(i_epoch, logger.get_loss('val')) if early_stop.should_stop: break # save out last model as best model if no best model saved if not best_model_saved: model.save(os.path.join(expt_dir, 'best_val_model.pt')) model.hparams = None best_val_model = copy.deepcopy(model) model.hparams = hparams best_val_model.hparams = hparams # save out last model if hparams.get('save_last_model', False): model.save(os.path.join(expt_dir, 'last_model.pt')) # compute test loss logger.reset_metrics('test') data_generator.reset_iterators('test') best_val_model.eval() for i_test in range(data_generator.n_tot_batches['test']): # get next minibatch and put it on the device data, dataset = data_generator.next_batch('test') # call the appropriate loss function logger.reset_metrics('test') loss_dict = model.loss(data, dataset=dataset, accumulate_grad=False) logger.update_metrics('test', loss_dict, dataset=dataset) # calculate metrics for each *batch* (rather than whole dataset) exp.log(logger.create_metric_row( 'test', i_epoch, i_test, dataset, trial=data['batch_idx'].item(), by_dataset=True)) exp.save() # export latents if method == 'ae' and hparams['export_latents']: print('exporting latents') from behavenet.fitting.eval import export_latents export_latents(data_generator, best_val_model) elif method == 'nll' and hparams['export_predictions']: print('exporting predictions') from behavenet.fitting.eval import export_predictions export_predictions(data_generator, best_val_model) elif method == 'conv-decoder' and hparams.get('export_predictions', False): print('warning! exporting predictions not currently implemented for convolutional decoder')
def print_epoch(curr, total): """Pretty print epoch number.""" if total < 10: print('epoch %i/%i' % (curr, total)) elif total < 100: print('epoch %02i/%02i' % (curr, total)) elif total < 1000: print('epoch %03i/%03i' % (curr, total)) elif total < 10000: print('epoch %04i/%04i' % (curr, total)) elif total < 100000: print('epoch %05i/%05i' % (curr, total)) else: print('epoch %i/%i' % (curr, total))