Skip to content

Commit 9d5b896

Browse files
committed
add the code for aten.hinge_embedding_loss op
1 parent 38d5f99 commit 9d5b896

File tree

8 files changed

+247
-0
lines changed

8 files changed

+247
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9452,6 +9452,32 @@ def Torch_AtenNllLossBackwardOp : Torch_Op<"aten.nll_loss_backward", [
94529452
}];
94539453
}
94549454

9455+
def Torch_AtenHingeEmbeddingLossOp : Torch_Op<"aten.hinge_embedding_loss", [
9456+
AllowsTypeRefinement,
9457+
HasValueSemantics,
9458+
ReadOnly
9459+
]> {
9460+
let summary = "Generated op for `aten::hinge_embedding_loss : (Tensor, Tensor, float, int) -> (Tensor)`";
9461+
let arguments = (ins
9462+
AnyTorchTensorType:$self,
9463+
AnyTorchTensorType:$target,
9464+
Torch_FloatType:$margin,
9465+
Torch_IntType:$reduction
9466+
);
9467+
let results = (outs
9468+
AnyTorchOptionalTensorType:$result
9469+
);
9470+
let hasCustomAssemblyFormat = 1;
9471+
let extraClassDefinition = [{
9472+
ParseResult AtenHingeEmbeddingLossOp::parse(OpAsmParser &parser, OperationState &result) {
9473+
return parseDefaultTorchOp(parser, result, 4, 1);
9474+
}
9475+
void AtenHingeEmbeddingLossOp::print(OpAsmPrinter &printer) {
9476+
printDefaultTorchOp(printer, *this, 4, 1);
9477+
}
9478+
}];
9479+
}
9480+
94559481
def Torch_AtenBincountOp : Torch_Op<"aten.bincount", [
94569482
AllowsTypeRefinement,
94579483
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10696,6 +10696,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1069610696
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg1) : (!torch.list<int>) -> !torch.list<int>\n"
1069710697
" return %0 : !torch.list<int>\n"
1069810698
" }\n"
10699+
" func.func @\"__torch_mlir_shape_fn.aten.hinge_embedding_loss\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.int) -> !torch.list<int> {\n"
10700+
" %int1 = torch.constant.int 1\n"
10701+
" %int2 = torch.constant.int 2\n"
10702+
" %0 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
10703+
" %1 = torch.aten.__contains__.int_list %0, %arg3 : !torch.list<int>, !torch.int -> !torch.bool\n"
10704+
" %2 = torch.prim.If %1 -> (!torch.list<int>) {\n"
10705+
" %3 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
10706+
" torch.prim.If.yield %3 : !torch.list<int>\n"
10707+
" } else {\n"
10708+
" %3 = func.call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
10709+
" torch.prim.If.yield %3 : !torch.list<int>\n"
10710+
" }\n"
10711+
" return %2 : !torch.list<int>\n"
10712+
" }\n"
1069910713
" func.func @\"__torch_mlir_shape_fn.aten.mse_loss\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int) -> !torch.list<int> {\n"
1070010714
" %int0 = torch.constant.int 0\n"
1070110715
" %0 = torch.aten.eq.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n"
@@ -13340,6 +13354,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1334013354
" }\n"
1334113355
" return %6 : !torch.int\n"
1334213356
" }\n"
13357+
" func.func @\"__torch_mlir_dtype_fn.aten.hinge_embedding_loss\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.float, %arg3: !torch.int) -> !torch.int {\n"
13358+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
13359+
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
13360+
" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
13361+
" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
13362+
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
13363+
" return %4 : !torch.int\n"
13364+
" }\n"
1334313365
" func.func @\"__torch_mlir_dtype_fn.aten.max_pool2d_with_indices_backward\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.tuple<int, int>) -> !torch.int {\n"
1334413366
" %none = torch.constant.none\n"
1334513367
" %str = torch.constant.str \"AssertionError: \"\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10471,6 +10471,89 @@ class DecomposeAtenNllLossForwardOp
1047110471
};
1047210472
} // namespace
1047310473

