diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 7716c059c874..65ec383410b7 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -11041,6 +11041,7 @@ def Torch_AtenAllBoolOp : Torch_Op<"aten.all.bool", [ printDefaultTorchOp(printer, *this, 1, 1); } }]; + let hasFolder = 1; } def Torch_AtenAllDimOp : Torch_Op<"aten.all.dim", [ @@ -12027,6 +12028,29 @@ def Torch_AtenBroadcastToOp : Torch_Op<"aten.broadcast_to", [ let hasFolder = 1; } +def Torch_AtenBroadcastTensorsOp : Torch_Op<"aten.broadcast_tensors", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::broadcast_tensors : (Tensor[]) -> (Tensor[])`"; + let arguments = (ins + AnyTorchListOfTensorType:$tensors + ); + let results = (outs + AnyTorchListOfTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBroadcastTensorsOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenBroadcastTensorsOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenIndexTensorOp : Torch_Op<"aten.index.Tensor", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index a000b7ab2f98..cd93af1fd8d4 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -60,10 +60,10 @@ Type getBuiltInTypeForTorchScalar(Type type); Value getDtypeIntValueForType(PatternRewriter &rewriter, Location loc, Type dtype); -// Checks whether the `inputA` and `inputB` are broadcast compatible or not. If +// Checks whether the inputs are broadcast compatible or not. If // yes, then computes the final broadcast shape. void computeBroadcastShape(PatternRewriter &rewriter, Location loc, - Value inputA, Value inputB, + SmallVector inputs, SmallVector &resultShape, SmallVector &resultShapeValue); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index fbaf8a1f756b..d2e3c94733e9 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1065,9 +1065,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( } else { SmallVector resultBroadcastShapeInt; SmallVector resultBroadcastShapeValue; - Torch::computeBroadcastShape(rewriter, binder.getLoc(), curr, - valList[i], resultBroadcastShapeInt, - resultBroadcastShapeValue); + Torch::computeBroadcastShape( + rewriter, binder.getLoc(), {curr, valList[i]}, + resultBroadcastShapeInt, resultBroadcastShapeValue); auto baseType = Torch::ValueTensorType::get( binder.op->getContext(), resultBroadcastShapeInt, resultType.getOptionalDtype()); diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index e50be5ff97ae..7f1490a25a01 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2838,6 +2838,24 @@ OpFoldResult AtenAnyBoolOp::fold(FoldAdaptor adaptor) { return nullptr; } +//===----------------------------------------------------------------------===// +// AtenAllBoolOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenAllBoolOp::fold(FoldAdaptor adaptor) { + auto inputConstruct = getSelf().getDefiningOp(); + if (!inputConstruct || isListPotentiallyMutated(inputConstruct)) + return nullptr; + // If all operands are a constant true, return true. + for (auto operand : inputConstruct.getOperands()) { + bool b = true; + if (!matchPattern(operand, m_TorchConstantBool(&b)) || !b) { + return nullptr; + } + } + return getI1IntegerAttr(getContext(), true); +} + //===----------------------------------------------------------------------===// // AtenFloatScalarOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 23f1814cc008..e14d61c6c2d9 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7796,6 +7796,37 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.expand(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.broadcast_tensors\"(%arg0: !torch.list>) -> !torch.list> {\n" +" %true = torch.constant.bool true\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.list>) {\n" +" %3 = torch.prim.ListConstruct : () -> !torch.list>\n" +" torch.prim.If.yield %3 : !torch.list>\n" +" } else {\n" +" %3 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list>, !torch.int -> !torch.list\n" +" %4 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" +" %5 = torch.aten.__range_length %int1, %4, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %6 = torch.prim.Loop %5, %true, init(%3) {\n" +" ^bb0(%arg1: !torch.int, %arg2: !torch.list):\n" +" %9 = torch.aten.__derive_index %arg1, %int1, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %10 = torch.aten.__getitem__.t %arg0, %9 : !torch.list>, !torch.int -> !torch.list\n" +" %11 = func.call @__torch__.torch.jit._shape_functions.broadcast(%arg2, %10) : (!torch.list, !torch.list) -> !torch.list\n" +" torch.prim.Loop.condition %true, iter(%11 : !torch.list)\n" +" } : (!torch.int, !torch.bool, !torch.list) -> !torch.list\n" +" %7 = torch.prim.ListConstruct : () -> !torch.list>\n" +" %8 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" +" torch.prim.Loop %8, %true, init() {\n" +" ^bb0(%arg1: !torch.int):\n" +" %9 = torch.aten.append.t %7, %6 : !torch.list>, !torch.list -> !torch.list>\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" torch.prim.If.yield %7 : !torch.list>\n" +" }\n" +" return %2 : !torch.list>\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.view\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.view(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -12447,6 +12478,35 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.broadcast_tensors\"(%arg0: !torch.list>) -> !torch.list> {\n" +" %true = torch.constant.bool true\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" +" %1 = torch.prim.Loop %0, %true, init(%int0) {\n" +" ^bb0(%arg1: !torch.int, %arg2: !torch.int):\n" +" %4 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list>, !torch.int -> !torch.tuple\n" +" %5 = torch.prim.TupleIndex %4, %int0 : !torch.tuple, !torch.int -> !torch.int\n" +" %6 = torch.aten.gt.int %5, %arg2 : !torch.int, !torch.int -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.int) {\n" +" %8 = torch.prim.TupleIndex %4, %int0 : !torch.tuple, !torch.int -> !torch.int\n" +" torch.prim.If.yield %8 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %arg2 : !torch.int\n" +" }\n" +" torch.prim.Loop.condition %true, iter(%7 : !torch.int)\n" +" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n" +" %2 = torch.prim.ListConstruct : () -> !torch.list>\n" +" %3 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" +" torch.prim.Loop %3, %true, init() {\n" +" ^bb0(%arg1: !torch.int):\n" +" %4 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list>, !torch.int -> !torch.tuple\n" +" %5:2 = torch.prim.TupleUnpack %4 : !torch.tuple -> !torch.int, !torch.int\n" +" %6 = torch.prim.TupleConstruct %1, %5#1 : !torch.int, !torch.int -> !torch.tuple\n" +" %7 = torch.aten.append.t %2, %6 : !torch.list>, !torch.tuple -> !torch.list>\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" return %2 : !torch.list>\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.cosine_similarity\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.float) -> !torch.int {\n" " %int7 = torch.constant.int 7\n" " %int6 = torch.constant.int 6\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index cb49fa97b86a..b75b227adde1 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -24,7 +24,6 @@ #include "llvm/ADT/StringSet.h" #include #include - using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; @@ -3415,7 +3414,7 @@ class DecomposeAtenLinalgCrossOp : public OpRewritePattern { // calculate common shape for broadcast SmallVector broadcastShape; SmallVector broadcastShapeValue; - computeBroadcastShape(rewriter, loc, self, other, broadcastShape, + computeBroadcastShape(rewriter, loc, {self, other}, broadcastShape, broadcastShapeValue); Type broadcastType = ValueTensorType::get( @@ -8962,7 +8961,7 @@ class DecomposeAtenCosineSimilarityOp // Broadcast x1 and x2 to the same shape SmallVector indexBroadcastShapeInt; SmallVector indexBroadcastShapeValue; - computeBroadcastShape(rewriter, loc, x1, x2, indexBroadcastShapeInt, + computeBroadcastShape(rewriter, loc, {x1, x2}, indexBroadcastShapeInt, indexBroadcastShapeValue); Type dtype = cast(x1.getType()).getOptionalDtype(); Type broadcastType = ValueTensorType::get( @@ -11329,7 +11328,7 @@ class DecomposeAtenHeaviside : public OpRewritePattern { auto resultTy = dyn_cast(op.getType()); SmallVector broadcastShape; SmallVector broadcastShapeValue; - computeBroadcastShape(rewriter, loc, input, value, broadcastShape, + computeBroadcastShape(rewriter, loc, {input, value}, broadcastShape, broadcastShapeValue); auto broadcastType = ValueTensorType::get( @@ -12427,6 +12426,52 @@ class DecomposeAtenRoundDecimalsOp }; } // namespace +namespace { +class DecomposeAtenBroadcastTensorsOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenBroadcastTensorsOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + SmallVector tensors; + if (!getListConstructElements(op.getTensors(), tensors)) + return rewriter.notifyMatchFailure(op, "Unable to get tensors"); + int64_t numTensors = tensors.size(); + + SmallVector broadcastShape; + SmallVector broadcastShapeValue; + + computeBroadcastShape(rewriter, loc, tensors, broadcastShape, + broadcastShapeValue); + + auto resType = cast(tensors[0].getType()); + auto dtype = resType.getDtype(); + Type broadcastType = ValueTensorType::get( + op.getContext(), llvm::ArrayRef(broadcastShape), dtype); + + Value broadcastShapeTorchList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), + broadcastShapeValue); + + SmallVector broadcastedValues; + for (int64_t i = 0; i < numTensors; i++) { + auto inputTensor = tensors[i]; + auto broadcastedVal = rewriter.create( + loc, broadcastType, inputTensor, broadcastShapeTorchList); + broadcastedValues.push_back(broadcastedVal); + } + + auto broadcastedValuesList = rewriter.create( + loc, Torch::ListType::get(broadcastType), broadcastedValues); + + rewriter.replaceOp(op, broadcastedValuesList); + return success(); + } +}; +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -12628,6 +12673,7 @@ class DecomposeComplexOpsPass DecomposeAtenAdaptivePool2dOp>(patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenAdaptivePool2dOp>(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 6d6ed9cad50d..f11732d56097 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -519,6 +519,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 388e31353571..bc11f10f8fce 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -479,78 +479,114 @@ FailureOr Torch::unsqueezeTensor(PatternRewriter &rewriter, return unsqueezed; } -// Checks whether the `shapeA` and `shapeB` are broadcast compatible or not. If +// Checks whether the inputs are broadcast compatible or not. If // yes, then computes the final broadcast shape. void Torch::computeBroadcastShape(PatternRewriter &rewriter, Location loc, - Value inputA, Value inputB, + SmallVector inputs, SmallVector &resultShape, SmallVector &resultShapeValue) { - SmallVector shapeA{ - cast(inputA.getType()).getSizes()}; - SmallVector shapeB{ - cast(inputB.getType()).getSizes()}; - unsigned rankA = shapeA.size(); - unsigned rankB = shapeB.size(); - unsigned minRank = rankA > rankB ? rankB : rankA; + + SmallVector> shapes; + SmallVector ranks; + SmallVector maxShapeValues; + + for (auto input : inputs) { + SmallVector shape{ + cast(input.getType()).getSizes()}; + shapes.push_back(shape); + ranks.push_back(shape.size()); + } + + Value torchCstOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + unsigned maxRank = *std::max_element(ranks.begin(), ranks.end()); + // Check whether the shapes of the tensors are broadcastable or not. // Two tensors are “broadcastable” if the following rules hold: // 1.) Each tensor has at least one dimension. // 2.) When iterating over the dimension sizes, starting at the trailing // dimension, the dimension sizes must either be equal, one of them is 1, or // one of them does not exist. - for (unsigned i = 0; i < minRank; i++) { - Value sizeDimA = rewriter.create( - loc, rewriter.getI64IntegerAttr(rankA - i - 1)); - Value sizeDimB = rewriter.create( - loc, rewriter.getI64IntegerAttr(rankB - i - 1)); - Value sizeInputA = - rewriter.createOrFold(loc, inputA, sizeDimA); - Value sizeInputB = - rewriter.createOrFold(loc, inputB, sizeDimB); - Value torchCstOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - Value cmpSizeAEqualsSizeB = - rewriter.create(loc, sizeInputA, sizeInputB); - Value cmpSizeAEqualsOne = - rewriter.create(loc, sizeInputA, torchCstOne); - Value cmpSizeBEqualsOne = - rewriter.create(loc, sizeInputB, torchCstOne); - Value anyBoolOpList = rewriter.create( - loc, Torch::ListType::get(cmpSizeAEqualsOne.getType()), - SmallVector{cmpSizeAEqualsSizeB, cmpSizeAEqualsOne, - cmpSizeBEqualsOne}); - Value cmp = rewriter.create(loc, anyBoolOpList); - rewriter.create( - loc, cmp, "tensors are not broadcast compatible"); + for (unsigned i = 0; i < maxRank; i++) { + + SmallVector sizeInputs; + for (auto [idx, input] : llvm::enumerate(inputs)) { + int sizeDimIdx = ranks[idx] - i - 1; + if (sizeDimIdx >= 0) { + auto sizeDim = rewriter.create( + loc, rewriter.getI64IntegerAttr(sizeDimIdx)); + sizeInputs.push_back( + rewriter.createOrFold(loc, input, sizeDim)); + } + } + + // Compute shape value of broadcast result, + // which is the maximum of dimension sizes across all inputs + Value maxShapeVal = sizeInputs.front(); + for (auto sizeInput : sizeInputs) { + maxShapeVal = rewriter.create(loc, maxShapeVal, sizeInput); + } + maxShapeValues.push_back(maxShapeVal); + + SmallVector predicates; + for (auto sizeVal : sizeInputs) { + Value cmpSizeEquals = + rewriter.create(loc, sizeVal, maxShapeVal); + Value cmpSizeEqualsOne = + rewriter.create(loc, sizeVal, torchCstOne); + Value anyBoolOpList = rewriter.create( + loc, Torch::ListType::get(cmpSizeEquals.getType()), + SmallVector{cmpSizeEquals, cmpSizeEqualsOne}); + Value cmp = rewriter.create(loc, anyBoolOpList); + predicates.push_back(cmp); + } + + if (!predicates.empty()) { + Value anyBoolOpList = rewriter.create( + loc, Torch::ListType::get(predicates.front().getType()), predicates); + Value cmp = rewriter.create(loc, anyBoolOpList); + rewriter.create( + loc, cmp, "tensors are not broadcast compatible"); + } } + // If we reach here then it means both the shapes are broadcast compatible. - resultShape = rankA >= rankB ? shapeA : shapeB; - Value shapeTensor = rankA >= rankB ? inputA : inputB; + auto maxRankIdx = + std::max_element(ranks.begin(), ranks.end()) - ranks.begin(); + resultShape = shapes[maxRankIdx]; + Value shapeTensor = inputs[maxRankIdx]; + for (unsigned i = 0; i < resultShape.size(); i++) { Value sizeDim = rewriter.create( loc, rewriter.getI64IntegerAttr(i)); resultShapeValue.push_back( rewriter.createOrFold(loc, shapeTensor, sizeDim)); } - unsigned resultRank = resultShape.size(); - for (unsigned i = 0; i < minRank; i++) { - Value sizeDimA = rewriter.create( - loc, rewriter.getI64IntegerAttr(rankA - i - 1)); - Value sizeDimB = rewriter.create( - loc, rewriter.getI64IntegerAttr(rankB - i - 1)); - Value sizeInputA = - rewriter.createOrFold(loc, inputA, sizeDimA); - Value sizeInputB = - rewriter.createOrFold(loc, inputB, sizeDimB); - resultShapeValue[resultRank - i - 1] = - rewriter.create(loc, sizeInputA, sizeInputB); - if (shapeA[rankA - i - 1] == kUnknownSize || - shapeB[rankB - i - 1] == kUnknownSize) { + for (unsigned i = 0; i < maxRank; i++) { + + resultShapeValue[resultRank - i - 1] = maxShapeValues[i]; + + // Compute result shape if all input shapes are known + bool unknownSize = false; + for (auto [idx, shape] : llvm::enumerate(shapes)) { + if (ranks[idx] - i - 1 < shape.size() && + shape[ranks[idx] - i - 1] == kUnknownSize) { + unknownSize = true; + } + } + + if (unknownSize) { resultShape[resultRank - i - 1] = kUnknownSize; } else { - resultShape[resultRank - i - 1] = - std::max(shapeA[rankA - i - 1], shapeB[rankB - i - 1]); + + int64_t maxShape = 1; + for (auto [idx, shape] : llvm::enumerate(shapes)) { + if (ranks[idx] - i - 1 < shape.size()) { + maxShape = std::max(maxShape, shape[ranks[idx] - i - 1]); + } + } + resultShape[resultRank - i - 1] = maxShape; } } } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 62590eb48365..4d7233772d04 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -3469,6 +3469,7 @@ "StdCorrectionEmptyDimModule_basic", "VarCorrectionEmptyDimModule_basic", "VarDimEmptyDimModule_basic", + "BroadcastTensorsModule_basic", # Runtime op verification: rank mismatch in memref.cast "ViewSizeFromOtherTensor_basic", "SliceOutOfLowerBoundEndIndexModule_basic", @@ -4237,6 +4238,8 @@ "BoolIntTrueModule_basic", "BroadcastDynamicDimModule_basic", "BroadcastToModule_basic", + "BroadcastTensorsModule_basic", + "BroadcastTensorsModuleList_multiple_ranks", "BucketizeTensorFloatModule_basic", "BucketizeTensorModule_basic", "BucketizeTensorOutInt32RightModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 6af2292dea57..2966f6f17be3 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -949,6 +949,17 @@ def aten〇expand_as〡shape(self: List[int], other: List[int]) -> List[int]: def aten〇broadcast_to〡shape(self: List[int], size: List[int]) -> List[int]: return upstream_shape_functions.expand(self, size) +def aten〇broadcast_tensors〡shape(tensors: List[List[int]]) -> List[List[int]]: + if len(tensors) == 0: + return [] + result = tensors[0] + for i in range(1, len(tensors)): + result = upstream_shape_functions.broadcast(result, tensors[i]) + out: List[List[int]] = [] + for _ in tensors: + out.append(result) + return out + def aten〇view〡shape(self: List[int], size: List[int]) -> List[int]: return upstream_shape_functions.view(self, size) @@ -3148,6 +3159,17 @@ def aten〇broadcast_to〡dtype(self_rank_dtype: Tuple[int, int], size: List[int self_rank, self_dtype = self_rank_dtype return self_dtype +def aten〇broadcast_tensors〡dtype(tensors_rank_dtype: List[Tuple[int, int]]) -> List[Tuple[int, int]]: + max_rank = 0 + for rd in tensors_rank_dtype: + if rd[0] > max_rank: + max_rank = rd[0] + out: List[Tuple[int, int]] = [] + for tensor_rank_dtype in tensors_rank_dtype: + tensor_rank, tensor_dtype = tensor_rank_dtype + out.append((max_rank, tensor_dtype)) + return out + @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=2,dim=0, error_types={torch.complex128, torch.complex64, *all_integer_dtypes()})) def aten〇cosine_similarity〡dtype(x1_rank_dtype: Tuple[int, int], x2_rank_dtype: Tuple[int, int], dim: int = 1, eps: float = 1e-08) -> int: diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 902e95fd3d97..9bfc574b8393 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -847,7 +847,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::isneginf : (Tensor) -> (Tensor)") emit("aten::isposinf : (Tensor) -> (Tensor)") emit("aten::all : (Tensor) -> (Tensor)") - emit("aten::all.bool : (bool[]) -> (bool)") + emit("aten::all.bool : (bool[]) -> (bool)", has_folder=True) emit("aten::all.dim : (Tensor, int, bool) -> (Tensor)") emit("aten::any : (Tensor) -> (Tensor)") emit("aten::any.dim : (Tensor, int, bool) -> (Tensor)") @@ -899,6 +899,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::expand : (Tensor, int[], bool) -> (Tensor)") emit("aten::expand_as : (Tensor, Tensor) -> (Tensor)") emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)", has_folder=True) + emit("aten::broadcast_tensors : (Tensor[]) -> (Tensor[])") emit("aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)") emit("aten::index.Tensor_hacked_twin : (Tensor, Tensor[]) -> (Tensor)") emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)", has_folder=True) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 1ad698db9cc1..ad506a6e6cdd 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -2146,6 +2146,57 @@ def BroadcastToModule_basic(module, tu: TestUtils): # ============================================================================== +class BroadcastTensorsModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, x, y): + x1, y1 = torch.broadcast_tensors(x, y) + return x1, y1 + + +@register_test_case(module_factory=lambda: BroadcastTensorsModule()) +def BroadcastTensorsModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 3), tu.rand(2, 1)) + + +# ============================================================================== + + +class BroadcastTensorsModuleList(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3], torch.float32, True), + ([2, 1], torch.float32, True), + ([2, 1, 1], torch.float32, True), + ] + ) + def forward(self, x, y, z): + x1, y1, z1 = torch.broadcast_tensors(x, y, z) + return x1, y1, z1 + + +@register_test_case(module_factory=lambda: BroadcastTensorsModuleList()) +def BroadcastTensorsModuleList_multiple_ranks(module, tu: TestUtils): + module.forward(tu.rand(3), tu.rand(2, 1), tu.rand(2, 1, 1)) + + +# ============================================================================== + + class BroadcastToSameRankStaticModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index a025ec09726d..0a529fee772f 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -2685,6 +2685,16 @@ func.func @torch.aten.any.bool$fold() -> !torch.bool { return %0 : !torch.bool } +// CHECK-LABEL: func.func @torch.aten.all.bool$fold() -> !torch.bool { +// CHECK: %[[CST_TRUE:.*]] = torch.constant.bool true +// CHECK: return %[[CST_TRUE]] : !torch.bool +func.func @torch.aten.all.bool$fold() -> !torch.bool { + %true = torch.constant.bool true + %input = torch.prim.ListConstruct %true, %true, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list + %0 = torch.aten.all.bool %input : !torch.list -> !torch.bool + return %0 : !torch.bool +} + // CHECK-LABEL: func.func @torch.aten.floor$canonicalize // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],si64> // CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[?,?],si64> diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 7644c00de069..3e921247cdc2 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -846,3 +846,24 @@ func.func @native_layer_norm_mixed_dtypes(%input: !torch.vtensor<[1,56,56,96],bf %result, %mean, %rstd = torch.aten.native_layer_norm %input, %normalized_shape, %weight, %bias, %eps : !torch.vtensor<[1,56,56,96],bf16>, !torch.list, !torch.vtensor<[96],bf16>, !torch.vtensor<[96],bf16>, !torch.float -> !torch.vtensor<[1,56,56,96],bf16>, !torch.vtensor<[1,56,56,1],f32>, !torch.vtensor<[1,56,56,1],f32> return %result, %mean, %rstd : !torch.vtensor<[1,56,56,96],bf16>, !torch.vtensor<[1,56,56,1],f32>, !torch.vtensor<[1,56,56,1],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.broadcast_tensors +// CHECK-SAME: (%[[ARG0:.*]]: !torch.vtensor<[1,3],f32> +// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[2,1],f32> +// CHECK-DAG: %[[INT2:.*]] = torch.constant.int 2 +// CHECK-DAG: %[[INT3:.*]] = torch.constant.int 3 +// CHECK-DAG: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: torch.runtime.assert %[[TRUE]], "tensors are not broadcast compatible" +// CHECK: torch.runtime.assert %[[TRUE]], "tensors are not broadcast compatible" +// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[B0:.*]] = torch.aten.broadcast_to %[[ARG0]], %[[SHAPE]] : !torch.vtensor<[1,3],f32>, !torch.list -> !torch.vtensor<[2,3],f32> +// CHECK: %[[B1:.*]] = torch.aten.broadcast_to %[[ARG1]], %[[SHAPE]] : !torch.vtensor<[2,1],f32>, !torch.list -> !torch.vtensor<[2,3],f32> +// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[B0]], %[[B1]] : (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>) -> !torch.list> +// CHECK: return %[[LIST]] : !torch.list> +func.func @torch.aten.broadcast_tensors(%arg0: !torch.vtensor<[1,3],f32>, %arg1: !torch.vtensor<[2,1],f32>) -> !torch.list> { + %0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[1,3],f32>, !torch.vtensor<[2,1],f32>) -> !torch.list + %1 = torch.aten.broadcast_tensors %0 : !torch.list -> !torch.list> + return %1 : !torch.list> +}