Skip to content

[Torch] Add support for Huber Loss function #4248

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -9532,6 +9532,32 @@ def Torch_AtenKlDivOp : Torch_Op<"aten.kl_div", [
}];
}

def Torch_AtenHuberLossOp : Torch_Op<"aten.huber_loss", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::huber_loss : (Tensor, Tensor, int, float) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$target,
Torch_IntType:$reduction,
Torch_FloatType:$delta
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenHuberLossOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenHuberLossOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}

def Torch_AtenBincountOp : Torch_Op<"aten.bincount", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
33 changes: 33 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10721,6 +10721,31 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.huber_loss\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int, %arg3: !torch.float) -> !torch.list<int> {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: Invalid reduction value.\"\n"
" %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n"
" %int2 = torch.constant.int 2\n"
" %0 = torch.prim.Uninitialized : !torch.list<int>\n"
" %1 = torch.aten.eq.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.list<int>) {\n"
" %3 = func.call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" torch.prim.If.yield %3 : !torch.list<int>\n"
" } else {\n"
" %3 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = torch.aten.__contains__.int_list %3, %arg2 : !torch.list<int>, !torch.int -> !torch.bool\n"
" %5 = torch.prim.If %4 -> (!torch.list<int>) {\n"
" %6 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" torch.prim.If.yield %6 : !torch.list<int>\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield %0 : !torch.list<int>\n"
" }\n"
" torch.prim.If.yield %5 : !torch.list<int>\n"
" }\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.nll_loss_forward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple<list<int>, list<int>> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.nll_loss_forward(%arg0, %arg1, %arg2, %arg3) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.int) -> !torch.tuple<list<int>, list<int>>\n"
" return %0 : !torch.tuple<list<int>, list<int>>\n"
Expand Down Expand Up @@ -14612,6 +14637,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.huber_loss\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int, %arg3: !torch.float) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.mse_loss\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
Expand Down
78 changes: 78 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10707,6 +10707,83 @@ class DecomposeAtenKlDivOp : public OpRewritePattern<AtenKlDivOp> {
};
} // namespace

namespace {
class DecomposeAtenHuberLossOp : public OpRewritePattern<AtenHuberLossOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenHuberLossOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.getSelf();
Value target = op.getTarget();
Value reductionValue = op.getReduction();
Value deltaValue = op.getDelta();

auto selfTy = cast<ValueTensorType>(self.getType());
auto targetTy = cast<ValueTensorType>(target.getType());
auto outTy = cast<ValueTensorType>(op.getType());
if (!selfTy.hasSizes() || !targetTy.hasSizes() || !outTy.hasSizes()) {
return rewriter.notifyMatchFailure(
op, "require self, target and output having sizes!");
}
if (!selfTy.hasDtype() || !targetTy.hasDtype() || !outTy.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "require self, target and output having dtype!");
}

// Squared term: 0.5 * (input - target)^2
Value constOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value constHalf =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.5));
Value inputMinusTarget =
rewriter.create<AtenSubTensorOp>(loc, selfTy, self, target, constOne);
Value squaredValue =
rewriter.create<AtenSquareOp>(loc, selfTy, inputMinusTarget);
Value squaredTerm =
rewriter.create<AtenMulScalarOp>(loc, selfTy, squaredValue, constHalf);

// Delta scaled term: delta * (|input - target| - 0.5 * delta)
Value absDiffValue =
rewriter.create<AtenAbsOp>(loc, selfTy, inputMinusTarget);
Value halfOfDelta = rewriter.create<AtenMulOp>(
loc, rewriter.getType<Torch::FloatType>(), constHalf, deltaValue);
Value absDiffMinusDeltaHalf = rewriter.create<AtenSubScalarOp>(
loc, selfTy, absDiffValue, halfOfDelta, constOne);
Value deltaScaledTerm = rewriter.create<AtenMulScalarOp>(
loc, selfTy, absDiffMinusDeltaHalf, deltaValue);

// Loss calculation based on the condition: |input - target| < delta
ValueTensorType boolTy = ValueTensorType::get(
op.getContext(), selfTy.getSizes(), rewriter.getI1Type());
Value cmpValue =
rewriter.create<AtenLeScalarOp>(loc, boolTy, absDiffValue, deltaValue);
Value lossPointwise = rewriter.create<AtenWhereSelfOp>(
loc, selfTy, cmpValue, squaredTerm, deltaScaledTerm);

// Extract reduction int value from reduction argument
int64_t reduction;
if (!matchPattern(reductionValue, m_TorchConstantInt(&reduction))) {
return rewriter.notifyMatchFailure(op,
"reduction should be a constant int!");
}
Value loss;
Value none = rewriter.create<ConstantNoneOp>(loc);
// reduction: mean
if (reduction == 1) {
loss = rewriter.create<AtenMeanOp>(loc, outTy, lossPointwise, none);
} else if (reduction == 2) {
// reduction: sum
loss = rewriter.create<AtenSumOp>(loc, outTy, lossPointwise, none);
} else {
// reduction: none
loss = lossPointwise;
}
rewriter.replaceOp(op, loss);
return success();
}
};
} // namespace