10474+
namespace {
10475+
// Decompostion of aten.hinge_embedding_loss op
10476+
// Ref:
10477+
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Loss.cpp#L182
10478+
// The Hinge Embedding Loss:
10479+
// | input, if target == 1
10480+
// loss(x) = |
10481+
// | max(0, margin - input), if target == -1
10482+
class DecomposeHingeEmbeddingLoss
10483+
: public OpRewritePattern<AtenHingeEmbeddingLossOp> {
10484+
using OpRewritePattern<AtenHingeEmbeddingLossOp>::OpRewritePattern;
10485+
LogicalResult matchAndRewrite(AtenHingeEmbeddingLossOp op,
10486+
PatternRewriter &rewriter) const override {
10487+
Location loc = op.getLoc();
10488+
auto input = op.getSelf();
10489+
auto target = op.getTarget();
10490+
10491+
auto inputTy = dyn_cast<ValueTensorType>(input.getType());
10492+
if (!inputTy.hasDtype() || !inputTy.hasSizes())
10493+
return rewriter.notifyMatchFailure(op, "input must have dtype and size");
10494+
10495+
auto targetTy = dyn_cast<ValueTensorType>(target.getType());
10496+
if (!targetTy.hasDtype() || !targetTy.hasSizes())
10497+
return rewriter.notifyMatchFailure(op, "target must have dtype and size");
10498+
auto resultTy = dyn_cast<ValueTensorType>(op.getType());
10499+
Value minusOne = getConstantWithGivenDtypeAndValue(rewriter, loc, -1,
10500+
targetTy.getDtype());
10501+
Value one = getConstantWithGivenDtypeAndValue(rewriter, loc, 1,
10502+
targetTy.getDtype());
10503+
Value zero = getConstantWithGivenDtypeAndValue(rewriter, loc, 0,
10504+
targetTy.getDtype());
10505+
Value alpha =
10506+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
10507+
auto boolType = targetTy.getWithSizesAndDtype(targetTy.getSizes(),
10508+
rewriter.getI1Type());
10509+
// input - margin
10510+
auto inputMinusMargin = rewriter.create<AtenSubScalarOp>(
10511+
loc, inputTy, input, op.getMargin(), alpha);
10512+
// multiply by -1 to get margin - input
10513+
auto marginDiff = rewriter.create<AtenMulScalarOp>(
10514+
loc, inputTy, inputMinusMargin, minusOne);
10515+
// max(0, margin - input) => clamping the minimum value of margin - input at
10516+
// 0
10517+
auto marginClamp =
10518+
rewriter.create<AtenClampMinOp>(loc, inputTy, marginDiff, zero);
10519+
// Compute mask: target != 1
10520+
auto targetNotOne =
10521+
rewriter.create<AtenNeScalarOp>(loc, boolType, target, one);
10522+
// If target != -1 use marginClamp otherwise 0.
10523+
auto outputMargin = rewriter.create<AtenWhereScalarOtherOp>(
10524+
loc, inputTy, targetNotOne, marginClamp, zero);
10525+
// Compute mask: target != 1
10526+
auto targetNotMinusOne =
10527+
rewriter.create<AtenNeScalarOp>(loc, boolType, target, minusOne);
10528+
// If target != 1 use the original input. Otherwise 0.
10529+
auto outputSelf = rewriter.create<AtenWhereScalarOtherOp>(
10530+
loc, inputTy, targetNotMinusOne, input, zero);
10531+
// Add : outputMargin + outputSelf
10532+
auto output = rewriter.create<AtenAddTensorOp>(loc, inputTy, outputMargin,
10533+
outputSelf, /*alpha=*/alpha);
10534+
int64_t reduction;
10535+
if (!matchPattern(op.getReduction(), m_TorchConstantInt(&reduction))) {
10536+
return rewriter.notifyMatchFailure(op,
10537+
"reduction should be a constant int!");
10538+
}
10539+
Value loss;
10540+
Value none = rewriter.create<ConstantNoneOp>(loc);
10541+
// reduction: mean
10542+
if (reduction == 1) {
10543+
loss = rewriter.create<AtenMeanOp>(loc, resultTy, output, none);
10544+
} else if (reduction == 2) {
10545+
// reduction: sum
10546+
loss = rewriter.create<AtenSumOp>(loc, resultTy, output, none);
10547+
} else {
10548+
// reduction: none
10549+
loss = output;
10550+
}
10551+
rewriter.replaceOp(op, loss);
10552+
return success();
10553+
}
10554+
};
10555+
} // namespace
10556+
1047410557
namespace {
1047510558
class DecomposeAtenBinaryCrossEntropyWithLogitsOp
1047610559
: public OpRewritePattern<AtenBinaryCrossEntropyWithLogitsOp> {
@@ -12384,6 +12467,7 @@ class DecomposeComplexOpsPass
1238412467
addPatternIfTargetOpIsIllegal<DecomposeAtenOneHotOp>(patterns);
1238512468
addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns);
1238612469
addPatternIfTargetOpIsIllegal<DecomposeAtenNllLossForwardOp>(patterns);
12470+
addPatternIfTargetOpIsIllegal<DecomposeHingeEmbeddingLoss>(patterns);
1238712471
addPatternIfTargetOpIsIllegal<DecomposeAtenBinaryCrossEntropyWithLogitsOp>(
1238812472
patterns);
1238912473
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
537537
target.addIllegalOp<AtenLerpTensorOp>();
538538
target.addIllegalOp<AtenMseLossOp>();
539539
target.addIllegalOp<AtenL1LossOp>();
540+
target.addIllegalOp<AtenHingeEmbeddingLossOp>();
540541
target.addIllegalOp<AtenRandintLowOp>();
541542
target.addIllegalOp<AtenRandintOp>();
542543
target.addIllegalOp<AtenVarMeanCorrectionOp>();

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1789,6 +1789,9 @@
17891789
"L1LossMeanReductionModule_basic",
17901790
"L1LossNoReductionModule_basic",
17911791
"L1LossSumReductionModule_basic",
1792+
"HingeEmbeddingLossReductionMeanModule_basic",
1793+
"HingeEmbeddingLossReductionSumModule_basic",
1794+
"HingeEmbeddingLossWithoutReductionModule_basic",
17921795
"PixelShuffleModuleStaticRank3Int64_basic",
17931796
"PixelShuffleModuleStaticRank4Float32_basic",
17941797
"RandIntLowModule_basic",
@@ -2958,6 +2961,10 @@
29582961
"GtFloatIntModule_basic",
29592962
"GtIntModule_basic",
29602963
"HardtanhBackward_basic",
2964+
"HingeEmbeddingLossBasicModule_basic",
2965+
"HingeEmbeddingLossReductionMeanModule_basic",
2966+
"HingeEmbeddingLossReductionSumModule_basic",
2967+
"HingeEmbeddingLossWithoutReductionModule_basic",
29612968
"HstackBasicComplexModule_basic",
29622969
"HstackBasicFloatModule_basic",
29632970
"HstackBasicIntFloatModule_basic",
@@ -3953,6 +3960,10 @@
39533960
"NllLossStaticModule_mean_basic",
39543961
"NllLossStaticModule_sum_basic",
39553962
"NllLossStaticModule_weight_basic",
3963+
"HingeEmbeddingLossBasicModule_basic",
3964+
"HingeEmbeddingLossReductionMeanModule_basic",
3965+
"HingeEmbeddingLossReductionSumModule_basic",
3966+
"HingeEmbeddingLossWithoutReductionModule_basic",
39563967
"Exp2StaticModule_basic",
39573968
"ElementwiseRreluWithNoiseEvalModule_basic",
39583969
"ElementwiseRreluWithNoiseEvalStaticModule_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2183,6 +2183,12 @@ def aten〇nll_loss_forward〡shape(self: List[int], target: List[int], weight:
21832183
def aten〇nll_loss_backward〡shape(grad_output: List[int], self: List[int], target: List[int], weight: Optional[List[int]], reduction: int, ignore_index: int, total_weight: List[int]) -> List[int]:
21842184
return upstream_shape_functions.unary(self)
21852185

2186+
def aten〇hinge_embedding_loss〡shape(self: List[int], target: List[int], margin: float = 1., reduction: int = 1) -> List[int]:
2187+
if reduction in [1,2]:
2188+
return []
2189+
else:
2190+
return upstream_shape_functions.unary(self)
2191+
21862192
# TODO: upstream this
21872193
def aten〇mse_loss〡shape(self: List[int], target: List[int], reduction: int = 1) -> List[int]:
21882194
if reduction == 0:
@@ -3953,6 +3959,13 @@ def aten〇nll_loss_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], se
39533959
return torch.int64
39543960
return result
39553961

