From d6125e48edd94197774d5f14dacdc36339b504b7 Mon Sep 17 00:00:00 2001 From: penguin-wwy <940375606@qq.com> Date: Fri, 18 Jul 2025 10:56:06 +0800 Subject: [PATCH] [Stablehlo] Refactor utility functions for reduction --- lib/Conversion/TorchToStablehlo/Reduction.cpp | 53 +++++++------------ 1 file changed, 18 insertions(+), 35 deletions(-) diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index d44e5db66b59..c007ea7a69f5 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -45,85 +45,68 @@ static SmallVector getReduceOutputShape(ArrayRef inputShape, static Value createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter) { auto constType = RankedTensorType::get({}, elementTy); + DenseElementsAttr constAttr = nullptr; if (isa(op)) { if (isa(elementTy)) { - auto constAttr = DenseElementsAttr::get( + constAttr = DenseElementsAttr::get( constType, {APFloat::getZero( cast(elementTy).getFloatSemantics(), /*negative=*/false)}); - return rewriter.create(op->getLoc(), constType, - constAttr); } else if (isa(elementTy)) { - auto constAttr = DenseElementsAttr::get( + constAttr = DenseElementsAttr::get( constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())}); - return rewriter.create(op->getLoc(), constType, - constAttr); } } if (isa(op)) { if (isa(elementTy)) { - auto constAttr = DenseElementsAttr::get( + constAttr = DenseElementsAttr::get( constType, {APFloat::getInf(cast(elementTy).getFloatSemantics(), /*negative=*/true)}); - return rewriter.create(op->getLoc(), constType, - constAttr); } else if (isa(elementTy)) { - auto constAttr = DenseElementsAttr::get( + constAttr = DenseElementsAttr::get( constType, {APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())}); - return rewriter.create(op->getLoc(), constType, - constAttr); } } if (isa(op)) { if (isa(elementTy)) { - auto constAttr = DenseElementsAttr::get( + constAttr = DenseElementsAttr::get( constType, {APFloat::getInf(cast(elementTy).getFloatSemantics(), /*negative=*/false)}); - return rewriter.create(op->getLoc(), constType, - constAttr); } else if (isa(elementTy)) { - auto constAttr = DenseElementsAttr::get( + constAttr = DenseElementsAttr::get( constType, {APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth())}); - return rewriter.create(op->getLoc(), constType, - constAttr); } } if (isa(op)) { if (isa(elementTy)) { APFloat one(cast(elementTy).getFloatSemantics(), 1); - auto constAttr = DenseElementsAttr::get(constType, one); - return rewriter.create(op->getLoc(), constType, - constAttr); + constAttr = DenseElementsAttr::get(constType, one); } else if (isa(elementTy)) { APInt one(elementTy.getIntOrFloatBitWidth(), 1); - auto constAttr = DenseElementsAttr::get(constType, one); - return rewriter.create(op->getLoc(), constType, - constAttr); + constAttr = DenseElementsAttr::get(constType, one); } } if (isa(op)) { - auto constAttr = - DenseElementsAttr::get(constType, {APInt(/*numBits=*/1, 1)}); - return rewriter.create(op->getLoc(), constType, - constAttr); + constAttr = DenseElementsAttr::get(constType, {APInt(/*numBits=*/1, 1)}); } if (isa(op)) { - auto constAttr = - DenseElementsAttr::get(constType, {APInt(/*numBits=*/1, 0)}); + constAttr = DenseElementsAttr::get(constType, {APInt(/*numBits=*/1, 0)}); + } + + if (constAttr != nullptr) { return rewriter.create(op->getLoc(), constType, constAttr); } - op->emitError("unimplemented lowering in " "createInitialValueForReduceOp"); return nullptr; @@ -483,7 +466,7 @@ class ConvertAtenReduceDimsOp : public ConvertAtenReductionOp { return rewriter.notifyMatchFailure( op, "non-const integer `dim` is not supported"); } - if (inputDims.size() == 0) { + if (inputDims.empty()) { dims = llvm::to_vector(llvm::seq(0, inputTy.getRank())); } else { for (auto d : inputDims) { @@ -570,7 +553,7 @@ class ConvertAtenReduceWithIndicesOp : public ConvertAtenReductionOp { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } - auto inputShapeVec = *inputShapeInfo; + auto &inputShapeVec = *inputShapeInfo; if (op.getResult(1).use_empty()) { llvm::SmallVector outputShape(inputTy.getShape()); @@ -643,7 +626,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "non-const integer `dim` is not supported"); } - if (inputDims.size() == 0) { + if (inputDims.empty()) { rewriter.replaceOp(op, input); return success(); } @@ -722,7 +705,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "non-const integer `dim` is not supported"); } - if (inputDims.size() == 0) { + if (inputDims.empty()) { inputDims = llvm::to_vector<4>(llvm::seq(0, inputTy.getRank())); } }