triplet_loss¶
- behavenet.fitting.losses.triplet_loss(triplet_loss_obj, z, datasets)[source]¶
Compute triplet loss to learn separated embedding space.
Currently only supported for 2- and 3-dataset batches
- Parameters:
triplet_loss_obj (
torch.TripletMarginLossobject) – already instantiated triplet loss object; this function splits up the data to give to this objectz (
torch.Tensor) – low-dim data embeddings; shape (N, d), where N is number of samples and d is embedding dimdatasets (
torch.Tensor) – identifies the dataset that each sample belongs to; shape (N,)
- Returns:
scalar value; triplet loss
- Return type:
torch.Tensor