3962+
def aten〇hinge_embedding_loss〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], margin: float = 1., reduction: int = 1) -> int:
3963+
self_rank, self_dtype = self_rank_dtype
3964+
target_rank, target_dtype = target_rank_dtype
3965+
ranks: List[Optional[int]] = [self_rank, target_rank]
3966+
dtypes = [self_dtype, target_dtype]
3967+
return promote_dtypes(ranks, dtypes)
3968+
39563969
@check_dtype_function(_check_tensors_with_the_same_dtype(
39573970
None, [(2, 4, 7, 6), (2, 4, 6, 5)], None, None,
39583971
[2, 2], [1, 1], [1, 1], [1, 1], False, TensorOfShape(2, 4, 7, 6, dtype=torch.int64)) +

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -760,6 +760,7 @@ def emit_with_mutating_variants(key, **kwargs):
760760
emit(
761761
"aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)"
762762
)
763+
emit("aten::hinge_embedding_loss : (Tensor, Tensor, float, int) -> (Tensor)")
763764
emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)")
764765
emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)")
765766
emit("aten::linalg_norm : (Tensor, Scalar?, int[]?, bool, int?) -> (Tensor)")

projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2455,6 +2455,95 @@ def BinaryCrossEntropyWithLogitsStaticModule_basic(module, tu: TestUtils):
24552455
# ==============================================================================
24562456

