"""Variational autoencoder models implemented in PyTorch."""
import numpy as np
from sklearn.metrics import r2_score
import torch
from torch import nn
import behavenet.fitting.losses as losses
from behavenet.models.aes import AE, ConvAEDecoder, ConvAEEncoder
# to ignore imports for sphix-autoapidoc
__all__ = [
'reparameterize', 'VAE', 'ConditionalVAE', 'BetaTCVAE', 'PSVAE', 'MSPSVAE', 'ConvAEPSEncoder',
'ConvAEMSPSEncoder']
[docs]def reparameterize(mu, logvar):
"""Sample from N(mu, var)
Parameters
----------
mu : :obj:`torch.Tensor`
vector of mean parameters
logvar : :obj:`torch.Tensor`
vector of log variances; only mean field approximation is currently implemented
Returns
-------
:obj:`torch.Tensor`
sampled vector of shape (n_frames, n_latents)
"""
std = torch.exp(logvar)
eps = torch.randn_like(std)
return eps.mul(std).add_(mu)
[docs]class VAE(AE):
"""Base variational autoencoder class.
This class constructs convolutional variational autoencoders. The convolutional autoencoder
architecture is defined by various keys in the dict that serves as the constructor input. See
the :mod:`behavenet.fitting.ae_model_architecture_generator` module to see examples for how
this is done.
The VAE class can also be used to fit β-VAE models (see https://arxiv.org/pdf/1804.03599.pdf)
by changing the value of the `vae.beta` parameter in the `ae_model.json` file; a value of 1
corresponds to a standard VAE; a value >1 will upweight the KL divergence term which, in some
cases, can lead to disentangling of the latent representation.
"""
def __init__(self, hparams):
"""
Parameters
----------
hparams : :obj:`dict`
- 'model_type' (:obj:`int`): 'conv'
- 'model_class' (:obj:`str`): 'vae'
- 'y_pixels' (:obj:`int`)
- 'x_pixels' (:obj:`int`)
- 'n_input_channels' (:obj:`int`)
- 'n_ae_latents' (:obj:`int`)
- 'fit_sess_io_layers; (:obj:`bool`): fit session-specific input/output layers
- 'vae.beta' (:obj:`float`)
- 'vae.beta_anneal_epochs' (:obj:`int`)
- 'ae_encoding_x_dim' (:obj:`list`)
- 'ae_encoding_y_dim' (:obj:`list`)
- 'ae_encoding_n_channels' (:obj:`list`)
- 'ae_encoding_kernel_size' (:obj:`list`)
- 'ae_encoding_stride_size' (:obj:`list`)
- 'ae_encoding_x_padding' (:obj:`list`)
- 'ae_encoding_y_padding' (:obj:`list`)
- 'ae_encoding_layer_type' (:obj:`list`)
- 'ae_decoding_x_dim' (:obj:`list`)
- 'ae_decoding_y_dim' (:obj:`list`)
- 'ae_decoding_n_channels' (:obj:`list`)
- 'ae_decoding_kernel_size' (:obj:`list`)
- 'ae_decoding_stride_size' (:obj:`list`)
- 'ae_decoding_x_padding' (:obj:`list`)
- 'ae_decoding_y_padding' (:obj:`list`)
- 'ae_decoding_layer_type' (:obj:`list`)
- 'ae_decoding_starting_dim' (:obj:`list`)
- 'ae_decoding_last_FF_layer' (:obj:`bool`)
"""
if hparams['model_type'] == 'linear':
raise NotImplementedError
hparams['variational'] = True
super().__init__(hparams)
# set up kl annealing
anneal_epochs = self.hparams.get('vae.beta_anneal_epochs', 0)
self.curr_epoch = 0 # must be modified by training script
if anneal_epochs > 0:
self.beta_vals = np.append(
np.linspace(0, hparams['vae.beta'], anneal_epochs),
np.ones(hparams['max_n_epochs'] + 1)) # sloppy addition to fully cover rest
else:
self.beta_vals = hparams['vae.beta'] * np.ones(hparams['max_n_epochs'] + 1)
[docs] def forward(self, x, dataset=None, use_mean=False, **kwargs):
"""Process input data.
Parameters
----------
x : :obj:`torch.Tensor` object
input data
dataset : :obj:`int`
used with session-specific io layers
use_mean : :obj:`bool`
True to skip sampling step
Returns
-------
:obj:`tuple`
- x_hat (:obj:`torch.Tensor`): output of shape (n_frames, n_channels, y_pix, x_pix)
- z (:obj:`torch.Tensor`): sampled latent variable of shape (n_frames, n_latents)
- mu (:obj:`torch.Tensor`): mean paramter of shape (n_frames, n_latents)
- logvar (:obj:`torch.Tensor`): logvar paramter of shape (n_frames, n_latents)
"""
mu, logvar, pool_idx, outsize = self.encoding(x, dataset=dataset)
if use_mean:
z = mu
else:
z = reparameterize(mu, logvar)
x_hat = self.decoding(z, pool_idx, outsize, dataset=dataset)
return x_hat, z, mu, logvar
[docs] def loss(self, data, dataset=0, accumulate_grad=True, chunk_size=200):
"""Calculate ELBO loss for VAE.
The batch is split into chunks if larger than a hard-coded `chunk_size` to keep memory
requirements low; gradients are accumulated across all chunks before a gradient step is
taken.
Parameters
----------
data : :obj:`dict`
batch of data; keys should include 'images' and 'masks', if necessary
dataset : :obj:`int`, optional
used for session-specific io layers
accumulate_grad : :obj:`bool`, optional
accumulate gradient for training step
chunk_size : :obj:`int`, optional
batch is split into chunks of this size to keep memory requirements low
Returns
-------
:obj:`dict`
- 'loss' (:obj:`float`): full elbo
- 'loss_ll' (:obj:`float`): log-likelihood portion of elbo
- 'loss_kl' (:obj:`float`): kl portion of elbo
- 'loss_mse' (:obj:`float`): mse (without gaussian constants)
- 'beta' (:obj:`float`): weight in front of kl term
"""
x = data['images'][0]
m = data['masks'][0] if 'masks' in data else None
beta = self.beta_vals[self.curr_epoch]
batch_size = x.shape[0]
n_chunks = int(np.ceil(batch_size / chunk_size))
loss_val = 0
loss_ll_val = 0
loss_kl_val = 0
loss_mse_val = 0
for chunk in range(n_chunks):
idx_beg = chunk * chunk_size
idx_end = np.min([(chunk + 1) * chunk_size, batch_size])
x_in = x[idx_beg:idx_end]
m_in = m[idx_beg:idx_end] if m is not None else None
x_hat, _, mu, logvar = self.forward(x_in, dataset=dataset, use_mean=False)
# log-likelihood
loss_ll = losses.gaussian_ll(x_in, x_hat, m_in)
# kl
loss_kl = losses.kl_div_to_std_normal(mu, logvar)
# combine
loss = -loss_ll + beta * loss_kl
if accumulate_grad:
loss.backward()
# get loss value (weighted by batch size)
loss_val += loss.item() * (idx_end - idx_beg)
loss_ll_val += loss_ll.item() * (idx_end - idx_beg)
loss_kl_val += loss_kl.item() * (idx_end - idx_beg)
loss_mse_val += losses.gaussian_ll_to_mse(
loss_ll.item(), np.prod(x.shape[1:])) * (idx_end - idx_beg)
loss_val /= batch_size
loss_ll_val /= batch_size
loss_kl_val /= batch_size
loss_mse_val /= batch_size
loss_dict = {
'loss': loss_val, 'loss_ll': loss_ll_val, 'loss_kl': loss_kl_val,
'loss_mse': loss_mse_val, 'beta': beta}
return loss_dict
[docs]class ConditionalVAE(VAE):
"""Conditional variational autoencoder class.
This class constructs conditional convolutional variational autoencoders. At the latent layer
an additional set of variables, saved under the 'labels' key in the hdf5 data file, are
concatenated with the latents before being reshaped into a 2D array for decoding.
"""
def __init__(self, hparams):
"""See constructor documentation of AE for hparams details.
Parameters
----------
hparams : :obj:`dict`
in addition to the standard keys, must also contain :obj:`n_labels` and
:obj:`conditional_encoder`
"""
super().__init__(hparams)
[docs] def build_model(self):
"""Construct the model using hparams.
The ConditionalAE is initialized when :obj:`model_class='cond-ae`, and currently only
supports :obj:`model_type='conv` (i.e. no linear)
"""
self.hparams['hidden_layer_size'] = self.hparams['n_ae_latents'] + self.hparams['n_labels']
self.encoding = ConvAEEncoder(self.hparams)
self.decoding = ConvAEDecoder(self.hparams)
[docs] def forward(self, x, dataset=None, labels=None, labels_2d=None, use_mean=False, **kwargs):
"""Process input data.
Parameters
----------
x : :obj:`torch.Tensor` object
input data of shape (batch, n_channels, y_pix, x_pix)
dataset : :obj:`int`
used with session-specific io layers
labels : :obj:`torch.Tensor` object
continuous labels corresponding to input data, of shape (batch, n_labels)
labels_2d: :obj:`torch.Tensor` object
one-hot labels corresponding to input data, of shape (batch, n_labels, y_pix, x_pix);
for a given frame, each channel corresponds to a label and is all zeros with a single
value of one in the proper x/y position
use_mean : :obj:`bool`
True to skip sampling step
Returns
-------
:obj:`tuple`
- y (:obj:`torch.Tensor`): output of shape (n_frames, n_channels, y_pix, x_pix)
- x (:obj:`torch.Tensor`): hidden representation of shape (n_frames, n_latents)
"""
if self.hparams['conditional_encoder']:
# append label information to input
x = torch.cat((x, labels_2d), dim=1)
mu, logvar, pool_idx, outsize = self.encoding(x, dataset=dataset)
if use_mean:
z = mu
else:
z = reparameterize(mu, logvar)
z_aug = torch.cat((z, labels), dim=1)
x_hat = self.decoding(z_aug, pool_idx, outsize, dataset=dataset)
return x_hat, z, mu, logvar
[docs] def loss(self, data, dataset=0, accumulate_grad=True, chunk_size=200):
"""Calculate ELBO loss for ConditionalVAE.
The batch is split into chunks if larger than a hard-coded `chunk_size` to keep memory
requirements low; gradients are accumulated across all chunks before a gradient step is
taken.
Parameters
----------
data : :obj:`dict`
batch of data; keys should include 'images' and 'masks', if necessary
dataset : :obj:`int`, optional
used for session-specific io layers
accumulate_grad : :obj:`bool`, optional
accumulate gradient for training step
chunk_size : :obj:`int`, optional
batch is split into chunks of this size to keep memory requirements low
Returns
-------
:obj:`dict`
- 'loss' (:obj:`float`): full elbo
- 'loss_ll' (:obj:`float`): log-likelihood portion of elbo
- 'loss_kl' (:obj:`float`): kl portion of elbo
- 'loss_mse' (:obj:`float`): mse (without gaussian constants)
- 'beta' (:obj:`float`): weight in front of kl term
"""
x = data['images'][0]
y = data['labels'][0]
m = data['masks'][0] if 'masks' in data else None
if self.hparams['conditional_encoder']:
# continuous labels transformed into 2d one-hot array as input to encoder
y_2d = data['labels_sc'][0]
else:
y_2d = None
beta = self.beta_vals[self.curr_epoch]
batch_size = x.shape[0]
n_chunks = int(np.ceil(batch_size / chunk_size))
loss_val = 0
loss_ll_val = 0
loss_kl_val = 0
loss_mse_val = 0
for chunk in range(n_chunks):
idx_beg = chunk * chunk_size
idx_end = np.min([(chunk + 1) * chunk_size, batch_size])
x_in = x[idx_beg:idx_end]
y_in = y[idx_beg:idx_end]
m_in = m[idx_beg:idx_end] if m is not None else None
y_2d_in = y_2d[idx_beg:idx_end] if y_2d is not None else None
x_hat, _, mu, logvar = self.forward(
x_in, dataset=dataset, use_mean=False, labels=y_in, labels_2d=y_2d_in)
# log-likelihood
loss_ll = losses.gaussian_ll(x_in, x_hat, m_in)
# kl
loss_kl = losses.kl_div_to_std_normal(mu, logvar)
# combine
loss = -loss_ll + beta * loss_kl
if accumulate_grad:
loss.backward()
# get loss value (weighted by batch size)
loss_val += loss.item() * (idx_end - idx_beg)
loss_ll_val += loss_ll.item() * (idx_end - idx_beg)
loss_kl_val += loss_kl.item() * (idx_end - idx_beg)
loss_mse_val += losses.gaussian_ll_to_mse(
loss_ll.item(), np.prod(x.shape[1:])) * (idx_end - idx_beg)
loss_val /= batch_size
loss_ll_val /= batch_size
loss_kl_val /= batch_size
loss_mse_val /= batch_size
loss_dict = {
'loss': loss_val, 'loss_ll': loss_ll_val, 'loss_kl': loss_kl_val,
'loss_mse': loss_mse_val, 'beta': beta}
return loss_dict
[docs]class BetaTCVAE(VAE):
"""Beta Total Correlation VAE class.
This class constructs convolutional variational autoencoders and decomposes the KL divergence
term in the ELBO into three terms:
1. index code mutual information
2. total correlation
3. dimension-wise KL
The total correlation term is up-weighted to encourage "disentangled" latents; for more
information, see https://arxiv.org/pdf/1802.04942.pdf.
"""
def __init__(self, hparams):
"""
Parameters
----------
hparams : :obj:`dict`
in addition to the standard keys, must also contain :obj:`btcvae.beta`
"""
if hparams['model_type'] == 'linear':
raise NotImplementedError
super().__init__(hparams)
# set up beta annealing
anneal_epochs = self.hparams.get('beta_tcvae.beta_anneal_epochs', 0)
self.curr_epoch = 0 # must be modified by training script
beta = hparams['beta_tcvae.beta']
# TODO: these values should not be precomputed
if anneal_epochs > 0:
# annealing for total correlation term
self.beta_vals = np.append(
np.linspace(0, beta, anneal_epochs), # USED TO START AT 1!!
beta * np.ones(hparams['max_n_epochs'] + 1)) # sloppy addition to fully cover rest
# annealing for remaining kl terms - index code mutual info and dim-wise kl
self.kl_anneal_vals = np.append(
np.linspace(0, 1, anneal_epochs),
np.ones(hparams['max_n_epochs'] + 1)) # sloppy addition to fully cover rest
else:
self.beta_vals = beta * np.ones(hparams['max_n_epochs'] + 1)
self.kl_anneal_vals = np.ones(hparams['max_n_epochs'] + 1)
[docs] def loss(self, data, dataset=0, accumulate_grad=True, chunk_size=200):
"""Calculate (decomposed) ELBO loss for VAE.
The batch is split into chunks if larger than a hard-coded `chunk_size` to keep memory
requirements low; gradients are accumulated across all chunks before a gradient step is
taken.
Parameters
----------
data : :obj:`dict`
batch of data; keys should include 'images' and 'masks', if necessary
dataset : :obj:`int`, optional
used for session-specific io layers
accumulate_grad : :obj:`bool`, optional
accumulate gradient for training step
chunk_size : :obj:`int`, optional
batch is split into chunks of this size to keep memory requirements low
Returns
-------
:obj:`dict`
- 'loss' (:obj:`float`): full elbo
- 'loss_ll' (:obj:`float`): log-likelihood portion of elbo
- 'loss_mi' (:obj:`float`): code index mutual info portion of kl of elbo
- 'loss_tc' (:obj:`float`): total correlation portion of kl of elbo
- 'loss_dwkl' (:obj:`float`): dim-wise kl portion of kl of elbo
- 'loss_mse' (:obj:`float`): mse (without gaussian constants)
- 'beta' (:obj:`float`): weight in front of kl term
"""
x = data['images'][0]
m = data['masks'][0] if 'masks' in data else None
beta = self.beta_vals[self.curr_epoch]
kl = self.kl_anneal_vals[self.curr_epoch]
batch_size = x.shape[0]
n_chunks = int(np.ceil(batch_size / chunk_size))
loss_strs = ['loss', 'loss_ll', 'loss_mi', 'loss_tc', 'loss_dwkl']
loss_dict_vals = {loss: 0 for loss in loss_strs}
loss_dict_vals['loss_mse'] = 0
for chunk in range(n_chunks):
idx_beg = chunk * chunk_size
idx_end = np.min([(chunk + 1) * chunk_size, batch_size])
x_in = x[idx_beg:idx_end]
m_in = m[idx_beg:idx_end] if m is not None else None
x_hat, sample, mu, logvar = self.forward(x_in, dataset=dataset, use_mean=False)
# reset losses
loss_dict_torch = {loss: 0 for loss in loss_strs}
# data log-likelihood
loss_dict_torch['loss_ll'] = losses.gaussian_ll(x_in, x_hat, m_in)
loss_dict_torch['loss'] -= loss_dict_torch['loss_ll']
# compute all terms of decomposed elbo at once
index_code_mi, total_correlation, dimension_wise_kl = losses.decomposed_kl(
sample, mu, logvar)
# unsupervised latents index-code mutual information
loss_dict_torch['loss_mi'] = index_code_mi
loss_dict_torch['loss'] += kl * loss_dict_torch['loss_mi']
# unsupervised latents total correlation
loss_dict_torch['loss_tc'] = total_correlation
loss_dict_torch['loss'] += beta * loss_dict_torch['loss_tc']
# unsupervised latents dimension-wise kl
loss_dict_torch['loss_dwkl'] = dimension_wise_kl
loss_dict_torch['loss'] += kl * loss_dict_torch['loss_dwkl']
if accumulate_grad:
loss_dict_torch['loss'].backward()
# get loss value (weighted by batch size)
bs = idx_end - idx_beg
for key, val in loss_dict_torch.items():
loss_dict_vals[key] += val.item() * bs
loss_dict_vals['loss_mse'] += losses.gaussian_ll_to_mse(
loss_dict_vals['loss_ll'] / bs, np.prod(x.shape[1:])) * bs
# compile (properly weighted) loss terms
for key in loss_dict_vals.keys():
loss_dict_vals[key] /= batch_size
# store hyperparams
loss_dict_vals['beta'] = beta
return loss_dict_vals
[docs]class PSVAE(AE):
"""Partitioned subspace variational autoencoder class.
This class constructs a VAE that...
"""
def __init__(self, hparams):
"""See constructor documentation of AE for hparams details.
Parameters
----------
hparams : :obj:`dict`
in addition to the standard keys, must also contain:
- 'n_labels' (:obj:`n_labels`)
- 'ps_vae.alpha' (:obj:`float`)
- 'ps_vae.beta' (:obj:`float`)
"""
if hparams['model_type'] == 'linear':
raise NotImplementedError
if hparams['n_ae_latents'] < hparams['n_labels']:
raise ValueError('PS-VAE model must contain at least as many latents as labels')
self.n_latents = hparams['n_ae_latents']
self.n_labels = hparams['n_labels']
hparams['variational'] = True
super().__init__(hparams)
# set up beta annealing
anneal_epochs = self.hparams.get('ps_vae.anneal_epochs', 0)
self.curr_epoch = 0 # must be modified by training script
beta = hparams['ps_vae.beta']
# TODO: these values should not be precomputed
if anneal_epochs > 0:
# annealing for total correlation term
self.beta_vals = np.append(
np.linspace(0, beta, anneal_epochs), # USED TO START AT 1!!
beta * np.ones(hparams['max_n_epochs'] + 1)) # sloppy addition to fully cover rest
# annealing for remaining kl terms - index code mutual info and dim-wise kl
self.kl_anneal_vals = np.append(
np.linspace(0, 1, anneal_epochs),
np.ones(hparams['max_n_epochs'] + 1)) # sloppy addition to fully cover rest
else:
self.beta_vals = beta * np.ones(hparams['max_n_epochs'] + 1)
self.kl_anneal_vals = np.ones(hparams['max_n_epochs'] + 1)
[docs] def build_model(self):
"""Construct the model using hparams."""
self.hparams['hidden_layer_size'] = self.hparams['n_ae_latents']
if self.model_type == 'conv':
self.encoding = ConvAEPSEncoder(self.hparams)
self.decoding = ConvAEDecoder(self.hparams)
elif self.model_type == 'linear':
raise NotImplementedError
# if self.hparams.get('fit_sess_io_layers', False):
# raise NotImplementedError
# n_latents = self.hparams['n_ae_latents']
# self.encoding = LinearAEEncoder(n_latents, self.img_size)
# self.decoding = LinearAEDecoder(n_latents, self.img_size, self.encoding)
else:
raise ValueError('"%s" is an invalid model_type' % self.model_type)
[docs] def forward(self, x, dataset=None, use_mean=False, **kwargs):
"""Process input data.
Parameters
----------
x : :obj:`torch.Tensor` object
input data of shape (n_frames, n_channels, y_pix, x_pix)
dataset : :obj:`int`
used with session-specific io layers
use_mean : :obj:`bool`
True to skip sampling step
Returns
-------
:obj:`tuple`
- x_hat (:obj:`torch.Tensor`): output of shape (n_frames, n_channels, y_pix, x_pix)
- z (:obj:`torch.Tensor`): sampled latent variable of shape (n_frames, n_latents)
- mu (:obj:`torch.Tensor`): mean paramter of shape (n_frames, n_latents)
- logvar (:obj:`torch.Tensor`): logvar paramter of shape (n_frames, n_latents)
- y_hat (:obj:`torch.Tensor`): output of shape (n_frames, n_labels)
"""
y, w, logvar, pool_idx, outsize = self.encoding(x, dataset=dataset)
mu = torch.cat([y, w], axis=1)
if use_mean:
z = mu
else:
z = reparameterize(mu, logvar)
x_hat = self.decoding(z, pool_idx, outsize, dataset=dataset)
y_hat = self.encoding.D(y)
return x_hat, z, mu, logvar, y_hat
[docs] def loss(self, data, dataset=0, accumulate_grad=True, chunk_size=200):
"""Calculate modified ELBO loss for PSVAE.
The batch is split into chunks if larger than a hard-coded `chunk_size` to keep memory
requirements low; gradients are accumulated across all chunks before a gradient step is
taken.
Parameters
----------
data : :obj:`dict`
batch of data; keys should include 'images' and 'masks', if necessary
dataset : :obj:`int`, optional
used for session-specific io layers
accumulate_grad : :obj:`bool`, optional
accumulate gradient for training step
chunk_size : :obj:`int`, optional
batch is split into chunks of this size to keep memory requirements low
Returns
-------
:obj:`dict`
- 'loss' (:obj:`float`): full elbo
- 'loss_ll' (:obj:`float`): log-likelihood portion of elbo
- 'loss_kl' (:obj:`float`): kl portion of elbo
- 'loss_mse' (:obj:`float`): mse (without gaussian constants)
- 'beta' (:obj:`float`): weight in front of kl term
"""
x = data['images'][0]
y = data['labels'][0]
m = data['masks'][0] if 'masks' in data else None
n = data['labels_masks'][0] if 'labels_masks' in data else None
batch_size = x.shape[0]
n_chunks = int(np.ceil(batch_size / chunk_size))
n_labels = self.hparams['n_labels']
# n_latents = self.hparams['n_ae_latents']
# compute hyperparameters
alpha = self.hparams['ps_vae.alpha']
beta = self.beta_vals[self.curr_epoch]
kl = self.kl_anneal_vals[self.curr_epoch]
loss_strs = [
'loss', 'loss_data_ll', 'loss_label_ll', 'loss_zs_kl', 'loss_zu_mi', 'loss_zu_tc',
'loss_zu_dwkl']
loss_dict_vals = {loss: 0 for loss in loss_strs}
loss_dict_vals['loss_data_mse'] = 0
y_hat_all = []
for chunk in range(n_chunks):
idx_beg = chunk * chunk_size
idx_end = np.min([(chunk + 1) * chunk_size, batch_size])
x_in = x[idx_beg:idx_end]
y_in = y[idx_beg:idx_end]
m_in = m[idx_beg:idx_end] if m is not None else None
n_in = n[idx_beg:idx_end] if n is not None else None
x_hat, sample, mu, logvar, y_hat = self.forward(x_in, dataset=dataset, use_mean=False)
# reset losses
loss_dict_torch = {loss: 0 for loss in loss_strs}
# data log-likelihood
loss_dict_torch['loss_data_ll'] = losses.gaussian_ll(x_in, x_hat, m_in)
loss_dict_torch['loss'] -= loss_dict_torch['loss_data_ll']
# label log-likelihood
loss_dict_torch['loss_label_ll'] = losses.gaussian_ll(y_in, y_hat, n_in)
loss_dict_torch['loss'] -= alpha * loss_dict_torch['loss_label_ll']
# supervised latents kl
loss_dict_torch['loss_zs_kl'] = losses.kl_div_to_std_normal(
mu[:, :n_labels], logvar[:, :n_labels])
loss_dict_torch['loss'] += loss_dict_torch['loss_zs_kl']
# compute all terms of decomposed elbo at once
index_code_mi, total_correlation, dimension_wise_kl = losses.decomposed_kl(
sample[:, n_labels:], mu[:, n_labels:], logvar[:, n_labels:])
# unsupervised latents index-code mutual information
loss_dict_torch['loss_zu_mi'] = index_code_mi
loss_dict_torch['loss'] += kl * loss_dict_torch['loss_zu_mi']
# unsupervised latents total correlation
loss_dict_torch['loss_zu_tc'] = total_correlation
loss_dict_torch['loss'] += beta * loss_dict_torch['loss_zu_tc']
# unsupervised latents dimension-wise kl
loss_dict_torch['loss_zu_dwkl'] = dimension_wise_kl
loss_dict_torch['loss'] += kl * loss_dict_torch['loss_zu_dwkl']
if accumulate_grad:
loss_dict_torch['loss'].backward()
# get loss value (weighted by batch size)
bs = idx_end - idx_beg
for key, val in loss_dict_torch.items():
loss_dict_vals[key] += val.item() * bs
loss_dict_vals['loss_data_mse'] += losses.gaussian_ll_to_mse(
loss_dict_vals['loss_data_ll'] / bs, np.prod(x.shape[1:])) * bs
# collect predicted labels to compute R2
y_hat_all.append(y_hat.cpu().detach().numpy())
# use variance-weighted r2s to ignore small-variance latents
y_hat_all = np.concatenate(y_hat_all, axis=0)
y_all = y.cpu().detach().numpy()
if n is not None:
n_np = n.cpu().detach().numpy()
r2 = r2_score(y_all[n_np == 1], y_hat_all[n_np == 1], multioutput='variance_weighted')
else:
r2 = r2_score(y_all, y_hat_all, multioutput='variance_weighted')
# compile (properly weighted) loss terms
for key in loss_dict_vals.keys():
loss_dict_vals[key] /= batch_size
# store hyperparams
loss_dict_vals['alpha'] = alpha
loss_dict_vals['beta'] = beta
loss_dict_vals['label_r2'] = r2
return loss_dict_vals
[docs] def get_predicted_labels(self, x, dataset=None, use_mean=True):
"""Process input data to get predicted labels.
Parameters
----------
x : :obj:`torch.Tensor` object
input data of shape (n_frames, n_channels, y_pix, x_pix)
dataset : :obj:`int`
used with session-specific io layers
use_mean : :obj:`bool`
True to skip sampling step
Returns
-------
:obj:`torch.Tensor`
output of shape (n_frames, n_labels)
"""
y, w, logvar, pool_idx, outsize = self.encoding(x, dataset=dataset)
if not use_mean:
y = reparameterize(y, logvar[:, :self.n_labels])
y_hat = self.encoding.D(y)
return y_hat
[docs]class MSPSVAE(PSVAE):
"""Partitioned subspace variational autoencoder class for multiple sessions."""
def __init__(self, hparams):
"""See constructor documentation of AE for hparams details.
Parameters
----------
hparams : :obj:`dict`
in addition to the standard keys, must also contain:
- 'n_labels' (:obj:`n_labels`) # number of supervised dims (number of labels)
- 'ps_vae.alpha' (:obj:`float`) # weight on label reconstruction loss
- 'ps_vae.beta' (:obj:`float`) # weight on unsupervised TC loss
- 'ps_vae.gamma' (:obj:`float`) # weight on orthogonalization loss
- 'ps_vae.delta' (:obj:`float`) # weight on background embedding loss
- 'n_background' (:obj:`int`) # dimensionality of background latent space
- 'n_sessions_per_batch' (:obj:`int`) # data generator param, >1
- 'ps_vae.ms_loss' (:obj:`str`) # multi-session loss: 'triplet' | 'classification'
"""
if hparams['n_sessions_per_batch'] == 1:
raise ValueError('must choose "n_sessions_per_batch" > 1 in hparams')
super().__init__(hparams)
n_background = self.hparams.get('n_background', 4)
self.hparams['n_background'] = n_background # make sure this gets saved
self.TripletLoss = nn.TripletMarginLoss(margin=1.0, p=2)
[docs] def build_model(self):
"""Construct the model using hparams."""
self.hparams['hidden_layer_size'] = self.hparams['n_ae_latents']
if self.model_type == 'conv':
self.encoding = ConvAEMSPSEncoder(self.hparams)
self.decoding = ConvAEDecoder(self.hparams)
elif self.model_type == 'linear':
raise NotImplementedError
# if self.hparams.get('fit_sess_io_layers', False):
# raise NotImplementedError
# n_latents = self.hparams['n_ae_latents']
# self.encoding = LinearAEEncoder(n_latents, self.img_size)
# self.decoding = LinearAEDecoder(n_latents, self.img_size, self.encoding)
else:
raise ValueError('"%s" is an invalid model_type' % self.model_type)
[docs] def forward(self, x, dataset=None, use_mean=False, **kwargs):
"""Process input data.
Parameters
----------
x : :obj:`torch.Tensor` object
input data
dataset : :obj:`int`
used with session-specific io layers
use_mean : :obj:`bool`
True to skip sampling step
Returns
-------
:obj:`tuple`
- x_hat (:obj:`torch.Tensor`): output of shape (n_frames, n_channels, y_pix, x_pix)
- y_hat (:obj:`torch.Tensor`): output of shape (n_frames, n_channels, y_pix, x_pix)
- z (:obj:`torch.Tensor`): sampled latent variable of shape (n_frames, n_latents)
- mu (:obj:`torch.Tensor`): mean paramter of shape (n_frames, n_latents)
- logvar (:obj:`torch.Tensor`): logvar paramter of shape (n_frames, n_latents)
"""
z_s, z_b, z, logvar, pool_idx, outsize = self.encoding(x, dataset=dataset)
mu = torch.cat([z_s, z_b, z], axis=1)
if use_mean:
z = mu
else:
z = reparameterize(mu, logvar)
x_hat = self.decoding(z, pool_idx, outsize, dataset=dataset)
y_hat = self.encoding.D(z_s)
return x_hat, z, mu, logvar, y_hat
[docs] def loss(self, datas, dataset=None, accumulate_grad=True, chunk_size=None):
"""Calculate modified ELBO loss for MSPSVAE.
The batch is split into chunks if larger than a hard-coded `chunk_size` to keep memory
requirements low; gradients are accumulated across all chunks before a gradient step is
taken.
Parameters
----------
datas : :obj:`list` of :obj:`dict`
batch of data; keys should include 'images' and 'masks', if necessary
datasets : :obj:`list` of :obj:`int`
used for embedding loss
accumulate_grad : :obj:`bool`, optional
accumulate gradient for training step
chunk_size : :obj:`int`, optional
deprecated
Returns
-------
:obj:`dict`
- 'loss' (:obj:`float`): full elbo
- 'loss_ll' (:obj:`float`): log-likelihood portion of elbo
- 'loss_kl' (:obj:`float`): kl portion of elbo
- 'loss_mse' (:obj:`float`): mse (without gaussian constants)
- 'beta' (:obj:`float`): weight in front of kl term
"""
if isinstance(datas, list):
x = torch.cat([data['images'][0] for data in datas], dim=0)
y = torch.cat([data['labels'][0] for data in datas], dim=0)
m = torch.cat([data['masks'][0] for data in datas], dim=0) \
if 'masks' in datas[0] else None
n = torch.cat([data['labels_masks'][0] for data in datas], dim=0) \
if 'labels_masks' in datas[0] else None
datasets = np.concatenate(
[d * np.ones(datas[d_idx]['images'].shape[1]) for d_idx, d in enumerate(dataset)])
else:
x = datas['images'][0]
y = datas['labels'][0]
m = datas['masks'][0] if 'masks' in datas else None
n = datas['labels_masks'][0] if 'labels_masks' in datas else None
datasets = None
n_labels = self.hparams['n_labels']
n_background = self.hparams['n_background']
# n_latents = self.hparams['n_ae_latents']
# compute hyperparameters
alpha = self.hparams['ps_vae.alpha']
beta = self.beta_vals[self.curr_epoch]
# gamma = self.hparams['ps_vae.gamma']
delta = self.hparams['ps_vae.delta']
kl = self.kl_anneal_vals[self.curr_epoch]
loss_strs = [
'loss', 'loss_data_ll', 'loss_label_ll', 'loss_zs_kl', 'loss_zu_mi', 'loss_zu_tc',
'loss_zu_dwkl',
# 'loss_AB_orth',
'loss_triplet']
loss_dict_vals = {loss: 0 for loss in loss_strs}
loss_dict_vals['loss_data_mse'] = 0
x_hat, sample, mu, logvar, y_hat = self.forward(x, dataset=None, use_mean=False)
# reset losses
loss_dict_torch = {loss: 0 for loss in loss_strs}
# data log-likelihood
loss_dict_torch['loss_data_ll'] = losses.gaussian_ll(x, x_hat, m)
loss_dict_torch['loss'] -= loss_dict_torch['loss_data_ll']
# label log-likelihood
loss_dict_torch['loss_label_ll'] = losses.gaussian_ll(y, y_hat, n)
loss_dict_torch['loss'] -= alpha * loss_dict_torch['loss_label_ll']
# supervised latents kl
loss_dict_torch['loss_zs_kl'] = losses.kl_div_to_std_normal(
mu[:, :n_labels], logvar[:, :n_labels])
loss_dict_torch['loss'] += loss_dict_torch['loss_zs_kl']
# compute all terms of decomposed elbo at once
index_code_mi, total_correlation, dimension_wise_kl = losses.decomposed_kl(
sample[:, n_labels + n_background:], mu[:, n_labels + n_background:],
logvar[:, n_labels + n_background:])
# unsupervised latents index-code mutual information
loss_dict_torch['loss_zu_mi'] = index_code_mi
loss_dict_torch['loss'] += kl * loss_dict_torch['loss_zu_mi']
# unsupervised latents total correlation
loss_dict_torch['loss_zu_tc'] = total_correlation
loss_dict_torch['loss'] += beta * loss_dict_torch['loss_zu_tc']
# unsupervised latents dimension-wise kl
loss_dict_torch['loss_zu_dwkl'] = dimension_wise_kl
loss_dict_torch['loss'] += kl * loss_dict_torch['loss_zu_dwkl']
# orthogonality between A, B, and C
# A shape: [n_labels, n_latents]
# B shape: [n_latents - n_labels - n_background, n_latents]
# C shape: [n_background, n_latents]
# compute ||UU^T||^2
# loss_dict_torch['loss_AB_orth'] = losses.subspace_overlap(
# self.encoding.A.weight, self.encoding.B.weight, C=self.encoding.C.weight)
#
# loss_dict_torch['loss'] += gamma * loss_dict_torch['loss_AB_orth']
# triplet loss
if isinstance(datas, list):
loss_dict_torch['loss_triplet'] = losses.triplet_loss(
self.TripletLoss, mu[:, n_labels:n_labels + n_background:], datasets)
loss_dict_torch['loss'] += delta * loss_dict_torch['loss_triplet']
else:
# don't record triplet loss info
del loss_dict_torch['loss_triplet']
pass
if accumulate_grad:
loss_dict_torch['loss'].backward()
# get loss values as scalars
for key, val in loss_dict_torch.items():
loss_dict_vals[key] += val.item()
loss_dict_vals['loss_data_mse'] += losses.gaussian_ll_to_mse(
loss_dict_vals['loss_data_ll'], np.prod(x.shape[1:]))
# use variance-weighted r2s to ignore small-variance latents
y_hat_all = y_hat.cpu().detach().numpy()
y_all = y.cpu().detach().numpy()
if n is not None:
n_np = n.cpu().detach().numpy()
r2 = r2_score(y_all[n_np == 1], y_hat_all[n_np == 1], multioutput='variance_weighted')
else:
r2 = r2_score(y_all, y_hat_all, multioutput='variance_weighted')
# store hyperparams
loss_dict_vals['alpha'] = alpha
loss_dict_vals['beta'] = beta
# loss_dict_vals['gamma'] = gamma
loss_dict_vals['delta'] = delta
loss_dict_vals['label_r2'] = r2
# print(self.encoding.A.weight)
# print(self.encoding.B.weight)
# print(self.encoding.C.weight)
return loss_dict_vals
[docs] def get_predicted_labels(self, x, dataset=None, use_mean=True):
"""Process input data to get predicted labels.
Parameters
----------
x : :obj:`torch.Tensor` object
input data
dataset : :obj:`int`
used with session-specific io layers
use_mean : :obj:`bool`
True to skip sampling step
Returns
-------
:obj:`tuple`
- x_hat (:obj:`torch.Tensor`): output of shape (n_frames, n_channels, y_pix, x_pix)
- y_hat (:obj:`torch.Tensor`): output of shape (n_frames, n_channels, y_pix, x_pix)
- z (:obj:`torch.Tensor`): sampled latent variable of shape (n_frames, n_latents)
- mu (:obj:`torch.Tensor`): mean paramter of shape (n_frames, n_latents)
- logvar (:obj:`torch.Tensor`): logvar paramter of shape (n_frames, n_latents)
"""
z_s, _, _, logvar, pool_idx, outsize = self.encoding(x, dataset=dataset)
if not use_mean:
z_s = reparameterize(z_s, logvar[:, :self.n_labels])
y_hat = self.encoding.D(z_s)
return y_hat
[docs] def export_latents(self, data_gen, filename=None):
"""Need to create standard data generator in order to export latents."""
import os
import pickle
from behavenet.data.utils import build_data_generator
from copy import deepcopy
hp_new = deepcopy(self.hparams)
hp_new['n_sessions_per_batch'] = 1 # force standard data generator
hp_new['train_frac'] = 1 # use all training batches
hp_new['trial_splits'] = '1;0;0;0' # no gaps
data_generator = build_data_generator(hp_new, data_gen.datasets_info)
self.eval()
# initialize container for latents
latents = [[] for _ in range(data_generator.n_datasets)]
for sess, dataset in enumerate(data_generator.datasets):
latents[sess] = [np.array([]) for _ in range(dataset.n_trials)]
# partially fill container (gap trials will be included as nans)
dtypes = ['train', 'val', 'test']
for dtype in dtypes:
data_generator.reset_iterators(dtype)
for i in range(data_generator.n_tot_batches[dtype]):
data, sess = data_generator.next_batch(dtype)
# process batch, perhaps in chunks if full batch is too large to fit on gpu
chunk_size = 200
y = data['images'][0]
batch_size = y.shape[0]
if batch_size > chunk_size:
latents[sess][data['batch_idx'].item()] = np.full(
shape=(data['images'].shape[1], self.hparams['n_ae_latents']),
fill_value=np.nan)
# split into chunks
n_chunks = int(np.ceil(batch_size / chunk_size))
for chunk in range(n_chunks):
# take chunks of size chunk_size, plus overlap due to
# max_lags
idx_beg = chunk * chunk_size
idx_end = np.min([(chunk + 1) * chunk_size, batch_size])
y_in = y[idx_beg:idx_end]
output = self.encoding(y_in, dataset=sess)
curr_latents = torch.cat([output[0], output[1], output[2]], axis=1)
latents[sess][data['batch_idx'].item()][idx_beg:idx_end, :] = \
curr_latents.cpu().detach().numpy()
else:
y_in = y
output = self.encoding(y_in, dataset=sess)
curr_latents = torch.cat([output[0], output[1], output[2]], axis=1)
latents[sess][data['batch_idx'].item()] = curr_latents.cpu().detach().numpy()
# save latents separately for each dataset
filenames = []
for sess, dataset in enumerate(data_generator.datasets):
if filename is None:
# get save name which includes lab/expt/animal/session
sess_id = str('%s_%s_%s_%s_latents.pkl' % (
dataset.lab, dataset.expt, dataset.animal, dataset.session))
filename_save = os.path.join(
self.hparams['expt_dir'], 'version_%i' % self.version, sess_id)
else:
filename_save = filename
# save out array in pickle file
print('saving latents %i of %i:\n%s' % (
sess + 1, data_generator.n_datasets, filename_save))
latents_dict = {'latents': latents[sess], 'trials': dataset.batch_idxs}
with open(filename_save, 'wb') as f:
pickle.dump(latents_dict, f)
filenames.append(filename_save)
return filenames
[docs]class ConvAEPSEncoder(ConvAEEncoder):
"""Convolutional encoder that separates label-related subspace."""
def __init__(self, hparams):
from behavenet.models.base import DiagLinear
super().__init__(hparams)
# add linear transformations mapping from NN output to label-, non-label-related subspaces
n_latents = self.hparams['n_ae_latents']
n_labels = self.hparams['n_labels']
# NN -> constrained latents
self.A = nn.Linear(n_latents, n_labels, bias=False)
# NN -> unconstrained latents
self.B = nn.Linear(n_latents, n_latents - n_labels, bias=False)
# constrained latents -> labels (diagonal matrix + bias)
self.D = DiagLinear(n_labels, bias=True)
# fix A, B to be orthogonal (and not trainable)
from scipy.stats import ortho_group
m = ortho_group.rvs(dim=n_latents).astype('float32')
with torch.no_grad():
self.A.weight = nn.Parameter(
torch.from_numpy(m[:n_labels, :]), requires_grad=False)
self.B.weight = nn.Parameter(
torch.from_numpy(m[n_labels:, :]), requires_grad=False)
def __str__(self):
"""Pretty print encoder architecture."""
format_str = 'Encoder architecture:\n'
i = 0
for module in self.encoder:
format_str += str(' {:02d}: {}\n'.format(i, module))
i += 1
# final ff layer
format_str += str(' {:02d}: {}\n'.format(i, self.FF))
# final linear transformations
format_str += str(' {:02d}: {} (to constrained latents)\n'.format(i, self.A))
format_str += str(' {:02d}: {} (to unconstrained latents)\n'.format(i, self.B))
format_str += str(' {:02d}: {} (constrained latents to labels)\n'.format(i, self.D))
return format_str
[docs] def forward(self, x, dataset=None):
"""Process input data.
Parameters
----------
x : :obj:`torch.Tensor` object
input data
dataset : :obj:`int`
used with session-specific io layers
Returns
-------
:obj:`tuple`
- encoder output y (:obj:`torch.Tensor`): constrained latents (predicted labels) of
shape (n_labels)
- encoder output z (:obj:`torch.Tensor`): unconstrained latents of shape
(n_latents - n_labels)
- logvar (:obj:`torch.Tensor`): log variance of latents of shape (n_latents)
- pool_idx (:obj:`list`): max pooling indices for each layer
- output_size (:obj:`list`): output size for each layer
"""
# loop over layers, have to collect pool_idx and output sizes if using max pooling to use
# in unpooling
pool_idx = []
target_output_size = []
for layer in self.encoder:
if isinstance(layer, nn.MaxPool2d):
target_output_size.append(x.size())
x, idx = layer(x)
pool_idx.append(idx)
elif isinstance(layer, nn.ModuleList):
x = layer[dataset](x)
else:
x = layer(x)
# reshape for ff layer
x1 = x.view(x.size(0), -1)
x = self.FF(x1)
# push through linear transformations
y = self.A(x) # NN -> constrained latents
w = self.B(x) # NN -> unconstrained latents
return y, w, self.logvar(x1), pool_idx, target_output_size
[docs]class ConvAEMSPSEncoder(ConvAEEncoder):
"""Convolutional encoder that separates label-related subspace."""
def __init__(self, hparams):
from behavenet.models.base import DiagLinear
super().__init__(hparams)
# add linear transformations mapping from NN output to label-, non-label-related subspaces
n_latents = self.hparams['n_ae_latents']
n_labels = self.hparams['n_labels']
n_background = self.hparams['n_background']
# NN -> supervised latents
self.A = nn.Linear(n_latents, n_labels, bias=False)
# NN -> unsupervised latents
self.B = nn.Linear(n_latents, n_latents - n_labels - n_background, bias=False)
# NN -> background latents
self.C = nn.Linear(n_latents, n_background, bias=True)
# supervised latents -> labels (diagonal matrix + bias)
self.D = DiagLinear(n_labels, bias=True)
# fix A, B, C to be orthogonal (and not trainable)
from scipy.stats import ortho_group
m = ortho_group.rvs(dim=n_latents).astype('float32')
with torch.no_grad():
self.A.weight = nn.Parameter(
torch.from_numpy(m[:n_labels, :]), requires_grad=False)
self.B.weight = nn.Parameter(
torch.from_numpy(m[n_labels + n_background:, :]), requires_grad=False)
self.C.weight = nn.Parameter(
torch.from_numpy(m[n_labels:n_labels + n_background, :]), requires_grad=False)
def __str__(self):
"""Pretty print encoder architecture."""
format_str = 'Encoder architecture:\n'
i = 0
for module in self.encoder:
format_str += str(' {:02d}: {}\n'.format(i, module))
i += 1
# final ff layer
format_str += str(' {:02d}: {}\n'.format(i, self.FF))
# final linear transformations
format_str += str(' {:02d}: {} (to supervised latents)\n'.format(i, self.A))
format_str += str(' {:02d}: {} (to unsupervised latents)\n'.format(i, self.B))
format_str += str(' {:02d}: {} (to background latents)\n'.format(i, self.C))
format_str += str(' {:02d}: {} (supervised latents to labels)\n'.format(i, self.D))
return format_str
[docs] def forward(self, x, dataset=None):
"""Process input data.
Parameters
----------
x : :obj:`torch.Tensor` object
input data
dataset : :obj:`int`
used with session-specific io layers
Returns
-------
:obj:`tuple`
- encoder output y (:obj:`torch.Tensor`): constrained latents (predicted labels) of
shape (n_labels)
- encoder output z (:obj:`torch.Tensor`): unconstrained latents of shape
(n_latents - n_labels)
- logvar (:obj:`torch.Tensor`): log variance of latents of shape (n_latents)
- pool_idx (:obj:`list`): max pooling indices for each layer
- output_size (:obj:`list`): output size for each layer
"""
# loop over layers, have to collect pool_idx and output sizes if using max pooling to use
# in unpooling
pool_idx = []
target_output_size = []
for layer in self.encoder:
if isinstance(layer, nn.MaxPool2d):
target_output_size.append(x.size())
x, idx = layer(x)
pool_idx.append(idx)
elif isinstance(layer, nn.ModuleList):
x = layer[dataset](x)
else:
x = layer(x)
# reshape for ff layer
x1 = x.view(x.size(0), -1)
x = self.FF(x1)
# push through linear transformations
z_s = self.A(x) # NN -> supervised latents
z = self.B(x) # NN -> unsupervised latents
z_b = self.C(x) # NN -> background latents
return z_s, z_b, z, self.logvar(x1), pool_idx, target_output_size