Logger¶
- class behavenet.fitting.training.Logger(n_datasets=1)[source]¶
Bases:
objectBase method for logging loss metrics.
Loss metrics are tracked for the aggregate dataset (potentially spanning multiple sessions) as well as session-specific metrics for easier downstream plotting.
Methods Summary
create_metric_row(dtype, epoch, batch, ...)Export metrics and other data (e.g.
get_loss(dtype)Return loss aggregated over all datasets.
reset_metrics(dtype)Reset all metrics.
update_metrics(dtype, loss_dict[, dataset])Update metrics for a specific dtype/dataset.
Methods Documentation
- create_metric_row(dtype, epoch, batch, dataset, trial, best_epoch=None, by_dataset=False)[source]¶
Export metrics and other data (e.g. epoch) for logging train progress.
- Parameters:
dtype (
str) – ‘train’ | ‘val’ | ‘test’epoch (
int) – current training epochbatch (
int) – current training batchdataset (
int) – dataset id for current batchtrial (
intorNoneType) – trial id within the current datasetbest_epoch (
int, optional) – best current training epochby_dataset (
bool, optional) –Trueto return metrics for a specific dataset,Falseto return metrics aggregated over multiple datasets
- Returns:
aggregated metrics for current epoch/batch
- Return type:
dict
- get_loss(dtype)[source]¶
Return loss aggregated over all datasets.
- Parameters:
dtype (
str) – datatype to calculate loss for (e.g. ‘train’, ‘val’, ‘test’)
- reset_metrics(dtype)[source]¶
Reset all metrics.
- Parameters:
dtype (
str) – datatype to reset metrics for (e.g. ‘train’, ‘val’, ‘test’)
- update_metrics(dtype, loss_dict, dataset=None)[source]¶
Update metrics for a specific dtype/dataset.
- Parameters:
dtype (
str) – dataset type to update metrics for (e.g. ‘train’, ‘val’, ‘test’)loss_dict (
dict) – key-value pairs correspond to all quantities that should be logged throughout training; dictionary returned by loss attribute of BehaveNet modelsdataset (
intorNoneType, optional) – ifNoneType, updates the aggregated metrics; ifint, updates the associated dataset/session