BaseModel

class behavenet.models.base.BaseModel(*args, **kwargs)[source]

Bases: Module

Template for PyTorch models.

Methods Summary

build_model()

Build model from hparams.

forward(*args, **kwargs)

Push data through model.

get_parameters()

Get all model parameters that have gradient updates turned on.

loss(*args, **kwargs)

Compute loss.

save(filepath)

Save model parameters.

Methods Documentation

build_model()[source]

Build model from hparams.

forward(*args, **kwargs)[source]

Push data through model.

get_parameters()[source]

Get all model parameters that have gradient updates turned on.

loss(*args, **kwargs)[source]

Compute loss.

save(filepath)[source]

Save model parameters.