namespace {
class DecomposeAtenBinaryCrossEntropyWithLogitsOp
: public OpRewritePattern<AtenBinaryCrossEntropyWithLogitsOp> {
Expand Down Expand Up @@ -12696,6 +12773,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenBinaryCrossEntropyWithLogitsOp>(
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenKlDivOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenHuberLossOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTopkOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenArgsortOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenLogaddexpOp>();
target.addIllegalOp<AtenLogaddexp2Op>();
target.addIllegalOp<AtenKlDivOp>();
target.addIllegalOp<AtenHuberLossOp>();

for (auto &opName : backendLegalOpsSet) {
target.addLegalOp(
Expand Down
8 changes: 8 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3103,6 +3103,10 @@
"PoissonNLLLossSumReductionModule_basic",
"PoissonNLLLossNonDefaultEpsModule_basic",
"KlDivLossModule_batchmean_reduction_basic",
"HuberLossModule_default_basic",
"HuberLossModule_reduction_is_none_basic",
"HuberLossModule_mean_reduction_basic",
"HuberLossModule_sum_reduction_basic",
"NormScalarComplexModule_basic",
"NormScalarModule_basic",
"NormScalarOptDimKeepDimComplexModule_basic",
Expand Down Expand Up @@ -4682,6 +4686,10 @@
"NllLossModule_ignore_index_out_of_bounds_basic",
"NllLossModule_mean_basic",
"NllLossModule_sum_basic",
"HuberLossModule_default_basic",
"HuberLossModule_reduction_is_none_basic",
"HuberLossModule_mean_reduction_basic",
"HuberLossModule_sum_reduction_basic",
"NormScalarComplexModule_basic",
"NormScalarModule_basic",
"NormScalarOptDimKeepDimComplexModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2185,6 +2185,14 @@ def aten〇kl_div〡shape(self: List[int], target: List[int], reduction: int = 1
else:
assert False, "Invalid reduction value."

def aten〇huber_loss〡shape(self: List[int], target: List[int], reduction: int = 1, delta: float = 1.) -> List[int]:
if reduction == 0:
return upstream_shape_functions.unary(self)
elif reduction in [1, 2]:
return []
else:
assert False, "Invalid reduction value."

@check_shape_function([
Invocation(TensorOfShape(2, 3), LongTensorOfShape(2), None, 1, -100), # Basic case.
Invocation(TensorOfShape(3), LongTensorOfShape(), None, 1, -100), # No batch dim.
Expand Down Expand Up @@ -4571,6 +4579,14 @@ def aten〇kl_div〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: T
promoted_dtype = promote_dtypes(ranks, dtypes)
return promoted_dtype

def aten〇huber_loss〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], reduction: int = 1, delta: float = 1.) -> int:
self_rank, self_dtype = self_rank_dtype
target_rank, target_dtype = target_rank_dtype
ranks: List[Optional[int]] = [self_rank, target_rank]
dtypes = [self_dtype, target_dtype]
promoted_dtype = promote_dtypes(ranks, dtypes)
return promoted_dtype

@check_dtype_function(_check_two_tensor_op(
output_error_types={torch.bool, torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64}))
def aten〇mse_loss〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], reduction: int = 1) -> int:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,7 @@ def emit_with_mutating_variants(key, **kwargs):
"aten::poisson_nll_loss : (Tensor, Tensor, bool, bool, float, int) -> (Tensor)"
)
emit("aten::kl_div : (Tensor, Tensor, int, bool) -> (Tensor)")
emit("aten::huber_loss : (Tensor, Tensor, int, float) -> (Tensor)")
emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)")
emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)")
emit("aten::linalg_norm : (Tensor, Scalar?, int[]?, bool, int?) -> (Tensor)")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,4 @@ def register_all_tests():
from . import meshgrid
from . import timeout
from . import kl_div_loss
from . import huber_loss
108 changes: 108 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/huber_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

import torch

from torch_mlir_e2e_test.framework import TestUtils
from torch_mlir_e2e_test.registry import register_test_case
from torch_mlir_e2e_test.annotations import annotate_args, export

# ==============================================================================


class HuberLossModule_default(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, x, y):
return torch.ops.aten.huber_loss(x, y)


@register_test_case(module_factory=lambda: HuberLossModule_default())
def HuberLossModule_default_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5, 2), tu.rand(3, 5, 2))


# ==============================================================================


class HuberLossModule_reduction_is_none(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, x, y):
return torch.ops.aten.huber_loss(x, y, delta=2.3, reduction=0)


@register_test_case(module_factory=lambda: HuberLossModule_reduction_is_none())
def HuberLossModule_reduction_is_none_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5, 2), tu.rand(3, 5, 2))


# ==============================================================================


class HuberLossModule_mean_reduction(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, x, y):
return torch.ops.aten.huber_loss(x, y, reduction=1)


@register_test_case(module_factory=lambda: HuberLossModule_mean_reduction())
def HuberLossModule_mean_reduction_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5, 2), tu.rand(3, 5, 2))


# ==============================================================================


class HuberLossModule_sum_reduction(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, x, y):
return torch.ops.aten.huber_loss(x, y, reduction=2)


@register_test_case(module_factory=lambda: HuberLossModule_sum_reduction())
def HuberLossModule_sum_reduction_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5, 2), tu.rand(3, 5, 2))


# ==============================================================================
Loading