@@ -10553,6 +10553,89 @@ class DecomposeAtenNllLossForwardOp
10553
10553
};
10554
10554
} // namespace
10555
10555
10556
+ namespace {
10557
+ // Decompostion of aten.hinge_embedding_loss op
10558
+ // Ref:
10559
+ // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Loss.cpp#L182
10560
+ // The Hinge Embedding Loss:
10561
+ // | input, if target == 1
10562
+ // loss(x) = |
10563
+ // | max(0, margin - input), if target == -1
10564
+ class DecomposeHingeEmbeddingLoss
10565
+ : public OpRewritePattern<AtenHingeEmbeddingLossOp> {
10566
+ using OpRewritePattern<AtenHingeEmbeddingLossOp>::OpRewritePattern;
10567
+ LogicalResult matchAndRewrite(AtenHingeEmbeddingLossOp op,
10568
+ PatternRewriter &rewriter) const override {
10569
+ Location loc = op.getLoc();
10570
+ auto input = op.getSelf();
10571
+ auto target = op.getTarget();
10572
+
10573
+ auto inputTy = dyn_cast<ValueTensorType>(input.getType());
10574
+ if (!inputTy.hasDtype() || !inputTy.hasSizes())
10575
+ return rewriter.notifyMatchFailure(op, "input must have dtype and size");
10576
+
10577
+ auto targetTy = dyn_cast<ValueTensorType>(target.getType());
10578
+ if (!targetTy.hasDtype() || !targetTy.hasSizes())
10579
+ return rewriter.notifyMatchFailure(op, "target must have dtype and size");
10580
+ auto resultTy = dyn_cast<ValueTensorType>(op.getType());
10581
+ Value minusOne = getConstantWithGivenDtypeAndValue(rewriter, loc, -1,
10582
+ targetTy.getDtype());
10583
+ Value one = getConstantWithGivenDtypeAndValue(rewriter, loc, 1,
10584
+ targetTy.getDtype());
10585
+ Value zero = getConstantWithGivenDtypeAndValue(rewriter, loc, 0,
10586
+ targetTy.getDtype());
10587
+ Value alpha =
10588
+ rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
10589
+ auto boolType = targetTy.getWithSizesAndDtype(targetTy.getSizes(),
10590
+ rewriter.getI1Type());
10591
+ // input - margin
10592
+ auto inputMinusMargin = rewriter.create<AtenSubScalarOp>(
10593
+ loc, inputTy, input, op.getMargin(), alpha);
10594
+ // multiply by -1 to get margin - input
10595
+ auto marginDiff = rewriter.create<AtenMulScalarOp>(
10596
+ loc, inputTy, inputMinusMargin, minusOne);
10597
+ // max(0, margin - input) => clamping the minimum value of margin - input at
10598
+ // 0
10599
+ auto marginClamp =
10600
+ rewriter.create<AtenClampMinOp>(loc, inputTy, marginDiff, zero);
10601
+ // Compute mask: target != 1
10602
+ auto targetNotOne =
10603
+ rewriter.create<AtenNeScalarOp>(loc, boolType, target, one);
10604
+ // If target != 1 use marginClamp otherwise 0.
10605
+ auto outputMargin = rewriter.create<AtenWhereScalarOtherOp>(
10606
+ loc, inputTy, targetNotOne, marginClamp, zero);
10607
+ // Compute mask: target != -1
10608
+ auto targetNotMinusOne =
10609
+ rewriter.create<AtenNeScalarOp>(loc, boolType, target, minusOne);
10610
+ // If target != -1 use the original input. Otherwise 0.
10611
+ auto outputSelf = rewriter.create<AtenWhereScalarOtherOp>(
10612
+ loc, inputTy, targetNotMinusOne, input, zero);
10613
+ // Add : outputMargin + outputSelf
10614
+ auto output = rewriter.create<AtenAddTensorOp>(loc, inputTy, outputMargin,
10615
+ outputSelf, /*alpha=*/alpha);
10616
+ int64_t reduction;
10617
+ if (!matchPattern(op.getReduction(), m_TorchConstantInt(&reduction))) {
10618
+ return rewriter.notifyMatchFailure(op,
10619
+ "reduction should be a constant int!");
10620
+ }
10621
+ Value loss;
10622
+ Value none = rewriter.create<ConstantNoneOp>(loc);
10623
+ // reduction: mean
10624
+ if (reduction == 1) {
10625
+ loss = rewriter.create<AtenMeanOp>(loc, resultTy, output, none);
10626
+ } else if (reduction == 2) {
10627
+ // reduction: sum
10628
+ loss = rewriter.create<AtenSumOp>(loc, resultTy, output, none);
10629
+ } else {
10630
+ // reduction: none
10631
+ loss = output;
10632
+ }
10633
+ rewriter.replaceOp(op, loss);
10634
+ return success();
10635
+ }
10636
+ };
10637
+ } // namespace
10638
+
10556
10639
namespace {
10557
10640
class DecomposeAtenBinaryCrossEntropyWithLogitsOp
10558
10641
: public OpRewritePattern<AtenBinaryCrossEntropyWithLogitsOp> {
@@ -12467,6 +12550,7 @@ class DecomposeComplexOpsPass
12467
12550
addPatternIfTargetOpIsIllegal<DecomposeAtenOneHotOp>(patterns);
12468
12551
addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns);
12469
12552
addPatternIfTargetOpIsIllegal<DecomposeAtenNllLossForwardOp>(patterns);
12553
+ addPatternIfTargetOpIsIllegal<DecomposeHingeEmbeddingLoss>(patterns);
12470
12554
addPatternIfTargetOpIsIllegal<DecomposeAtenBinaryCrossEntropyWithLogitsOp>(
12471
12555
patterns);
12472
12556
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
0 commit comments