triplet_loss

behavenet.fitting.losses.triplet_loss(triplet_loss_obj, z, datasets)[source]

Compute triplet loss to learn separated embedding space.

Currently only supported for 2- and 3-dataset batches

Parameters:
  • triplet_loss_obj (torch.TripletMarginLoss object) – already instantiated triplet loss object; this function splits up the data to give to this object

  • z (torch.Tensor) – low-dim data embeddings; shape (N, d), where N is number of samples and d is embedding dim

  • datasets (torch.Tensor) – identifies the dataset that each sample belongs to; shape (N,)

Returns:

scalar value; triplet loss

Return type:

torch.Tensor