Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Commit bd270fa

Browse files
Michael Marlenfacebook-github-bot
authored andcommitted
Output Layer
Summary: Enabling Multi Label output layer. This is done by migrating domain classifier into a multi label. Reviewed By: shreydesai Differential Revision: D25440015 fbshipit-source-id: ca917d723a1cd06618e6592bdc122f132a5a071d
1 parent 5068b68 commit bd270fa

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

pytext/models/output_layers/doc_classification_output_layer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
KLDivergenceCELoss,
1818
LabelSmoothedCrossEntropyLoss,
1919
MultiLabelSoftMarginLoss,
20+
BinaryCrossEntropyWithLogitsLoss,
2021
)
2122
from pytext.utils.label import get_label_weights
2223
from torch import jit
@@ -43,6 +44,7 @@ class Config(OutputLayerBase.Config):
4344
loss: Union[
4445
CrossEntropyLoss.Config,
4546
BinaryCrossEntropyLoss.Config,
47+
BinaryCrossEntropyWithLogitsLoss.Config,
4648
MultiLabelSoftMarginLoss.Config,
4749
AUCPRHingeLoss.Config,
4850
KLDivergenceBCELoss.Config,
@@ -83,6 +85,8 @@ def from_config(
8385
cls = BinaryClassificationOutputLayer
8486
elif isinstance(loss, MultiLabelSoftMarginLoss):
8587
cls = MultiLabelOutputLayer
88+
elif isinstance(loss, BinaryCrossEntropyWithLogitsLoss):
89+
cls = MultiLabelOutputLayer
8690
else:
8791
cls = MulticlassOutputLayer
8892

0 commit comments

Comments
 (0)