24572457

2458+
class HingeEmbeddingLossBasicModule(torch.nn.Module):
2459+
def __init__(self):
2460+
super().__init__()
2461+
2462+
@export
2463+
@annotate_args(
2464+
[
2465+
None,
2466+
([-1, -1, -1], torch.float32, True),
2467+
([-1, -1, -1], torch.float32, True),
2468+
]
2469+
)
2470+
def forward(self, input, target):
2471+
return torch.ops.aten.hinge_embedding_loss(
2472+
input, target, margin=1.5, reduction=1
2473+
)
2474+
2475+
2476+
@register_test_case(module_factory=lambda: HingeEmbeddingLossBasicModule())
2477+
def HingeEmbeddingLossBasicModule_basic(module, tu: TestUtils):
2478+
module.forward(tu.rand(1, 2, 3), tu.rand(1, 2, 3))
2479+
2480+
2481+
class HingeEmbeddingLossReductionMeanModule(torch.nn.Module):
2482+
def __init__(self):
2483+
super().__init__()
2484+
2485+
@export
2486+
@annotate_args(
2487+
[
2488+
None,
2489+
([8, 1], torch.float32, True),
2490+
([1, 1], torch.float32, True),
2491+
]
2492+
)
2493+
def forward(self, input, target):
2494+
return torch.ops.aten.hinge_embedding_loss(input, target, reduction=1)
2495+
2496+
2497+
@register_test_case(module_factory=lambda: HingeEmbeddingLossReductionMeanModule())
2498+
def HingeEmbeddingLossReductionMeanModule_basic(module, tu: TestUtils):
2499+
module.forward(tu.rand(8, 1), tu.rand(1, 1))
2500+
2501+
2502+
class HingeEmbeddingLossReductionSumModule(torch.nn.Module):
2503+
def __init__(self):
2504+
super().__init__()
2505+
2506+
@export
2507+
@annotate_args(
2508+
[
2509+
None,
2510+
([2, 5], torch.float32, True),
2511+
([1, 1], torch.float32, True),
2512+
]
2513+
)
2514+
def forward(self, input, target):
2515+
return torch.ops.aten.hinge_embedding_loss(input, target, reduction=2)
2516+
2517+
2518+
@register_test_case(module_factory=lambda: HingeEmbeddingLossReductionSumModule())
2519+
def HingeEmbeddingLossReductionSumModule_basic(module, tu: TestUtils):
2520+
module.forward(tu.rand(2, 5), tu.rand(1, 1))
2521+
2522+
2523+
class HingeEmbeddingLossWithoutReductionModule(torch.nn.Module):
2524+
def __init__(self):
2525+
super().__init__()
2526+
2527+
@export
2528+
@annotate_args(
2529+
[
2530+
None,
2531+
([8, 5], torch.float32, True),
2532+
([1], torch.float32, True),
2533+
]
2534+
)
2535+
def forward(self, input, target):
2536+
return torch.ops.aten.hinge_embedding_loss(input, target, margin=1.0)
2537+
2538+
2539+
@register_test_case(module_factory=lambda: HingeEmbeddingLossWithoutReductionModule())
2540+
def HingeEmbeddingLossWithoutReductionModule_basic(module, tu: TestUtils):
2541+
module.forward(tu.rand(8, 5), tu.rand(1))
2542+
2543+
2544+
# ==============================================================================
2545+
2546+
24582547
class TraceModule(torch.nn.Module):
24592548
def __init__(self) -> None:
24602549
super().__init__()

0 commit comments

Comments
 (0)