Skip to content

Commit 8cee8ed

Browse files
authored
[TORCH] Add support for aten.heaviside Op (#4220)
- Decomposed heaviside op into Aten ops. - Added test cases in the e2e part. This implementation addresses and closes #4211
1 parent 4f3a60b commit 8cee8ed

File tree

8 files changed

+211
-0
lines changed

8 files changed

+211
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13181,6 +13181,53 @@ def Torch_AtenWhereScalarSelfOp : Torch_Op<"aten.where.ScalarSelf", [
1318113181
let hasFolder = 1;
1318213182
}
1318313183

13184+
def Torch_AtenHeavisideOp : Torch_Op<"aten.heaviside", [
13185+
AllowsTypeRefinement,
13186+
HasValueSemantics,
13187+
ReadOnly
13188+
]> {
13189+
let summary = "Generated op for `aten::heaviside : (Tensor, Tensor) -> (Tensor)`";
13190+
let arguments = (ins
13191+
AnyTorchTensorType:$self,
13192+
AnyTorchTensorType:$values
13193+
);
13194+
let results = (outs
13195+
AnyTorchOptionalTensorType:$result
13196+
);
13197+
let hasCustomAssemblyFormat = 1;
13198+
let extraClassDefinition = [{
13199+
ParseResult AtenHeavisideOp::parse(OpAsmParser &parser, OperationState &result) {
13200+
return parseDefaultTorchOp(parser, result, 2, 1);
13201+
}
13202+
void AtenHeavisideOp::print(OpAsmPrinter &printer) {
13203+
printDefaultTorchOp(printer, *this, 2, 1);
13204+
}
13205+
}];
13206+
}
13207+
13208+
def Torch_AtenHeaviside_Op : Torch_Op<"aten.heaviside_", [
13209+
IsTrailingUnderscoreInplaceVariant,
13210+
AllowsTypeRefinement
13211+
]> {
13212+
let summary = "Generated op for `aten::heaviside_ : (Tensor, Tensor) -> (Tensor)`";
13213+
let arguments = (ins
13214+
Torch_NonValueTensorType:$self,
13215+
Torch_NonValueTensorType:$values
13216+
);
13217+
let results = (outs
13218+
AnyTorchOptionalNonValueTensorType:$result
13219+
);
13220+
let hasCustomAssemblyFormat = 1;
13221+
let extraClassDefinition = [{
13222+
ParseResult AtenHeaviside_Op::parse(OpAsmParser &parser, OperationState &result) {
13223+
return parseDefaultTorchOp(parser, result, 2, 1);
13224+
}
13225+
void AtenHeaviside_Op::print(OpAsmPrinter &printer) {
13226+
printDefaultTorchOp(printer, *this, 2, 1);
13227+
}
13228+
}];
13229+
}
13230+
1318413231
def Torch_AtenNanToNumOp : Torch_Op<"aten.nan_to_num", [
1318513232
AllowsTypeRefinement,
1318613233
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9675,6 +9675,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
96759675
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg2) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
96769676
" return %0 : !torch.list<int>\n"
96779677
" }\n"
9678+
" func.func @\"__torch_mlir_shape_fn.aten.heaviside\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
9679+
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
9680+
" return %0 : !torch.list<int>\n"
9681+
" }\n"
96789682
" func.func @\"__torch_mlir_shape_fn.aten.nan_to_num\"(%arg0: !torch.list<int>, %arg1: !torch.optional<float>, %arg2: !torch.optional<float>, %arg3: !torch.optional<float>) -> !torch.list<int> {\n"
96799683
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
96809684
" return %0 : !torch.list<int>\n"
@@ -15283,6 +15287,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1528315287
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
1528415288
" return %4 : !torch.int\n"
1528515289
" }\n"
15290+
" func.func @\"__torch_mlir_dtype_fn.aten.heaviside\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
15291+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
15292+
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
15293+
" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
15294+
" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
15295+
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
15296+
" return %4 : !torch.int\n"
15297+
" }\n"
1528615298
" func.func @\"__torch_mlir_dtype_fn.aten.nan_to_num\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<float>, %arg2: !torch.optional<float>, %arg3: !torch.optional<float>) -> !torch.int {\n"
1528715299
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1528815300
" return %0#1 : !torch.int\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11304,6 +11304,76 @@ class DecomposeAtenSgnOp : public OpRewritePattern<AtenSgnOp> {
1130411304
};
1130511305
} // namespace
1130611306

