Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pytext/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .loss import (
AUCPRHingeLoss,
BinaryCrossEntropyLoss,
BinaryCrossEntropyWithLogitsLoss,
CosineEmbeddingLoss,
CrossEntropyLoss,
KLDivergenceBCELoss,
Expand All @@ -26,6 +27,7 @@
"CrossEntropyLoss",
"CosineEmbeddingLoss",
"BinaryCrossEntropyLoss",
"BinaryCrossEntropyWithLogitsLoss",
"MultiLabelSoftMarginLoss",
"KLDivergenceBCELoss",
"KLDivergenceCELoss",
Expand Down
23 changes: 23 additions & 0 deletions pytext/loss/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,29 @@ def __call__(self, log_probs, targets, reduce=True):
)


class BinaryCrossEntropyWithLogitsLoss(Loss):
class Config(ConfigBase):
reduce: bool = True

def __call__(self, logits, targets, reduce=True):
"""
Computes 1-vs-all binary cross entropy loss for multiclass classification. However, unlike BinaryCrossEntropyLoss, we require targets to be a one-hot vector.
"""

target_labels = targets[0].float()

"""
`F.binary_cross_entropy_with_logits` requires the
output of the previous function be already a FloatTensor.
"""

loss = F.binary_cross_entropy_with_logits(
precision.maybe_float(logits), target_labels, reduction="none"
)

return loss.sum(-1).mean() if reduce else loss.sum(-1)


class BinaryCrossEntropyLoss(Loss):
class Config(ConfigBase):
reweight_negative: bool = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
KLDivergenceCELoss,
LabelSmoothedCrossEntropyLoss,
MultiLabelSoftMarginLoss,
BinaryCrossEntropyWithLogitsLoss,
)
from pytext.utils.label import get_label_weights
from torch import jit
Expand All @@ -43,6 +44,7 @@ class Config(OutputLayerBase.Config):
loss: Union[
CrossEntropyLoss.Config,
BinaryCrossEntropyLoss.Config,
BinaryCrossEntropyWithLogitsLoss.Config,
MultiLabelSoftMarginLoss.Config,
AUCPRHingeLoss.Config,
KLDivergenceBCELoss.Config,
Expand Down Expand Up @@ -83,6 +85,8 @@ def from_config(
cls = BinaryClassificationOutputLayer
elif isinstance(loss, MultiLabelSoftMarginLoss):
cls = MultiLabelOutputLayer
elif isinstance(loss, BinaryCrossEntropyWithLogitsLoss):
cls = MultiLabelOutputLayer
else:
cls = MulticlassOutputLayer

Expand Down