Source code for behavenet.fitting.losses

"""Custom losses for PyTorch models."""

import numpy as np
import torch
from torch.nn.modules.loss import _Loss
from torch.distributions.multivariate_normal import MultivariateNormal

# to ignore imports for sphix-autoapidoc
__all__ = [
    'mse', 'gaussian_ll', 'gaussian_ll_to_mse', 'kl_div_to_std_normal', 'index_code_mi',
    'total_correlation', 'dimension_wise_kl_to_std_normal', 'decomposed_kl', 'subspace_overlap',
    'triplet_loss']

LN2PI = np.log(2 * np.pi)


class GaussianNegLogProb(_Loss):
    """Minimize negative Gaussian log probability with learned covariance matrix.

    For now the covariance matrix is not data-dependent
    """

    def __init__(self, size_average=None, reduce=None, reduction='mean'):
        if reduction != 'mean':
            raise NotImplementedError
        super().__init__(size_average, reduce, reduction)

    def forward(self, input, target, precision):
        output_dim = target.shape[1]
        dist = MultivariateNormal(
            loc=input,
            covariance_matrix=1e-3 * torch.eye(output_dim) + precision)
        return torch.mean(-dist.log_prob(target))


[docs]def mse(y_pred, y_true, masks=None): """Compute mean square error (MSE) loss with masks. Parameters ---------- y_pred : :obj:`torch.Tensor` predicted data y_true : :obj:`torch.Tensor` true data masks : :obj:`torch.Tensor`, optional binary mask that is the same size as `y_pred` and `y_true`; by placing 0 entries in the mask, the corresponding dimensions will not contribute to the loss term, and will therefore not contribute to parameter updates Returns ------- :obj:`torch.Tensor` mean square error computed across all dimensions """ if masks is not None: return torch.mean(((y_pred - y_true) ** 2) * masks) else: return torch.mean((y_pred - y_true) ** 2)
[docs]def gaussian_ll(y_pred, y_mean, masks=None, std=1): """Compute multivariate Gaussian log-likelihood with a fixed diagonal noise covariance matrix. Parameters ---------- y_pred : :obj:`torch.Tensor` predicted data of shape (n_frames, ...) y_mean : :obj:`torch.Tensor` true data of shape (n_frames, ...) masks : :obj:`torch.Tensor`, optional binary mask that is the same size as `y_pred` and `y_true`; by placing 0 entries in the mask, the corresponding dimensions will not contribute to the loss term, and will therefore not contribute to parameter updates std : :obj:`float`, optional fixed standard deviation for all dimensions in the multivariate Gaussian Returns ------- :obj:`torch.Tensor` Gaussian log-likelihood summed across dims, averaged across batch """ dims = y_pred.shape n_dims = np.prod(dims[1:]) # first value is n_frames in batch log_var = np.log(std ** 2) if masks is not None: diff_sq = ((y_pred - y_mean) ** 2) * masks else: diff_sq = (y_pred - y_mean) ** 2 ll = - (0.5 * LN2PI + 0.5 * log_var) * n_dims - (0.5 / (std ** 2)) * diff_sq.sum( axis=tuple(1+np.arange(len(dims[1:])))) return torch.mean(ll)
[docs]def gaussian_ll_to_mse(ll, n_dims, gaussian_std=1, mse_std=1): """Convert a Gaussian log-likelihood term to MSE by removing constants and swapping variances. - NOTE: does not currently return correct values if gaussian ll is computed with masks Parameters ---------- ll : :obj:`float` original Gaussian log-likelihood n_dims : :obj:`int` number of dimensions in multivariate Gaussian gaussian_std : :obj:`float` std used to compute Gaussian log-likelihood mse_std : :obj:`float` std used to compute MSE Returns ------- :obj:`float` MSE value """ llc = np.copy(ll) llc += (0.5 * LN2PI + 0.5 * np.log(gaussian_std ** 2)) * n_dims # remove constant llc *= -(gaussian_std ** 2) / 0.5 # undo scaling by variance llc /= n_dims # change sum to mean llc *= 1.0 / (mse_std ** 2) # scale by mse variance return llc
[docs]def kl_div_to_std_normal(mu, logvar): """Compute element-wise KL(q(z) || N(0, 1)) where q(z) is a normal parameterized by mu, logvar. Parameters ---------- mu : :obj:`torch.Tensor` mean parameter of shape (n_frames, n_dims) logvar log variance parameter of shape (n_frames, n_dims) Returns ------- :obj:`torch.Tensor` KL divergence summed across dims, averaged across batch """ kl = 0.5 * torch.sum(logvar.exp() - logvar + mu.pow(2) - 1, dim=1) return torch.mean(kl)
[docs]def index_code_mi(z, mu, logvar): """Estimate index code mutual information in a batch. We ignore the constant as it does not matter for the minimization. The constant should be equal to log(n_frames * dataset_size). Parameters ---------- z : :obj:`torch.Tensor` sample of shape (n_frames, n_dims) mu : :obj:`torch.Tensor` mean parameter of shape (n_frames, n_dims) logvar : :obj:`torch.Tensor` log variance parameter of shape (n_frames, n_dims) Returns ------- :obj:`torch.Tensor` index code mutual information for batch, scalar value """ # Compute log(q(z(x_j)|x_i)) for every sample/dimension in the batch, which is a tensor of # shape (n_frames, n_dims). In the following comments, # (n_frames, n_frames, n_dims) are indexed by [j, i, l]. # z[:, None]: (n_frames, 1, n_dims) # mu[None, :]: (1, n_frames, n_dims) # logvar[None, :]: (1, n_frames, n_dims) log_qz_prob = _gaussian_log_density_unsummed(z[:, None], mu[None, :], logvar[None, :]) # Compute log(q(z(x_j))) as log(sum_i(q(z(x_j)|x_i))) + constant = # log(sum_i(prod_l q(z(x_j)_l|x_i))) + constant. log_qz = torch.logsumexp( torch.sum(log_qz_prob, dim=2, keepdim=False), # sum over gaussian dims dim=1, # logsumexp over batch keepdim=False) # Compute log prod_l q(z(x_j)_l | x_j) = sum_l log q(z(x_j)_l | x_j) log_qz_ = torch.diag(torch.sum(log_qz_prob, dim=2, keepdim=False)) # sum over gaussian dims return torch.mean(log_qz_ - log_qz)
[docs]def total_correlation(z, mu, logvar): """Estimate total correlation in a batch. Compute the expectation over a batch of: E_j [log(q(z(x_j))) - log(prod_l q(z(x_j)_l))] We ignore the constant as it does not matter for the minimization. The constant should be equal to (n_dims - 1) * log(n_frames * dataset_size). Code modified from https://github.com/julian-carpenter/beta-TCVAE/blob/master/nn/losses.py Parameters ---------- z : :obj:`torch.Tensor` sample of shape (n_frames, n_dims) mu : :obj:`torch.Tensor` mean parameter of shape (n_frames, n_dims) logvar : :obj:`torch.Tensor` log variance parameter of shape (n_frames, n_dims) Returns ------- :obj:`torch.Tensor` total correlation for batch, scalar value """ # Compute log(q(z(x_j)|x_i)) for every sample/dimension in the batch, which is a tensor of # shape (n_frames, n_dims). In the following comments, # (n_frames, n_frames, n_dims) are indexed by [j, i, l]. # z[:, None]: (n_frames, 1, n_dims) # mu[None, :]: (1, n_frames, n_dims) # logvar[None, :]: (1, n_frames, n_dims) log_qz_prob = _gaussian_log_density_unsummed(z[:, None], mu[None, :], logvar[None, :]) # Compute log prod_l p(z(x_j)_l) = sum_l(log(sum_i(q(z(x_j)_l|x_i))) + constant) for each # sample in the batch, which is a vector of size (batch_size,). log_qz_product = torch.sum( torch.logsumexp(log_qz_prob, dim=1, keepdim=False), # logsumexp over batch dim=1, # sum over gaussian dims keepdim=False) # Compute log(q(z(x_j))) as log(sum_i(q(z(x_j)|x_i))) + constant = # log(sum_i(prod_l q(z(x_j)_l|x_i))) + constant. log_qz = torch.logsumexp( torch.sum(log_qz_prob, dim=2, keepdim=False), # sum over gaussian dims dim=1, # logsumexp over batch keepdim=False) return torch.mean(log_qz - log_qz_product)
[docs]def dimension_wise_kl_to_std_normal(z, mu, logvar): """Estimate dimensionwise KL divergence to standard normal in a batch. Parameters ---------- z : :obj:`torch.Tensor` sample of shape (n_frames, n_dims) mu : :obj:`torch.Tensor` mean parameter of shape (n_frames, n_dims) logvar : :obj:`torch.Tensor` log variance parameter of shape (n_frames, n_dims) Returns ------- :obj:`torch.Tensor` dimension-wise KL to standard normal for batch, scalar value """ # Compute log(q(z(x_j)|x_i)) for every sample/dimension in the batch, which is a tensor of # shape (n_frames, n_dims). In the following comments, # (n_frames, n_frames, n_dims) are indexed by [j, i, l]. # z[:, None]: (n_frames, 1, n_dims) # mu[None, :]: (1, n_frames, n_dims) # logvar[None, :]: (1, n_frames, n_dims) log_qz_prob = _gaussian_log_density_unsummed(z[:, None], mu[None, :], logvar[None, :]) # Compute log prod_l p(z(x_j)_l) = sum_l(log(sum_i(q(z(x_j)_l|x_i))) + constant) for each # sample in the batch, which is a vector of size (batch_size,). log_qz_product = torch.sum( torch.logsumexp(log_qz_prob, dim=1, keepdim=False), # logsumexp over batch dim=1, # sum over gaussian dims keepdim=False) # Compute log_pz_prob = _gaussian_log_density_unsummed_std_normal(z) log_pz_product = torch.sum(log_pz_prob, dim=1, keepdim=False) # sum over gaussian dims return torch.mean(log_qz_product - log_pz_product)
[docs]def decomposed_kl(z, mu, logvar): """Decompose KL term in VAE loss. Decomposes the KL divergence loss term of the variational autoencoder into three terms: 1. index code mutual information 2. total correlation 3. dimension-wise KL None of these terms can be computed exactly when using stochastic gradient descent. This function instead computes approximations as detailed in https://arxiv.org/pdf/1802.04942.pdf. Parameters ---------- z : :obj:`torch.Tensor` sample of shape (n_frames, n_dims) mu : :obj:`torch.Tensor` mean parameter of shape (n_frames, n_dims) logvar : :obj:`torch.Tensor` log variance parameter of shape (n_frames, n_dims) Returns ------- :obj:`tuple` - index code mutual information (:obj:`torch.Tensor`) - total correlation (:obj:`torch.Tensor`) - dimension-wise KL (:obj:`torch.Tensor`) """ # Compute log(q(z(x_j)|x_i)) for every sample/dimension in the batch, which is a tensor of # shape (n_frames, n_dims). In the following comments, (n_frames, n_frames, n_dims) are indexed # by [j, i, l]. # # Note that the insertion of `None` expands dims to use torch's broadcasting feature # z[:, None]: (n_frames, 1, n_dims) # mu[None, :]: (1, n_frames, n_dims) # logvar[None, :]: (1, n_frames, n_dims) log_qz_prob = _gaussian_log_density_unsummed(z[:, None], mu[None, :], logvar[None, :]) # Compute log(q(z(x_j))) as # log(sum_i(q(z(x_j)|x_i))) + constant # = log(sum_i(prod_l q(z(x_j)_l|x_i))) + constant # = log(sum_i(exp(sum_l log q(z(x_j)_l|x_i))) + constant (assumes q is factorized) log_qz = torch.logsumexp( torch.sum(log_qz_prob, dim=2, keepdim=False), # sum over gaussian dims dim=1, # logsumexp over batch keepdim=False) # Compute log prod_l q(z(x_j)_l | x_j) # = sum_l log q(z(x_j)_l | x_j) log_qz_ = torch.diag(torch.sum(log_qz_prob, dim=2, keepdim=False)) # sum over gaussian dims # Compute log prod_l p(z(x_j)_l) # = sum_l(log(sum_i(q(z(x_j)_l|x_i))) + constant log_qz_product = torch.sum( torch.logsumexp(log_qz_prob, dim=1, keepdim=False), # logsumexp over batch dim=1, # sum over gaussian dims keepdim=False) # Compute sum_l log p(z(x_j)_l) log_pz_prob = _gaussian_log_density_unsummed_std_normal(z) log_pz_product = torch.sum(log_pz_prob, dim=1, keepdim=False) # sum over gaussian dims idx_code_mi = torch.mean(log_qz_ - log_qz) total_corr = torch.mean(log_qz - log_qz_product) dim_wise_kl = torch.mean(log_qz_product - log_pz_product) return idx_code_mi, total_corr, dim_wise_kl
def _gaussian_log_density_unsummed(z, mu, logvar): """First step of Gaussian log-density computation, without summing over dimensions. Assumes a diagonal noise covariance matrix. """ diff_sq = (z - mu) ** 2 inv_var = torch.exp(-logvar) return - 0.5 * (inv_var * diff_sq + logvar + LN2PI) def _gaussian_log_density_unsummed_std_normal(z): """First step of Gaussian log-density computation, without summing over dimensions. Assumes a diagonal noise covariance matrix. """ diff_sq = z ** 2 return - 0.5 * (diff_sq + LN2PI)
[docs]def subspace_overlap(A, B, C=None): """Compute inner product between subspaces defined by matrices A and B. Parameters ---------- A : :obj:`torch.Tensor` shape (a, d) B : :obj:`torch.Tensor` shape (b, d) C : :obj:`torch.Tensor`, optional shape (c, d) Returns ------- :obj:`torch.Tensor` scalar value; Frobenious norm of UU^T divided by number of entries """ if C is None: U = torch.cat([A, B], dim=0) else: U = torch.cat([A, B, C], dim=0) d = U.shape[0] eye = torch.eye(d, device=U.device) return torch.mean((torch.matmul(U, torch.transpose(U, 1, 0)) - eye).pow(2))
[docs]def triplet_loss(triplet_loss_obj, z, datasets): """Compute triplet loss to learn separated embedding space. Currently only supported for 2- and 3-dataset batches Parameters ---------- triplet_loss_obj : :obj:`torch.TripletMarginLoss` object already instantiated triplet loss object; this function splits up the data to give to this object z : :obj:`torch.Tensor` low-dim data embeddings; shape (N, d), where N is number of samples and d is embedding dim datasets : :obj:`torch.Tensor` identifies the dataset that each sample belongs to; shape (N,) Returns ------- :obj:`torch.Tensor` scalar value; triplet loss """ dataset_ids = np.unique(datasets) n_datasets = len(dataset_ids) if n_datasets == 2: # randomly split dataset into 3 chunks n_chunks = 3 a_idxs_ = np.random.permutation(np.where(datasets == dataset_ids[0])[0]) b_idxs_ = np.random.permutation(np.where(datasets == dataset_ids[1])[0]) # make sure chunks are all same length m = np.min([len(a_idxs_) // n_chunks, len(b_idxs_) // n_chunks]) a_idxs = [a_idxs_[i::n_chunks][:m] for i in range(n_chunks)] b_idxs = [b_idxs_[i::n_chunks][:m] for i in range(n_chunks)] loss = \ triplet_loss_obj(z[a_idxs[0]], z[a_idxs[1]], z[b_idxs[2]]) + \ triplet_loss_obj(z[b_idxs[0]], z[b_idxs[1]], z[a_idxs[2]]) + \ torch.pairwise_distance(z[a_idxs[0]], z[a_idxs[1]]).mean() + \ torch.pairwise_distance(z[b_idxs[0]], z[b_idxs[1]]).mean() n_loss_terms = 3 # legacy error for now elif n_datasets == 3: # randomly split dataset into 6 chunks n_chunks = 6 a_idxs_ = np.random.permutation(np.where(datasets == dataset_ids[0])[0]) b_idxs_ = np.random.permutation(np.where(datasets == dataset_ids[1])[0]) c_idxs_ = np.random.permutation(np.where(datasets == dataset_ids[2])[0]) # make sure all chunks are same length m = np.min([len(a_idxs_) // n_chunks, len(b_idxs_) // n_chunks, len(c_idxs_) // n_chunks]) a_idxs = [a_idxs_[i::n_chunks][:m] for i in range(n_chunks)] b_idxs = [b_idxs_[i::n_chunks][:m] for i in range(n_chunks)] c_idxs = [c_idxs_[i::n_chunks][:m] for i in range(n_chunks)] loss = \ triplet_loss_obj(z[a_idxs[0]], z[a_idxs[1]], z[b_idxs[4]]) + \ triplet_loss_obj(z[a_idxs[2]], z[a_idxs[3]], z[c_idxs[4]]) + \ triplet_loss_obj(z[b_idxs[0]], z[b_idxs[1]], z[a_idxs[4]]) + \ triplet_loss_obj(z[b_idxs[2]], z[b_idxs[3]], z[c_idxs[5]]) + \ triplet_loss_obj(z[c_idxs[0]], z[c_idxs[1]], z[a_idxs[5]]) + \ triplet_loss_obj(z[c_idxs[2]], z[c_idxs[3]], z[b_idxs[5]]) + \ torch.pairwise_distance(z[a_idxs[0]], z[a_idxs[1]]).mean() + \ torch.pairwise_distance(z[a_idxs[2]], z[a_idxs[3]]).mean() + \ torch.pairwise_distance(z[b_idxs[0]], z[b_idxs[1]]).mean() + \ torch.pairwise_distance(z[b_idxs[2]], z[b_idxs[3]]).mean() + \ torch.pairwise_distance(z[c_idxs[0]], z[c_idxs[1]]).mean() + \ torch.pairwise_distance(z[c_idxs[2]], z[c_idxs[3]]).mean() n_loss_terms = 6 elif n_datasets == 4: # randomly split dataset into 9 chunks n_chunks = 9 a_idxs_ = np.random.permutation(np.where(datasets == dataset_ids[0])[0]) b_idxs_ = np.random.permutation(np.where(datasets == dataset_ids[1])[0]) c_idxs_ = np.random.permutation(np.where(datasets == dataset_ids[2])[0]) d_idxs_ = np.random.permutation(np.where(datasets == dataset_ids[3])[0]) # make sure all chunks are same length m = np.min([ len(a_idxs_) // n_chunks, len(b_idxs_) // n_chunks, len(c_idxs_) // n_chunks, len(d_idxs_) // n_chunks]) a_idxs = [a_idxs_[i::n_chunks][:m] for i in range(n_chunks)] b_idxs = [b_idxs_[i::n_chunks][:m] for i in range(n_chunks)] c_idxs = [c_idxs_[i::n_chunks][:m] for i in range(n_chunks)] d_idxs = [d_idxs_[i::n_chunks][:m] for i in range(n_chunks)] loss = \ triplet_loss_obj(z[a_idxs[0]], z[a_idxs[1]], z[b_idxs[6]]) + \ triplet_loss_obj(z[a_idxs[2]], z[a_idxs[3]], z[c_idxs[6]]) + \ triplet_loss_obj(z[a_idxs[4]], z[a_idxs[5]], z[d_idxs[6]]) + \ triplet_loss_obj(z[b_idxs[0]], z[b_idxs[1]], z[a_idxs[6]]) + \ triplet_loss_obj(z[b_idxs[2]], z[b_idxs[3]], z[c_idxs[7]]) + \ triplet_loss_obj(z[b_idxs[4]], z[b_idxs[5]], z[d_idxs[7]]) + \ triplet_loss_obj(z[c_idxs[0]], z[c_idxs[1]], z[a_idxs[7]]) + \ triplet_loss_obj(z[c_idxs[2]], z[c_idxs[3]], z[b_idxs[7]]) + \ triplet_loss_obj(z[c_idxs[4]], z[c_idxs[5]], z[d_idxs[8]]) + \ triplet_loss_obj(z[d_idxs[0]], z[d_idxs[1]], z[a_idxs[8]]) + \ triplet_loss_obj(z[d_idxs[2]], z[d_idxs[3]], z[b_idxs[8]]) + \ triplet_loss_obj(z[d_idxs[4]], z[d_idxs[5]], z[c_idxs[8]]) + \ torch.pairwise_distance(z[a_idxs[0]], z[a_idxs[1]]).mean() + \ torch.pairwise_distance(z[a_idxs[2]], z[a_idxs[3]]).mean() + \ torch.pairwise_distance(z[a_idxs[4]], z[a_idxs[5]]).mean() + \ torch.pairwise_distance(z[b_idxs[0]], z[b_idxs[1]]).mean() + \ torch.pairwise_distance(z[b_idxs[2]], z[b_idxs[3]]).mean() + \ torch.pairwise_distance(z[b_idxs[4]], z[b_idxs[5]]).mean() + \ torch.pairwise_distance(z[c_idxs[0]], z[c_idxs[1]]).mean() + \ torch.pairwise_distance(z[c_idxs[2]], z[c_idxs[3]]).mean() + \ torch.pairwise_distance(z[c_idxs[4]], z[c_idxs[5]]).mean() + \ torch.pairwise_distance(z[d_idxs[0]], z[d_idxs[1]]).mean() + \ torch.pairwise_distance(z[d_idxs[2]], z[d_idxs[3]]).mean() + \ torch.pairwise_distance(z[d_idxs[4]], z[d_idxs[5]]).mean() n_loss_terms = 12 else: raise NotImplementedError return loss / n_loss_terms