Source code for behavenet.models.base

"""Base models/modules in PyTorch."""

import math
from torch import nn, save, Tensor

# to ignore imports for sphix-autoapidoc
__all__ = ['BaseModule', 'BaseModel', 'DiagLinear', 'CustomDataParallel']


[docs]class BaseModule(nn.Module): """Template for PyTorch modules.""" def __init__(self, *args, **kwargs): super().__init__() def __str__(self): """Pretty print module architecture.""" raise NotImplementedError
[docs] def build_model(self): """Build model from hparams.""" raise NotImplementedError
[docs] def forward(self, *args, **kwargs): """Push data through module.""" raise NotImplementedError
[docs] def freeze(self): """Prevent updates to module parameters.""" for param in self.parameters(): param.requires_grad = False
[docs] def unfreeze(self): """Force updates to module parameters.""" for param in self.parameters(): param.requires_grad = True
[docs]class BaseModel(nn.Module): """Template for PyTorch models.""" def __init__(self, *args, **kwargs): super().__init__() def __str__(self): """Pretty print model architecture.""" raise NotImplementedError
[docs] def build_model(self): """Build model from hparams.""" raise NotImplementedError
[docs] def forward(self, *args, **kwargs): """Push data through model.""" raise NotImplementedError
[docs] def loss(self, *args, **kwargs): """Compute loss.""" raise NotImplementedError
[docs] def save(self, filepath): """Save model parameters.""" save(self.state_dict(), filepath)
[docs] def get_parameters(self): """Get all model parameters that have gradient updates turned on.""" return filter(lambda p: p.requires_grad, self.parameters())
[docs]class DiagLinear(nn.Module): """Applies a diagonal linear transformation to the incoming data: :math:`y = xD^T + b`""" __constants__ = ['features'] # features: int # weight: Tensor def __init__(self, features, bias=True): super(DiagLinear, self).__init__() self.features = features self.weight = nn.Parameter(Tensor(features)) if bias: self.bias = nn.Parameter(Tensor(features)) else: self.register_parameter('bias', None) self.reset_parameters()
[docs] def reset_parameters(self): bound = 1 / math.sqrt(self.features) nn.init.uniform_(self.weight, -bound, bound) if self.bias is not None: bound = 1 / math.sqrt(self.features) nn.init.uniform_(self.bias, -bound, bound)
[docs] def forward(self, input): output = input.mul(self.weight) if self.bias is not None: output += self.bias return output
[docs] def extra_repr(self): return 'features={}, bias={}'.format(self.features, self.bias is not None)
[docs]class CustomDataParallel(nn.DataParallel): """Wrapper class for multi-gpu training. from https://github.com/pytorch/tutorials/issues/836 """ def __getattr__(self, name): try: return super().__getattr__(name) except AttributeError: return getattr(self.module, name)