"""Base models/modules in PyTorch."""
import math
from torch import nn, save, Tensor
# to ignore imports for sphix-autoapidoc
__all__ = ['BaseModule', 'BaseModel', 'DiagLinear', 'CustomDataParallel']
[docs]class BaseModule(nn.Module):
"""Template for PyTorch modules."""
def __init__(self, *args, **kwargs):
super().__init__()
def __str__(self):
"""Pretty print module architecture."""
raise NotImplementedError
[docs] def build_model(self):
"""Build model from hparams."""
raise NotImplementedError
[docs] def forward(self, *args, **kwargs):
"""Push data through module."""
raise NotImplementedError
[docs] def freeze(self):
"""Prevent updates to module parameters."""
for param in self.parameters():
param.requires_grad = False
[docs] def unfreeze(self):
"""Force updates to module parameters."""
for param in self.parameters():
param.requires_grad = True
[docs]class BaseModel(nn.Module):
"""Template for PyTorch models."""
def __init__(self, *args, **kwargs):
super().__init__()
def __str__(self):
"""Pretty print model architecture."""
raise NotImplementedError
[docs] def build_model(self):
"""Build model from hparams."""
raise NotImplementedError
[docs] def forward(self, *args, **kwargs):
"""Push data through model."""
raise NotImplementedError
[docs] def loss(self, *args, **kwargs):
"""Compute loss."""
raise NotImplementedError
[docs] def save(self, filepath):
"""Save model parameters."""
save(self.state_dict(), filepath)
[docs] def get_parameters(self):
"""Get all model parameters that have gradient updates turned on."""
return filter(lambda p: p.requires_grad, self.parameters())
[docs]class DiagLinear(nn.Module):
"""Applies a diagonal linear transformation to the incoming data: :math:`y = xD^T + b`"""
__constants__ = ['features']
# features: int
# weight: Tensor
def __init__(self, features, bias=True):
super(DiagLinear, self).__init__()
self.features = features
self.weight = nn.Parameter(Tensor(features))
if bias:
self.bias = nn.Parameter(Tensor(features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
[docs] def reset_parameters(self):
bound = 1 / math.sqrt(self.features)
nn.init.uniform_(self.weight, -bound, bound)
if self.bias is not None:
bound = 1 / math.sqrt(self.features)
nn.init.uniform_(self.bias, -bound, bound)
[docs] def forward(self, input):
output = input.mul(self.weight)
if self.bias is not None:
output += self.bias
return output
[docs]class CustomDataParallel(nn.DataParallel):
"""Wrapper class for multi-gpu training.
from https://github.com/pytorch/tutorials/issues/836
"""
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.module, name)