triplet_loss¶
-
behavenet.fitting.losses.triplet_loss(triplet_loss_obj, z, datasets)[source]¶ Compute triplet loss to learn separated embedding space.
Currently only supported for 2- and 3-dataset batches
- Parameters
triplet_loss_obj (
torch.TripletMarginLossobject) – already instantiated triplet loss object; this function splits up the data to give to this objectz (
torch.Tensor) – low-dim data embeddings; shape (N, d), where N is number of samples and d is embedding dimdatasets (
torch.Tensor) – identifies the dataset that each sample belongs to; shape (N,)
- Returns
scalar value; triplet loss
- Return type
torch.Tensor