11307+
namespace {
11308+
// Decomposed aten.heaviside op into
11309+
// using aten.eq, aten.lt, aten.logical_or, aten.where
11310+
// Heaviside(x, y) returns
11311+
// 0 if x < 0
11312+
// y if x == 0
11313+
// 1 if x > 0
11314+
class DecomposeAtenHeaviside : public OpRewritePattern<AtenHeavisideOp> {
11315+
public:
11316+
using OpRewritePattern::OpRewritePattern;
11317+
LogicalResult matchAndRewrite(AtenHeavisideOp op,
11318+
PatternRewriter &rewriter) const override {
11319+
auto input = op.getSelf();
11320+
auto value = op.getValues();
11321+
auto loc = op.getLoc();
11322+
auto inputTy = dyn_cast<BaseTensorType>(input.getType());
11323+
if (!inputTy || !inputTy.hasDtype() || !inputTy.hasSizes())
11324+
return rewriter.notifyMatchFailure(op, "input must have dtype and size.");
11325+
11326+
auto valueTy = dyn_cast<BaseTensorType>(value.getType());
11327+
if (!valueTy || !valueTy.hasDtype() || !valueTy.hasSizes())
11328+
return rewriter.notifyMatchFailure(op, "value must have dtype and size.");
11329+
auto resultTy = dyn_cast<BaseTensorType>(op.getType());
11330+
SmallVector<int64_t> broadcastShape;
11331+
SmallVector<Value> broadcastShapeValue;
11332+
computeBroadcastShape(rewriter, loc, input, value, broadcastShape,
11333+
broadcastShapeValue);
11334+
11335+
auto broadcastType = ValueTensorType::get(
11336+
op.getContext(), llvm::ArrayRef(broadcastShape), resultTy.getDtype());
11337+
auto boolBroadcastType = ValueTensorType::get(
11338+
op.getContext(), llvm::ArrayRef(broadcastShape), rewriter.getI1Type());
11339+
Value indexBroadcastShapeTorchList = rewriter.create<PrimListConstructOp>(
11340+
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
11341+
broadcastShapeValue);
11342+
auto inputBroadcasted = rewriter.create<AtenBroadcastToOp>(
11343+
loc, broadcastType, input, indexBroadcastShapeTorchList);
11344+
auto valueBroadcasted = rewriter.create<AtenBroadcastToOp>(
11345+
loc, broadcastType, value, indexBroadcastShapeTorchList);
11346+
11347+
Value zero = getConstantWithGivenDtypeAndValue(rewriter, loc, 0,
11348+
resultTy.getDtype());
11349+
Value one = getConstantWithGivenDtypeAndValue(rewriter, loc, 1,
11350+
resultTy.getDtype());
11351+
// Compute mask: input == 0
11352+
auto inputEqZero = rewriter
11353+
.create<AtenEqScalarOp>(loc, boolBroadcastType,
11354+
inputBroadcasted, zero)
11355+
->getResult(0);
11356+
// Compute mask: input < 0
11357+
auto inputLtZero = rewriter.create<AtenLtScalarOp>(loc, boolBroadcastType,
11358+
inputBroadcasted, zero);
11359+
// Compute mask: isnan(input)
11360+
auto isNan =
11361+
rewriter.create<AtenIsnanOp>(loc, boolBroadcastType, inputBroadcasted);
11362+
// Combine: input < 0 || isnan(input)
11363+
auto inputNegativeOrNan = rewriter.create<AtenLogicalOrOp>(
11364+
loc, boolBroadcastType, inputLtZero, isNan);
11365+
// Select 0 if input < 0 or input is nan, else 1
11366+
auto zerosOrOnes = rewriter.create<AtenWhereScalarOp>(
11367+
loc, resultTy, inputNegativeOrNan, zero, one);
11368+
// Final result: if input == 0, take from valueBroadcasted, else take from
11369+
// zerosOrOnes
11370+
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resultTy, inputEqZero,
11371+
valueBroadcasted, zerosOrOnes);
11372+
return success();
11373+
}
11374+
};
11375+
} // namespace
11376+
1130711377
namespace {
1130811378
// Unconditionally decompose `torch.type_as` into `prim.dtype` +
1130911379
// `torch.to.dtype`.
@@ -12528,6 +12598,7 @@ class DecomposeComplexOpsPass
1252812598
DecomposeConstantTensorNewLikeOp<AtenNewOnesOp, AtenOnesOp>>(patterns);
1252912599
addPatternIfTargetOpIsIllegal<DecomposeAtenHardtanhOp>(patterns);
1253012600
addPatternIfTargetOpIsIllegal<DecomposeAtenFullOp>(patterns);
12601+
addPatternIfTargetOpIsIllegal<DecomposeAtenHeaviside>(patterns);
1253112602
addPatternIfTargetOpIsIllegal<DecomposeAtenLinearOp>(patterns);
1253212603
addPatternIfTargetOpIsIllegal<DecomposeAtenMishOp>(patterns);
1253312604
addPatternIfTargetOpIsIllegal<DecomposeAtenFullLikeOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
461461
target.addIllegalOp<AtenSquareOp>();
462462
target.addIllegalOp<AtenVarOp>();
463463
target.addIllegalOp<AtenStdOp>();
464+
target.addIllegalOp<AtenHeavisideOp>();
464465
target.addIllegalOp<Aten_UnsafeViewOp>();
465466
target.addIllegalOp<Aten_ReshapeAliasOp>();
466467
target.addIllegalOp<AtenBernoulliOp>();

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1259,6 +1259,7 @@
12591259
"ElementwiseToDtypeI64ToI8Module_basic",
12601260
"ElementwiseToDtypeIdentityModule_basic",
12611261
"ElementwiseUnaryModule_basic",
1262+
"ElementwiseHeavisideModule_basic",
12621263
"EmptyLikeMemoryFormatModule_basic",
12631264
"EmptyLikeModule_defaultDtype",
12641265
"EmptyLikeModule_falsePinMemory",
@@ -1863,6 +1864,7 @@
18631864
"ElementwiseFracModule_basic",
18641865
"ElementwiseLdexpModule_basic",
18651866
"ElementwiseSignbitIntModule_basic",
1867+
"ElementwiseHeavisideModule_basic",
18661868
"Exp2StaticIntModule_basic",
18671869
"MaxPool1dEmptyStrideStaticModule_basic",
18681870
"MaxPool1dStaticCeilModeTrueModule_basic",
@@ -2976,6 +2978,9 @@
29762978
"GtFloatIntModule_basic",
29772979
"GtIntModule_basic",
29782980
"HardtanhBackward_basic",
2981+
"ElementwiseHeavisideModule_basic",
2982+
"ElementwiseHeavisideIntModule_basic",
2983+
"ElementwiseHeavisideNoBroadcastModule_basic",
29792984
"HstackBasicComplexModule_basic",
29802985
"HstackBasicFloatModule_basic",
29812986
"HstackBasicIntFloatModule_basic",
@@ -4002,6 +4007,9 @@
40024007
"ElementwiseRreluWithNoiseEvalStaticModule_basic",
40034008
"ElementwiseRreluWithNoiseTrainModule_basic",
40044009
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
4010+
"ElementwiseHeavisideModule_basic",
4011+
"ElementwiseHeavisideIntModule_basic",
4012+
"ElementwiseHeavisideNoBroadcastModule_basic",
40054013
"RreluWithNoiseBackwardEvalModule_basic",
40064014
"RreluWithNoiseBackwardEvalStaticModule_basic",
40074015
"RreluWithNoiseBackwardTrainModule_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1775,6 +1775,9 @@ def aten〇where〇ScalarOther〡shape(condition: List[int], self: List[int], ot
17751775
def aten〇where〇ScalarSelf〡shape(condition: List[int], self: float, other: List[int]) -> List[int]:
17761776
return upstream_shape_functions.broadcast(condition, other)
17771777

1778+
def aten〇heaviside〡shape(self: List[int], values: List[int]) -> List[int]:
1779+
return upstream_shape_functions.broadcast(self, values)
1780+
17781781
def aten〇nan_to_num〡shape(self: List[int], nan: Optional[float] = None, posinf: Optional[float] = None, neginf: Optional[float] = None) -> List[int]:
17791782
return upstream_shape_functions.unary(self)
17801783

@@ -5114,6 +5117,14 @@ def aten〇where〇ScalarSelf〡dtype(condition_rank_dtype: Tuple[int, int], sel
51145117
dtypes = [get_dtype_of_scalar(self), other_dtype]
51155118
return promote_dtypes(ranks, dtypes)
51165119

5120+
def aten〇heaviside〡dtype(self_rank_dtype: Tuple[int, int], values_rank_dtype: Tuple[int, int]) -> int:
5121+
self_rank,self_dtype = self_rank_dtype
5122+
values_rank,values_dtype = values_rank_dtype
5123+
ranks: List[Optional[int]] = [self_rank, values_rank]
5124+
dtypes = [self_dtype, values_dtype]
5125+
promoted_dtype = promote_dtypes(ranks, dtypes)
5126+
return promoted_dtype
5127+
51175128
@check_dtype_function(
51185129
_check_tensors_with_the_same_dtype(num_of_tensors=1))
51195130
def aten〇nan_to_num〡dtype(self_rank_dtype: Tuple[int, int], nan: Optional[float] = None, posinf: Optional[float] = None, neginf: Optional[float] = None) -> int:

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -964,6 +964,7 @@ def emit_with_mutating_variants(key, **kwargs):
964964
emit(
965965
"aten::where.ScalarSelf : (Tensor, Scalar, Tensor) -> (Tensor)", has_folder=True
966966
)
967+
emit_with_mutating_variants("aten::heaviside : (Tensor, Tensor) -> (Tensor)")
967968
emit("aten::nan_to_num : (Tensor, float?, float?, float?) -> (Tensor)")
968969
emit(
969970
"aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)",

projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,66 @@ def ElementwiseLtFloatScalarModule_basic(module, tu: TestUtils):
298298
# ==============================================================================
299299

300300

301+
class ElementwiseHeavisideModule(torch.nn.Module):
302+
def __init__(self):
303+
super().__init__()
304+
305+
@export
306+
@annotate_args([None, ([5], torch.float32, True), ([1], torch.float32, True)])
307+
def forward(self, x, values):
308+
return torch.heaviside(x, values)
309+
310+
311+
@register_test_case(module_factory=lambda: ElementwiseHeavisideModule())
312+
def ElementwiseHeavisideModule_basic(module, tu: TestUtils):
313+
module.forward(
314+
torch.tensor([1.0, -2.0, torch.inf, torch.nan, -torch.inf]), torch.tensor([5.0])
315+
)
316+
317+
318+
class ElementwiseHeavisideIntModule(torch.nn.Module):
319+
def __init__(self):
320+
super().__init__()
321+
322+
@export
323+
@annotate_args(
324+
[None, ([-1, -1, -1], torch.int64, True), ([-1, -1, -1, -1], torch.int64, True)]
325+
)
326+
def forward(self, x, values):
327+
return torch.heaviside(x, values)
328+
329+
330+
@register_test_case(module_factory=lambda: ElementwiseHeavisideIntModule())
331+
def ElementwiseHeavisideIntModule_basic(module, tu: TestUtils):
332+
module.forward(
333+
tu.randint(1, 2, 3, low=-100, high=1000),
334+
tu.randint(1, 1, 1, 1, low=-100, high=1000),
335+
)
336+
337+
338+
class ElementwiseHeavisideNoBroadcastModule(torch.nn.Module):
339+
def __init__(self):
340+
super().__init__()
341+
342+
@export
343+
@annotate_args(
344+
[None, ([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True)]
345+
)
346+
def forward(self, x, values):
347+
return torch.heaviside(x, values)
348+
349+
350+
@register_test_case(module_factory=lambda: ElementwiseHeavisideNoBroadcastModule())
351+
def ElementwiseHeavisideNoBroadcastModule_basic(module, tu: TestUtils):
352+
module.forward(
353+
tu.rand(5, 8),
354+
tu.rand(5, 8),
355+
)
356+
357+
358+
# ==============================================================================
359+
360+
301361
class ElementwiseLtIntScalarModule(torch.nn.Module):
302362
def __init__(self):
303363
super().__init__()

0 commit comments

Comments
 (0)