Skip to content

Support decomposition of torch.broadcast_tensors #4253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -11041,6 +11041,7 @@ def Torch_AtenAllBoolOp : Torch_Op<"aten.all.bool", [
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasFolder = 1;
}

def Torch_AtenAllDimOp : Torch_Op<"aten.all.dim", [
Expand Down Expand Up @@ -12027,6 +12028,29 @@ def Torch_AtenBroadcastToOp : Torch_Op<"aten.broadcast_to", [
let hasFolder = 1;
}

def Torch_AtenBroadcastTensorsOp : Torch_Op<"aten.broadcast_tensors", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::broadcast_tensors : (Tensor[]) -> (Tensor[])`";
let arguments = (ins
AnyTorchListOfTensorType:$tensors
);
let results = (outs
AnyTorchListOfTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenBroadcastTensorsOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenBroadcastTensorsOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_AtenIndexTensorOp : Torch_Op<"aten.index.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
4 changes: 2 additions & 2 deletions include/torch-mlir/Dialect/Torch/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ Type getBuiltInTypeForTorchScalar(Type type);
Value getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
Type dtype);

// Checks whether the `inputA` and `inputB` are broadcast compatible or not. If
// Checks whether the inputs are broadcast compatible or not. If
// yes, then computes the final broadcast shape.
void computeBroadcastShape(PatternRewriter &rewriter, Location loc,
Value inputA, Value inputB,
SmallVector<Value> inputs,
SmallVector<int64_t> &resultShape,
SmallVector<Value> &resultShapeValue);

Expand Down
6 changes: 3 additions & 3 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1065,9 +1065,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
} else {
SmallVector<int64_t> resultBroadcastShapeInt;
SmallVector<Value> resultBroadcastShapeValue;
Torch::computeBroadcastShape(rewriter, binder.getLoc(), curr,
valList[i], resultBroadcastShapeInt,
resultBroadcastShapeValue);
Torch::computeBroadcastShape(
rewriter, binder.getLoc(), {curr, valList[i]},
resultBroadcastShapeInt, resultBroadcastShapeValue);
auto baseType = Torch::ValueTensorType::get(
binder.op->getContext(), resultBroadcastShapeInt,
resultType.getOptionalDtype());
Expand Down
18 changes: 18 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2838,6 +2838,24 @@ OpFoldResult AtenAnyBoolOp::fold(FoldAdaptor adaptor) {
return nullptr;
}

//===----------------------------------------------------------------------===//
// AtenAllBoolOp
//===----------------------------------------------------------------------===//

OpFoldResult AtenAllBoolOp::fold(FoldAdaptor adaptor) {
auto inputConstruct = getSelf().getDefiningOp<Torch::PrimListConstructOp>();
if (!inputConstruct || isListPotentiallyMutated(inputConstruct))
return nullptr;
// If all operands are a constant true, return true.
for (auto operand : inputConstruct.getOperands()) {
bool b = true;
if (!matchPattern(operand, m_TorchConstantBool(&b)) || !b) {
return nullptr;
}
}
return getI1IntegerAttr(getContext(), true);
}

//===----------------------------------------------------------------------===//
// AtenFloatScalarOp
//===----------------------------------------------------------------------===//
Expand Down
60 changes: 60 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7796,6 +7796,37 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.expand(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.broadcast_tensors\"(%arg0: !torch.list<list<int>>) -> !torch.list<list<int>> {\n"
" %true = torch.constant.bool true\n"
" %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int\n"
" %1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.list<list<int>>) {\n"
" %3 = torch.prim.ListConstruct : () -> !torch.list<list<int>>\n"
" torch.prim.If.yield %3 : !torch.list<list<int>>\n"
" } else {\n"
" %3 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
" %4 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int\n"
" %5 = torch.aten.__range_length %int1, %4, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
" %6 = torch.prim.Loop %5, %true, init(%3) {\n"
" ^bb0(%arg1: !torch.int, %arg2: !torch.list<int>):\n"
" %9 = torch.aten.__derive_index %arg1, %int1, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
" %10 = torch.aten.__getitem__.t %arg0, %9 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
" %11 = func.call @__torch__.torch.jit._shape_functions.broadcast(%arg2, %10) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" torch.prim.Loop.condition %true, iter(%11 : !torch.list<int>)\n"
" } : (!torch.int, !torch.bool, !torch.list<int>) -> !torch.list<int>\n"
" %7 = torch.prim.ListConstruct : () -> !torch.list<list<int>>\n"
" %8 = torch.aten.len.t %arg0 : !torch.list<list<int>> -> !torch.int\n"
" torch.prim.Loop %8, %true, init() {\n"
" ^bb0(%arg1: !torch.int):\n"
" %9 = torch.aten.append.t %7, %6 : !torch.list<list<int>>, !torch.list<int> -> !torch.list<list<int>>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" torch.prim.If.yield %7 : !torch.list<list<int>>\n"
" }\n"
" return %2 : !torch.list<list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.view\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.view(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -12447,6 +12478,35 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.broadcast_tensors\"(%arg0: !torch.list<tuple<int, int>>) -> !torch.list<tuple<int, int>> {\n"
" %true = torch.constant.bool true\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<tuple<int, int>> -> !torch.int\n"
" %1 = torch.prim.Loop %0, %true, init(%int0) {\n"
" ^bb0(%arg1: !torch.int, %arg2: !torch.int):\n"
" %4 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list<tuple<int, int>>, !torch.int -> !torch.tuple<int, int>\n"
" %5 = torch.prim.TupleIndex %4, %int0 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
" %6 = torch.aten.gt.int %5, %arg2 : !torch.int, !torch.int -> !torch.bool\n"
" %7 = torch.prim.If %6 -> (!torch.int) {\n"
" %8 = torch.prim.TupleIndex %4, %int0 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
" torch.prim.If.yield %8 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %arg2 : !torch.int\n"
" }\n"
" torch.prim.Loop.condition %true, iter(%7 : !torch.int)\n"
" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n"
" %2 = torch.prim.ListConstruct : () -> !torch.list<tuple<int, int>>\n"
" %3 = torch.aten.len.t %arg0 : !torch.list<tuple<int, int>> -> !torch.int\n"
" torch.prim.Loop %3, %true, init() {\n"
" ^bb0(%arg1: !torch.int):\n"
" %4 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list<tuple<int, int>>, !torch.int -> !torch.tuple<int, int>\n"
" %5:2 = torch.prim.TupleUnpack %4 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %6 = torch.prim.TupleConstruct %1, %5#1 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" %7 = torch.aten.append.t %2, %6 : !torch.list<tuple<int, int>>, !torch.tuple<int, int> -> !torch.list<tuple<int, int>>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" return %2 : !torch.list<tuple<int, int>>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.cosine_similarity\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int, %arg3: !torch.float) -> !torch.int {\n"
" %int7 = torch.constant.int 7\n"
" %int6 = torch.constant.int 6\n"
Expand Down
54 changes: 50 additions & 4 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include "llvm/ADT/StringSet.h"
#include <cstdint>
#include <set>

using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
Expand Down Expand Up @@ -3415,7 +3414,7 @@ class DecomposeAtenLinalgCrossOp : public OpRewritePattern<AtenLinalgCrossOp> {
// calculate common shape for broadcast
SmallVector<int64_t> broadcastShape;
SmallVector<Value> broadcastShapeValue;
computeBroadcastShape(rewriter, loc, self, other, broadcastShape,
computeBroadcastShape(rewriter, loc, {self, other}, broadcastShape,
broadcastShapeValue);

Type broadcastType = ValueTensorType::get(
Expand Down Expand Up @@ -8962,7 +8961,7 @@ class DecomposeAtenCosineSimilarityOp
// Broadcast x1 and x2 to the same shape
SmallVector<int64_t> indexBroadcastShapeInt;
SmallVector<Value> indexBroadcastShapeValue;
computeBroadcastShape(rewriter, loc, x1, x2, indexBroadcastShapeInt,
computeBroadcastShape(rewriter, loc, {x1, x2}, indexBroadcastShapeInt,
indexBroadcastShapeValue);
Type dtype = cast<BaseTensorType>(x1.getType()).getOptionalDtype();
Type broadcastType = ValueTensorType::get(
Expand Down Expand Up @@ -11329,7 +11328,7 @@ class DecomposeAtenHeaviside : public OpRewritePattern<AtenHeavisideOp> {
auto resultTy = dyn_cast<BaseTensorType>(op.getType());
SmallVector<int64_t> broadcastShape;
SmallVector<Value> broadcastShapeValue;
computeBroadcastShape(rewriter, loc, input, value, broadcastShape,
computeBroadcastShape(rewriter, loc, {input, value}, broadcastShape,
broadcastShapeValue);

auto broadcastType = ValueTensorType::get(
Expand Down Expand Up @@ -12427,6 +12426,52 @@ class DecomposeAtenRoundDecimalsOp
};
} // namespace

namespace {
class DecomposeAtenBroadcastTensorsOp
: public OpRewritePattern<AtenBroadcastTensorsOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenBroadcastTensorsOp op,
PatternRewriter &rewriter) const override {

Location loc = op.getLoc();
SmallVector<Value> tensors;
if (!getListConstructElements(op.getTensors(), tensors))
return rewriter.notifyMatchFailure(op, "Unable to get tensors");
int64_t numTensors = tensors.size();

SmallVector<int64_t> broadcastShape;
SmallVector<Value> broadcastShapeValue;

computeBroadcastShape(rewriter, loc, tensors, broadcastShape,
broadcastShapeValue);

auto resType = cast<BaseTensorType>(tensors[0].getType());
auto dtype = resType.getDtype();
Type broadcastType = ValueTensorType::get(
op.getContext(), llvm::ArrayRef(broadcastShape), dtype);

Value broadcastShapeTorchList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
broadcastShapeValue);

SmallVector<Value> broadcastedValues;
for (int64_t i = 0; i < numTensors; i++) {
auto inputTensor = tensors[i];
auto broadcastedVal = rewriter.create<AtenBroadcastToOp>(
loc, broadcastType, inputTensor, broadcastShapeTorchList);
broadcastedValues.push_back(broadcastedVal);
}

auto broadcastedValuesList = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(broadcastType), broadcastedValues);

rewriter.replaceOp(op, broadcastedValuesList);
return success();
}
};
} // namespace

namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
Expand Down Expand Up @@ -12628,6 +12673,7 @@ class DecomposeComplexOpsPass
DecomposeAtenAdaptivePool2dOp<AtenAdaptiveMaxPool2dOp>>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenAdaptivePool2dOp<AtenAdaptiveAvgPool2dOp>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBroadcastTensorsOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMinOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMinTensorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMaxOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenAdaptiveAvgPool2dOp>();
target.addIllegalOp<AtenAdaptiveMaxPool1dOp>();
target.addIllegalOp<AtenAdaptiveMaxPool2dOp>();
target.addIllegalOp<AtenBroadcastTensorsOp>();
target.addIllegalOp<AtenClampMinOp>();
target.addIllegalOp<AtenClampMinTensorOp>();
target.addIllegalOp<AtenClampMaxOp>();
Expand Down
Loading
Loading