Logger

class behavenet.fitting.training.Logger(n_datasets=1)[source]

Bases: object

Base method for logging loss metrics.

Loss metrics are tracked for the aggregate dataset (potentially spanning multiple sessions) as well as session-specific metrics for easier downstream plotting.

Methods Summary

create_metric_row(dtype, epoch, batch, ...)

Export metrics and other data (e.g.

get_loss(dtype)

Return loss aggregated over all datasets.

reset_metrics(dtype)

Reset all metrics.

update_metrics(dtype, loss_dict[, dataset])

Update metrics for a specific dtype/dataset.

Methods Documentation

create_metric_row(dtype, epoch, batch, dataset, trial, best_epoch=None, by_dataset=False)[source]

Export metrics and other data (e.g. epoch) for logging train progress.

Parameters:
  • dtype (str) – ‘train’ | ‘val’ | ‘test’

  • epoch (int) – current training epoch

  • batch (int) – current training batch

  • dataset (int) – dataset id for current batch

  • trial (int or NoneType) – trial id within the current dataset

  • best_epoch (int, optional) – best current training epoch

  • by_dataset (bool, optional) – True to return metrics for a specific dataset, False to return metrics aggregated over multiple datasets

Returns:

aggregated metrics for current epoch/batch

Return type:

dict

get_loss(dtype)[source]

Return loss aggregated over all datasets.

Parameters:

dtype (str) – datatype to calculate loss for (e.g. ‘train’, ‘val’, ‘test’)

reset_metrics(dtype)[source]

Reset all metrics.

Parameters:

dtype (str) – datatype to reset metrics for (e.g. ‘train’, ‘val’, ‘test’)

update_metrics(dtype, loss_dict, dataset=None)[source]

Update metrics for a specific dtype/dataset.

Parameters:
  • dtype (str) – dataset type to update metrics for (e.g. ‘train’, ‘val’, ‘test’)

  • loss_dict (dict) – key-value pairs correspond to all quantities that should be logged throughout training; dictionary returned by loss attribute of BehaveNet models

  • dataset (int or NoneType, optional) – if NoneType, updates the aggregated metrics; if int, updates the associated dataset/session