From f6e6e10d44923f518120eadbae690e2f197f9e7d Mon Sep 17 00:00:00 2001 From: nbpatel Date: Mon, 25 Aug 2025 18:59:31 +0000 Subject: [PATCH 01/13] Add vector.step distribution pattern --- .../Transforms/XeGPUWgToSgDistribute.cpp | 204 +++++++++++++++++- .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 40 ++++ 2 files changed, 239 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 93b4efcd125ec..39d0e9ea2e91d 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -737,8 +737,18 @@ struct WgToSgArithConstantOp : public OpConversionPattern { if (!vecAttr || !vecAttr.isSplat() || !vecType) return failure(); - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult()); - if (!layout || !layout.getSgLayout()) + auto layoutName = xegpu::getLayoutName(op->getResult(0)); + auto attr = op->getAttr(layoutName); + + xegpu::DistributeLayoutAttr layout = nullptr; + // Try to get either SliceAttr or LayoutAttr, and keep as is + if (auto trySlice = dyn_cast_if_present(attr)) { + layout = trySlice; + } else if (auto tryLayout = dyn_cast_if_present(attr)) { + layout = tryLayout; + } + + if (!layout) return failure(); ArrayRef wgShape = vecType.getShape(); @@ -754,8 +764,12 @@ struct WgToSgArithConstantOp : public OpConversionPattern { auto sgAttr = DenseElementsAttr::get(newType, singleVal); auto cstOp = arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr); - if (auto newLayout = layout.dropSgLayoutAndData()) - xegpu::setLayoutAttr(cstOp->getResult(0), newLayout); + // Do nothing if layout is a SliceAttr + if (auto layoutAttr = dyn_cast(layout)) { + if (auto newLayout = layoutAttr.dropSgLayoutAndData()) { + xegpu::setLayoutAttr(cstOp->getResult(0), newLayout); + } + } SmallVector newConsts(count, cstOp); rewriter.replaceOpWithMultiple(op, {newConsts}); @@ -763,6 +777,139 @@ struct WgToSgArithConstantOp : public OpConversionPattern { } }; +<<<<<<< HEAD +======= +// This pattern transforms the LoadGatherOp with explicit offsets to load +// subgroup data, similar to WgToSgLoadNdOpWithOffset. +struct WgToSgLoadGatherOpWithOffset + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + if (!op.getOffsets()) + return failure(); + + Location loc = op.getLoc(); + VectorType resultType = op.getResult().getType(); + ArrayRef wgShape = resultType.getShape(); + + xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult()); + if (!layout || !layout.getSgLayout()) + return failure(); + + SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; + + SmallVector newLoadOps; + auto chunkSizeAttr = + rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1)); + VectorType newTy = VectorType::get(sgShape, resultType.getElementType()); + for (auto [offsets, mask] : + llvm::zip(adaptor.getOffsets(), adaptor.getMask())) { + auto newLoadOp = rewriter.create( + loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr, + op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); + xegpu::setLayoutAttr(newLoadOp->getResult(0), + layout.dropSgLayoutAndData()); + newLoadOps.push_back(newLoadOp); + } + rewriter.replaceOpWithMultiple(op, {newLoadOps}); + return success(); + } +}; + +// This pattern transforms the StoreScatterOp with explicit offsets to store +// subgroup data, similar to WgToSgStoreNdOpWithOffset. +struct WgToSgStoreScatterOpWithOffset + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + if (!op.getOffsets()) + return failure(); + + Location loc = op.getLoc(); + VectorType valueType = dyn_cast(op.getValue().getType()); + if (!valueType) + return failure(); + + ArrayRef wgShape = valueType.getShape(); + xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getValue()); + if (!layout || !layout.getSgLayout()) + return failure(); + + auto chunkSizeOpt = op.getChunkSize(); + int64_t chunkSize = chunkSizeOpt ? static_cast(*chunkSizeOpt) : 1; + auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize); + for (auto [val, offs, mask] : llvm::zip( + adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) { + rewriter.create( + loc, val, op.getDest(), offs, mask, chunkSizeAttr, op.getL1HintAttr(), + op.getL2HintAttr(), op.getL3HintAttr()); + // Update the layout_result_0 attribute to drop sg_layout and sg_data. + if (auto layoutAttr = + op->getAttrOfType("layout_result_0")) { + if (auto newLayout = layoutAttr.dropSgLayoutAndData()) + op->setAttr("layout_result_0", newLayout); + } + } + rewriter.eraseOp(op); + return success(); + } +}; + +struct WgToSgVectorStepOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto layoutName = xegpu::getLayoutName(op->getResult(0)); + auto attr = op->getAttr(layoutName); + + xegpu::DistributeLayoutAttr layoutAttr = nullptr; + // Try to get either SliceAttr or LayoutAttr, and keep as is + if (auto trySlice = dyn_cast_if_present(attr)) { + layoutAttr = trySlice; + } else if (auto tryLayout = dyn_cast_if_present(attr)) { + layoutAttr = tryLayout; + } + + if (!layoutAttr) + return failure(); + + Location loc = op.getLoc(); + VectorType type = op.getResult().getType(); + auto wgShape = type.getShape(); + std::optional> sgShape = + getSgShapeAndCount(wgShape, layoutAttr).first; + if (!sgShape) + return failure(); + + Value sgId = + gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); + auto maybeOffsets = layoutAttr.getOffsets(rewriter, loc, sgId, wgShape); + if (failed(maybeOffsets)) + return failure(); + + VectorType newTy = type.cloneWith(*sgShape, type.getElementType()); + Value base = vector::StepOp::create(rewriter, loc, newTy); + SmallVector newOps; + for (auto offsets : *maybeOffsets) { + Value bcast = + vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]); + Value add = arith::AddIOp::create(rewriter, loc, base, bcast); + newOps.push_back(add); + } + + rewriter.replaceOpWithMultiple(op, {newOps}); + return success(); + } +}; + +>>>>>>> ddbdb0d7eb4f (Add vector.step distribution pattern) struct WgToSgLoadMatrixOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -824,8 +971,14 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp, WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern, WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp, +<<<<<<< HEAD WgToSgArithConstantOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp>( patterns.getContext()); +======= + WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset, + WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp, + WgToSgStoreMatrixOp, WgToSgVectorStepOp>(patterns.getContext()); +>>>>>>> ddbdb0d7eb4f (Add vector.step distribution pattern) } } // namespace xegpu } // namespace mlir @@ -947,9 +1100,50 @@ void XeGPUWgToSgDistributePass::runOnOperation() { auto vecType = dyn_cast(op.getType()); if (!vecType) return true; - return isLegal(xegpu::getLayoutAttr(op.getResult())); + + auto layoutName = xegpu::getLayoutName(op->getResult(0)); + auto sliceAttr = op->getAttrOfType(layoutName); + if (sliceAttr) + return isLegal(sliceAttr); + + auto layoutAttr = op->getAttrOfType(layoutName); + if (layoutAttr) + return isLegal(layoutAttr); + + // If neither attribute is present, consider the op legal. + return true; + }); + + target.addDynamicallyLegalOp( + [=](xegpu::LoadGatherOp op) -> bool { + auto layout = xegpu::getLayoutAttr(op.getResult()); + return isLegal(layout); + }); + + target.addDynamicallyLegalOp( + [=](xegpu::StoreScatterOp op) -> bool { + // Check if the layout attribute is present on the result. + auto layout = op->getAttrOfType("layout_result_0"); + if (!layout) + return true; + return isLegal(layout); }); + target.addDynamicallyLegalOp([&](vector::StepOp op) -> bool { + // Check for either a SliceAttr or LayoutAttr on the result. + auto layoutName = xegpu::getLayoutName(op->getResult(0)); + auto sliceAttr = op->getAttrOfType(layoutName); + if (sliceAttr) + return isLegal(sliceAttr); + + auto layoutAttr = op->getAttrOfType(layoutName); + if (layoutAttr) + return isLegal(layoutAttr); + + // If neither attribute is present, consider the op legal. + return true; + }); + target.addDynamicallyLegalOp( [=](vector::BroadcastOp op) -> bool { return isLegal(xegpu::getLayoutAttr(op.getResult())); diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 32157a7911f62..0869d0346fed7 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -2,6 +2,7 @@ //CHECK: #map = affine_map<()[s0] -> (s0 floordiv 4)> //CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)> +//CHECK: #map2 = affine_map<()[s0] -> (s0 floordiv 8)> gpu.module @test_distribution { // CHECK-LABEL: create_nd_tdesc_no_offset // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> @@ -321,4 +322,43 @@ gpu.module @test_distribution { xegpu.store_matrix %cst, %mdesc[0, 0] {layout = #xegpu.layout} : vector<64x128xf32>, !xegpu.mem_desc<64x128xf32> gpu.return } + + // CHECK-LABEL: vector_step_op + gpu.func @vector_step_op_slice_attr() { + //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index + //CHECK: [[IDY:%.+]] = affine.apply #map2()[[[sgId]]] + //CHECK: [[c32:%.+]] = arith.constant 32 : index + //CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]] + //CHECK: [[c0:%.+]] = arith.constant 0 : index + //CHECK: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index + //CHECK: [[c128:%.+]] = arith.constant 128 : index + //CHECK: [[MODY:%.+]] = index.remu [[Y]], [[c128]] + //CHECK: [[BASE:%.+]] = vector.step : vector<32xindex> + //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex> + //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex> + %step = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [1]>}: vector<128xindex> + gpu.return + } + + gpu.func @vector_step_op_layout_attr() { + //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index + //CHECK: [[c16:%.+]] = arith.constant 16 : index + //CHECK: [[c8:%.+]] = arith.constant 8 : index + //CHECK: [[LOCALY:%.+]] = index.mul [[sgId]], [[c8]] + //CHECK: [[c0:%.+]] = arith.constant 0 : index + //CHECK: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index + //CHECK: [[c128:%.+]] = arith.constant 128 : index + //CHECK: [[MODY:%.+]] = index.remu [[Y]], [[c128]] + //CHECK: [[BASE:%.+]] = vector.step : vector<8xindex> + //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<8xindex> + //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<8xindex> + %step = vector.step {layout_result_0 = #xegpu.layout}: vector<128xindex> + gpu.return + } + + gpu.func @constant_with_slice_attr() { + //CHECK: [[cst:%.+]] = arith.constant dense<10> : vector<1xindex> + %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [1, 2, 3]>} dense<10> : vector<4xindex> + gpu.return + } } From ab57d1b226086686aaa3cf294bfb3a2147e677dd Mon Sep 17 00:00:00 2001 From: nbpatel Date: Tue, 26 Aug 2025 15:01:43 +0000 Subject: [PATCH 02/13] Add vector.shape_cast pattern --- .../Transforms/XeGPUWgToSgDistribute.cpp | 180 +++++++----------- .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 13 ++ 2 files changed, 78 insertions(+), 115 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 39d0e9ea2e91d..286a74bfebb34 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -535,8 +535,17 @@ struct WgToSgElementwiseOp : public ConversionPattern { ArrayRef wgShape = resultType.getShape(); - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0)); - if (!layout || !layout.getSgLayout()) + auto layoutName = xegpu::getLayoutName(op->getResult(0)); + auto attr = op->getAttr(layoutName); + + xegpu::DistributeLayoutAttr layout = nullptr; + if (auto trySlice = dyn_cast_if_present(attr)) { + layout = trySlice; + } else if (auto tryLayout = dyn_cast_if_present(attr)) { + layout = tryLayout; + } + + if (!layout) return failure(); SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; @@ -741,7 +750,6 @@ struct WgToSgArithConstantOp : public OpConversionPattern { auto attr = op->getAttr(layoutName); xegpu::DistributeLayoutAttr layout = nullptr; - // Try to get either SliceAttr or LayoutAttr, and keep as is if (auto trySlice = dyn_cast_if_present(attr)) { layout = trySlice; } else if (auto tryLayout = dyn_cast_if_present(attr)) { @@ -764,7 +772,6 @@ struct WgToSgArithConstantOp : public OpConversionPattern { auto sgAttr = DenseElementsAttr::get(newType, singleVal); auto cstOp = arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr); - // Do nothing if layout is a SliceAttr if (auto layoutAttr = dyn_cast(layout)) { if (auto newLayout = layoutAttr.dropSgLayoutAndData()) { xegpu::setLayoutAttr(cstOp->getResult(0), newLayout); @@ -777,90 +784,6 @@ struct WgToSgArithConstantOp : public OpConversionPattern { } }; -<<<<<<< HEAD -======= -// This pattern transforms the LoadGatherOp with explicit offsets to load -// subgroup data, similar to WgToSgLoadNdOpWithOffset. -struct WgToSgLoadGatherOpWithOffset - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - if (!op.getOffsets()) - return failure(); - - Location loc = op.getLoc(); - VectorType resultType = op.getResult().getType(); - ArrayRef wgShape = resultType.getShape(); - - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult()); - if (!layout || !layout.getSgLayout()) - return failure(); - - SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; - - SmallVector newLoadOps; - auto chunkSizeAttr = - rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1)); - VectorType newTy = VectorType::get(sgShape, resultType.getElementType()); - for (auto [offsets, mask] : - llvm::zip(adaptor.getOffsets(), adaptor.getMask())) { - auto newLoadOp = rewriter.create( - loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr, - op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); - xegpu::setLayoutAttr(newLoadOp->getResult(0), - layout.dropSgLayoutAndData()); - newLoadOps.push_back(newLoadOp); - } - rewriter.replaceOpWithMultiple(op, {newLoadOps}); - return success(); - } -}; - -// This pattern transforms the StoreScatterOp with explicit offsets to store -// subgroup data, similar to WgToSgStoreNdOpWithOffset. -struct WgToSgStoreScatterOpWithOffset - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - if (!op.getOffsets()) - return failure(); - - Location loc = op.getLoc(); - VectorType valueType = dyn_cast(op.getValue().getType()); - if (!valueType) - return failure(); - - ArrayRef wgShape = valueType.getShape(); - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getValue()); - if (!layout || !layout.getSgLayout()) - return failure(); - - auto chunkSizeOpt = op.getChunkSize(); - int64_t chunkSize = chunkSizeOpt ? static_cast(*chunkSizeOpt) : 1; - auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize); - for (auto [val, offs, mask] : llvm::zip( - adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) { - rewriter.create( - loc, val, op.getDest(), offs, mask, chunkSizeAttr, op.getL1HintAttr(), - op.getL2HintAttr(), op.getL3HintAttr()); - // Update the layout_result_0 attribute to drop sg_layout and sg_data. - if (auto layoutAttr = - op->getAttrOfType("layout_result_0")) { - if (auto newLayout = layoutAttr.dropSgLayoutAndData()) - op->setAttr("layout_result_0", newLayout); - } - } - rewriter.eraseOp(op); - return success(); - } -}; - struct WgToSgVectorStepOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -870,7 +793,6 @@ struct WgToSgVectorStepOp : public OpConversionPattern { auto attr = op->getAttr(layoutName); xegpu::DistributeLayoutAttr layoutAttr = nullptr; - // Try to get either SliceAttr or LayoutAttr, and keep as is if (auto trySlice = dyn_cast_if_present(attr)) { layoutAttr = trySlice; } else if (auto tryLayout = dyn_cast_if_present(attr)) { @@ -909,7 +831,42 @@ struct WgToSgVectorStepOp : public OpConversionPattern { } }; ->>>>>>> ddbdb0d7eb4f (Add vector.step distribution pattern) +// This pattern transforms vector.shape_cast ops to work at subgroup level. +struct WgToSgVectorShapeCastOp + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + VectorType resultType = dyn_cast(op.getResult().getType()); + if (!resultType) + return failure(); + + ArrayRef wgShape = resultType.getShape(); + xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult()); + if (!layout || !layout.getSgLayout()) + return failure(); + + SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; + VectorType newResultType = + VectorType::get(sgShape, resultType.getElementType()); + + SmallVector newShapeCastOps; + for (auto src : adaptor.getSource()) { + auto newShapeCast = + rewriter.create(op.getLoc(), newResultType, src); + xegpu::setLayoutAttr(newShapeCast->getResult(0), + layout.dropSgLayoutAndData()); + newShapeCastOps.push_back(newShapeCast.getResult()); + } + + rewriter.replaceOpWithMultiple(op, {newShapeCastOps}); + return success(); + } +}; + struct WgToSgLoadMatrixOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -971,14 +928,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp, WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern, WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp, -<<<<<<< HEAD - WgToSgArithConstantOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp>( - patterns.getContext()); -======= - WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset, - WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp, - WgToSgStoreMatrixOp, WgToSgVectorStepOp>(patterns.getContext()); ->>>>>>> ddbdb0d7eb4f (Add vector.step distribution pattern) + WgToSgArithConstantOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp, + WgToSgVectorStepOp, WgToSgVectorShapeCastOp>(patterns.getContext()); } } // namespace xegpu } // namespace mlir @@ -1114,21 +1065,6 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return true; }); - target.addDynamicallyLegalOp( - [=](xegpu::LoadGatherOp op) -> bool { - auto layout = xegpu::getLayoutAttr(op.getResult()); - return isLegal(layout); - }); - - target.addDynamicallyLegalOp( - [=](xegpu::StoreScatterOp op) -> bool { - // Check if the layout attribute is present on the result. - auto layout = op->getAttrOfType("layout_result_0"); - if (!layout) - return true; - return isLegal(layout); - }); - target.addDynamicallyLegalOp([&](vector::StepOp op) -> bool { // Check for either a SliceAttr or LayoutAttr on the result. auto layoutName = xegpu::getLayoutName(op->getResult(0)); @@ -1149,6 +1085,11 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return isLegal(xegpu::getLayoutAttr(op.getResult())); }); + target.addDynamicallyLegalOp( + [=](vector::ShapeCastOp op) -> bool { + return isLegal(xegpu::getLayoutAttr(op.getResult())); + }); + target.addDynamicallyLegalOp( [=](xegpu::ConvertLayoutOp op) -> bool { return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout()); @@ -1174,8 +1115,17 @@ void XeGPUWgToSgDistributePass::runOnOperation() { } } - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0)); - return isLegal(layout); + auto layoutName = xegpu::getLayoutName(op->getResult(0)); + auto sliceAttr = op->getAttrOfType(layoutName); + if (sliceAttr) + return isLegal(sliceAttr); + + auto layoutAttr = op->getAttrOfType(layoutName); + if (layoutAttr) + return isLegal(layoutAttr); + + // If neither attribute is present, consider the op legal. + return true; }); target.addDynamicallyLegalOp( diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 0869d0346fed7..7601274ba4969 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -356,9 +356,22 @@ gpu.module @test_distribution { gpu.return } + // CHECK-LABEL: constant_with_slice_attr gpu.func @constant_with_slice_attr() { //CHECK: [[cst:%.+]] = arith.constant dense<10> : vector<1xindex> %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [1, 2, 3]>} dense<10> : vector<4xindex> gpu.return } + + // CHECK-LABEL: vector_shape_cast + gpu.func @vector_shape_cast(%src: memref<256x128xf32>) { + %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32> + -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + %load = xegpu.load_nd %tdesc[0, 0] + : !xegpu.tensor_desc<256x128xf32, #xegpu.layout> + -> vector<256x128xf32> + //CHECK: vector.shape_cast {{.*}} : vector<32x32xf32> to vector<8x4x8x4xf32> + %cast = vector.shape_cast %load {layout_result_0 = #xegpu.layout} : vector<256x128xf32> to vector<16x16x16x8xf32> + gpu.return + } } From 2c96a5c1509fb8d4e4a91a0924373220e2f04af6 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Tue, 26 Aug 2025 16:36:20 +0000 Subject: [PATCH 03/13] Support slice for shapecast --- .../Transforms/XeGPUWgToSgDistribute.cpp | 59 +++++++++++-------- 1 file changed, 34 insertions(+), 25 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 286a74bfebb34..1c13f59151a34 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -784,6 +784,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern { } }; +// This pattern distributes the vector.step ops to work at subgroup level struct WgToSgVectorStepOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -845,8 +846,17 @@ struct WgToSgVectorShapeCastOp return failure(); ArrayRef wgShape = resultType.getShape(); - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult()); - if (!layout || !layout.getSgLayout()) + auto layoutName = xegpu::getLayoutName(op->getResult(0)); + auto attr = op->getAttr(layoutName); + + xegpu::DistributeLayoutAttr layout = nullptr; + if (auto trySlice = dyn_cast_if_present(attr)) { + layout = trySlice; + } else if (auto tryLayout = dyn_cast_if_present(attr)) { + layout = tryLayout; + } + + if (!layout) return failure(); SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; @@ -857,13 +867,16 @@ struct WgToSgVectorShapeCastOp for (auto src : adaptor.getSource()) { auto newShapeCast = rewriter.create(op.getLoc(), newResultType, src); - xegpu::setLayoutAttr(newShapeCast->getResult(0), - layout.dropSgLayoutAndData()); - newShapeCastOps.push_back(newShapeCast.getResult()); - } + if (auto layoutAttr = dyn_cast(layout)) { + if (auto newLayout = layoutAttr.dropSgLayoutAndData()) { + xegpu::setLayoutAttr(newShapeCast->getResult(0), newLayout); + } + newShapeCastOps.push_back(newShapeCast.getResult()); + } - rewriter.replaceOpWithMultiple(op, {newShapeCastOps}); - return success(); + rewriter.replaceOpWithMultiple(op, {newShapeCastOps}); + return success(); + } } }; @@ -1065,31 +1078,27 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return true; }); - target.addDynamicallyLegalOp([&](vector::StepOp op) -> bool { - // Check for either a SliceAttr or LayoutAttr on the result. - auto layoutName = xegpu::getLayoutName(op->getResult(0)); - auto sliceAttr = op->getAttrOfType(layoutName); - if (sliceAttr) - return isLegal(sliceAttr); + target.addDynamicallyLegalOp( + [=](Operation *op) -> bool { + // Check for either a SliceAttr or LayoutAttr on the result. + auto layoutName = xegpu::getLayoutName(op->getResult(0)); + auto sliceAttr = op->getAttrOfType(layoutName); + if (sliceAttr) + return isLegal(sliceAttr); - auto layoutAttr = op->getAttrOfType(layoutName); - if (layoutAttr) - return isLegal(layoutAttr); + auto layoutAttr = op->getAttrOfType(layoutName); + if (layoutAttr) + return isLegal(layoutAttr); - // If neither attribute is present, consider the op legal. - return true; - }); + // If neither attribute is present, consider the op legal. + return true; + }); target.addDynamicallyLegalOp( [=](vector::BroadcastOp op) -> bool { return isLegal(xegpu::getLayoutAttr(op.getResult())); }); - target.addDynamicallyLegalOp( - [=](vector::ShapeCastOp op) -> bool { - return isLegal(xegpu::getLayoutAttr(op.getResult())); - }); - target.addDynamicallyLegalOp( [=](xegpu::ConvertLayoutOp op) -> bool { return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout()); From 6a19aa52c45966a4b291cb5b306e29b533b64f99 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Tue, 26 Aug 2025 16:39:38 +0000 Subject: [PATCH 04/13] Clean up --- .../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 1c13f59151a34..a551e530c0e08 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -793,27 +793,27 @@ struct WgToSgVectorStepOp : public OpConversionPattern { auto layoutName = xegpu::getLayoutName(op->getResult(0)); auto attr = op->getAttr(layoutName); - xegpu::DistributeLayoutAttr layoutAttr = nullptr; + xegpu::DistributeLayoutAttr layout = nullptr; if (auto trySlice = dyn_cast_if_present(attr)) { - layoutAttr = trySlice; + layout = trySlice; } else if (auto tryLayout = dyn_cast_if_present(attr)) { - layoutAttr = tryLayout; + layout = tryLayout; } - if (!layoutAttr) + if (!layout) return failure(); Location loc = op.getLoc(); VectorType type = op.getResult().getType(); auto wgShape = type.getShape(); std::optional> sgShape = - getSgShapeAndCount(wgShape, layoutAttr).first; + getSgShapeAndCount(wgShape, layout).first; if (!sgShape) return failure(); Value sgId = gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); - auto maybeOffsets = layoutAttr.getOffsets(rewriter, loc, sgId, wgShape); + auto maybeOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape); if (failed(maybeOffsets)) return failure(); From f4d7108dbd22c3266913040ea5d1bd264174c745 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Fri, 29 Aug 2025 18:50:06 +0000 Subject: [PATCH 05/13] Temp --- .../Transforms/XeGPUWgToSgDistribute.cpp | 22 +------------------ 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 059641af2219a..f749d55e501e9 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -467,6 +467,7 @@ struct WgToSgVectorBroadcastOp LogicalResult matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + VectorType resultType = op.getResult().getType(); ArrayRef wgShape = resultType.getShape(); @@ -475,34 +476,14 @@ struct WgToSgVectorBroadcastOp if (!layout || !layout.isForWorkgroup()) return failure(); - // TODO: Currently only supports cases where the source and result ranks - // are the same. - auto srcType = - dyn_cast(adaptor.getOperands().front()[0].getType()); - if (!srcType || srcType.getRank() != resultType.getRank()) - return failure(); - SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; VectorType newResultType = VectorType::get(sgShape, resultType.getElementType()); - // Check if the output layout is distributable - SmallVector sgLayout = layout.getSgLayoutAsInt(); - if (sgLayout.empty()) - return failure(); if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout)) return failure(); - // Check if the srcShape has unit dim in dimensions being broadcasted, - // and the other dimensions are the same as the destination type - // TODO: Generalize it - auto srcShape = srcType.getShape(); - for (size_t i = 0; i < srcShape.size(); ++i) { - if (srcShape[i] != 1 && srcShape[i] != sgShape[i]) - return failure(); - } - SmallVector newBroadcastOps; for (auto operand : adaptor.getOperands().front()) { auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(), @@ -518,7 +499,6 @@ struct WgToSgVectorBroadcastOp } newBroadcastOps.push_back(newBroadcast.getResult()); } - rewriter.replaceOpWithMultiple(op, {newBroadcastOps}); return success(); } From 8cb5ebef69d5e4b5a564e666adb17b14c0d85397 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Thu, 4 Sep 2025 17:10:10 +0000 Subject: [PATCH 06/13] Clean up check --- .../Transforms/XeGPUWgToSgDistribute.cpp | 48 ++++++------------- 1 file changed, 14 insertions(+), 34 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index f421f49f96494..af514ca047db8 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -487,15 +487,10 @@ struct WgToSgVectorBroadcastOp for (auto operand : adaptor.getOperands().front()) { auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(), newResultType, operand); - if (auto sliceAttr = dyn_cast_if_present(layout)) { - if (sliceAttr.isForSubgroup()) - xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0), - sliceAttr.dropSgLayoutAndData()); - } else if (auto layoutAttr = - dyn_cast_if_present(layout)) { - if (auto newLayout = layoutAttr.dropSgLayoutAndData()) - xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0), newLayout); - } + if (!layout.getLaneLayoutAsInt().empty()) + xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0), + layout.dropSgLayoutAndData()); + newBroadcastOps.push_back(newBroadcast.getResult()); } rewriter.replaceOpWithMultiple(op, {newBroadcastOps}); @@ -549,13 +544,10 @@ struct WgToSgElementwiseOp : public ConversionPattern { // Copy all attributes, but update "layout_result_0" to drop // sgLayout/sgData for (auto attr : op->getAttrs()) { - if (auto layout = dyn_cast(attr.getValue())) { - if (auto newLayout = layout.dropSgLayoutAndData()) - state.addAttribute(attr.getName(), newLayout); - } else if (auto sliceAttr = - dyn_cast(attr.getValue())) { - if (sliceAttr.isForSubgroup()) - state.addAttribute(attr.getName(), sliceAttr.dropSgLayoutAndData()); + if (auto layout = + dyn_cast(attr.getValue())) { + if (!layout.getLaneLayoutAsInt().empty()) + state.addAttribute(attr.getName(), layout.dropSgLayoutAndData()); } else { state.addAttribute(attr.getName(), attr.getValue()); } @@ -746,15 +738,9 @@ struct WgToSgArithConstantOp : public OpConversionPattern { auto sgAttr = DenseElementsAttr::get(newType, singleVal); auto cstOp = arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr); - if (auto sliceAttr = dyn_cast_if_present(layout)) { - if (sliceAttr.isForSubgroup()) - xegpu::setDistributeLayoutAttr(cstOp->getResult(0), - sliceAttr.dropSgLayoutAndData()); - } else if (auto layoutAttr = - dyn_cast_if_present(layout)) { - if (auto newLayout = layoutAttr.dropSgLayoutAndData()) - xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout); - } + if (!layout.getLaneLayoutAsInt().empty()) + xegpu::setDistributeLayoutAttr(cstOp->getResult(0), + layout.dropSgLayoutAndData()); SmallVector newConsts(count, cstOp); rewriter.replaceOpWithMultiple(op, {newConsts}); @@ -983,15 +969,9 @@ struct WgToSgVectorShapeCastOp for (auto src : adaptor.getSource()) { auto newShapeCast = rewriter.create(op.getLoc(), newResultType, src); - if (auto sliceAttr = dyn_cast_if_present(layout)) { - if (sliceAttr.isForSubgroup()) - xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0), - sliceAttr.dropSgLayoutAndData()); - } else if (auto layoutAttr = - dyn_cast_if_present(layout)) { - if (auto newLayout = layoutAttr.dropSgLayoutAndData()) - xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0), newLayout); - } + if (!layout.getLaneLayoutAsInt().empty()) + xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0), + layout.dropSgLayoutAndData()); newShapeCastOps.push_back(newShapeCast.getResult()); } From 1161e28e3e2424b03a974343254306f59c71b44a Mon Sep 17 00:00:00 2001 From: nbpatel Date: Thu, 4 Sep 2025 23:45:33 +0000 Subject: [PATCH 07/13] Feedback --- .../Transforms/XeGPUWgToSgDistribute.cpp | 28 ++++++++----- .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 40 +++++++++---------- 2 files changed, 37 insertions(+), 31 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index af514ca047db8..3b9bd98742080 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -487,7 +487,8 @@ struct WgToSgVectorBroadcastOp for (auto operand : adaptor.getOperands().front()) { auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(), newResultType, operand); - if (!layout.getLaneLayoutAsInt().empty()) + if (!layout.getLaneLayoutAsInt().empty() || + !layout.getLaneDataAsInt().empty()) xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0), layout.dropSgLayoutAndData()); @@ -546,7 +547,8 @@ struct WgToSgElementwiseOp : public ConversionPattern { for (auto attr : op->getAttrs()) { if (auto layout = dyn_cast(attr.getValue())) { - if (!layout.getLaneLayoutAsInt().empty()) + if (!layout.getLaneLayoutAsInt().empty() || + !layout.getLaneDataAsInt().empty()) state.addAttribute(attr.getName(), layout.dropSgLayoutAndData()); } else { state.addAttribute(attr.getName(), attr.getValue()); @@ -738,7 +740,8 @@ struct WgToSgArithConstantOp : public OpConversionPattern { auto sgAttr = DenseElementsAttr::get(newType, singleVal); auto cstOp = arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr); - if (!layout.getLaneLayoutAsInt().empty()) + if (!layout.getLaneLayoutAsInt().empty() || + !layout.getLaneDataAsInt().empty()) xegpu::setDistributeLayoutAttr(cstOp->getResult(0), layout.dropSgLayoutAndData()); SmallVector newConsts(count, cstOp); @@ -923,18 +926,20 @@ struct WgToSgVectorStepOp : public OpConversionPattern { Value sgId = gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); - auto maybeOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape); - if (failed(maybeOffsets)) + auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape); + if (failed(sgOffsets)) return failure(); VectorType newTy = type.cloneWith(*sgShape, type.getElementType()); - Value base = vector::StepOp::create(rewriter, loc, newTy); + Value steps = vector::StepOp::create(rewriter, loc, newTy); SmallVector newOps; - for (auto offsets : *maybeOffsets) { - Value bcast = + for (auto offsets : *sgOffsets) { + // Broadcast the offset scalar to a vector & add to the base steps + Value bcastOffset = vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]); - Value add = arith::AddIOp::create(rewriter, loc, base, bcast); - newOps.push_back(add); + Value finalSteps = + arith::AddIOp::create(rewriter, loc, steps, bcastOffset); + newOps.push_back(finalSteps); } rewriter.replaceOpWithMultiple(op, {newOps}); @@ -969,7 +974,8 @@ struct WgToSgVectorShapeCastOp for (auto src : adaptor.getSource()) { auto newShapeCast = rewriter.create(op.getLoc(), newResultType, src); - if (!layout.getLaneLayoutAsInt().empty()) + if (!layout.getLaneLayoutAsInt().empty() || + !layout.getInstDataAsInt().empty()) xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0), layout.dropSgLayoutAndData()); newShapeCastOps.push_back(newShapeCast.getResult()); diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index b7e313512e9b9..27d9fa6b06a7b 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -370,15 +370,15 @@ gpu.module @test_distribution { // CHECK-LABEL: vector_step_op gpu.func @vector_step_op_slice_attr() { //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index - //CHECK: [[IDY:%.+]] = affine.apply #map2()[[[sgId]]] - //CHECK: [[c32:%.+]] = arith.constant 32 : index - //CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]] - //CHECK: [[c0:%.+]] = arith.constant 0 : index - //CHECK: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index - //CHECK: [[c128:%.+]] = arith.constant 128 : index - //CHECK: [[MODY:%.+]] = index.remu [[Y]], [[c128]] - //CHECK: [[BASE:%.+]] = vector.step : vector<32xindex> - //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex> + //CHECK-DAG: [[IDY:%.+]] = affine.apply #map2()[[[sgId]]] + //CHECK-DAG: [[c32:%.+]] = arith.constant 32 : index + //CHECK-DAG: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]] + //CHECK-DAG: [[c0:%.+]] = arith.constant 0 : index + //CHECK-DAG: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index + //CHECK-DAG: [[c128:%.+]] = arith.constant 128 : index + //CHECK-DAG: [[MODY:%.+]] = index.remu [[Y]], [[c128]] + //CHECK-DAG: [[BASE:%.+]] = vector.step : vector<32xindex> + //CHECK-DAG: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex> //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex> %step = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [1]>}: vector<128xindex> gpu.return @@ -386,15 +386,15 @@ gpu.module @test_distribution { gpu.func @vector_step_op_layout_attr() { //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index - //CHECK: [[c16:%.+]] = arith.constant 16 : index - //CHECK: [[c8:%.+]] = arith.constant 8 : index - //CHECK: [[LOCALY:%.+]] = index.mul [[sgId]], [[c8]] - //CHECK: [[c0:%.+]] = arith.constant 0 : index - //CHECK: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index - //CHECK: [[c128:%.+]] = arith.constant 128 : index - //CHECK: [[MODY:%.+]] = index.remu [[Y]], [[c128]] - //CHECK: [[BASE:%.+]] = vector.step : vector<8xindex> - //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<8xindex> + //CHECK-DAG: [[c16:%.+]] = arith.constant 16 : index + //CHECK-DAG: [[c8:%.+]] = arith.constant 8 : index + //CHECK-DAG: [[LOCALY:%.+]] = index.mul [[sgId]], [[c8]] + //CHECK-DAG: [[c0:%.+]] = arith.constant 0 : index + //CHECK-DAG: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index + //CHECK-DAG: [[c128:%.+]] = arith.constant 128 : index + //CHECK-DAG: [[MODY:%.+]] = index.remu [[Y]], [[c128]] + //CHECK-DAG: [[BASE:%.+]] = vector.step : vector<8xindex> + //CHECK-DAG: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<8xindex> //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<8xindex> %step = vector.step {layout_result_0 = #xegpu.layout}: vector<128xindex> gpu.return @@ -414,8 +414,8 @@ gpu.module @test_distribution { %load = xegpu.load_nd %tdesc[0, 0] : !xegpu.tensor_desc<256x128xf32, #xegpu.layout> -> vector<256x128xf32> - //CHECK: vector.shape_cast {{.*}} : vector<32x32xf32> to vector<8x4x8x4xf32> - %cast = vector.shape_cast %load {layout_result_0 = #xegpu.layout} : vector<256x128xf32> to vector<16x16x16x8xf32> + //CHECK: vector.shape_cast {{.*}} : vector<32x32xf32> to vector<2x16x4x8xf32> + %cast = vector.shape_cast %load {layout_result_0 = #xegpu.layout} : vector<256x128xf32> to vector<16x16x16x8xf32> gpu.return } } From 9457b54dc1b601df78aae0689d8a179e0dc18129 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Fri, 5 Sep 2025 20:43:05 +0000 Subject: [PATCH 08/13] Add check --- .../Transforms/XeGPUWgToSgDistribute.cpp | 24 +++++++++++++++++++ .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 4 ++-- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 3b9bd98742080..5c15d0749e894 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -970,6 +970,30 @@ struct WgToSgVectorShapeCastOp VectorType newResultType = VectorType::get(sgShape, resultType.getElementType()); + // TODO: Add check for compatible layouts in layout attr. + // Only support ShapeCast which expands or reduces unit dims only. + // That is, only allow shape casts where the non-unit dimensions are + // preserved, and any added or removed dimensions must be of size 1. + auto srcType = dyn_cast(adaptor.getSource()[0].getType()); + if (!srcType) + return failure(); + + auto isUnitOrPreserved = [](ArrayRef src, ArrayRef dst) { + // Remove all 1s from both shapes and compare the rest. + SmallVector srcNonUnit, dstNonUnit; + for (int64_t d : src) + if (d != 1) + srcNonUnit.push_back(d); + for (int64_t d : dst) + if (d != 1) + dstNonUnit.push_back(d); + return srcNonUnit == dstNonUnit; + }; + + if (!isUnitOrPreserved(srcType.getShape(), sgShape) || + !isUnitOrPreserved(sgShape, srcType.getShape())) + return failure(); + SmallVector newShapeCastOps; for (auto src : adaptor.getSource()) { auto newShapeCast = diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 27d9fa6b06a7b..da015c6c0e4a7 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -414,8 +414,8 @@ gpu.module @test_distribution { %load = xegpu.load_nd %tdesc[0, 0] : !xegpu.tensor_desc<256x128xf32, #xegpu.layout> -> vector<256x128xf32> - //CHECK: vector.shape_cast {{.*}} : vector<32x32xf32> to vector<2x16x4x8xf32> - %cast = vector.shape_cast %load {layout_result_0 = #xegpu.layout} : vector<256x128xf32> to vector<16x16x16x8xf32> + //CHECK: vector.shape_cast {{.*}} : vector<32x32xf32> to vector<32x1x32x1xf32> + %cast = vector.shape_cast %load {layout_result_0 = #xegpu.layout} : vector<256x128xf32> to vector<256x1x128x1xf32> gpu.return } } From b8021edc0eeb20b27998691719aeec176930e6b8 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Mon, 8 Sep 2025 15:57:57 +0000 Subject: [PATCH 09/13] Feedback --- .../Transforms/XeGPUWgToSgDistribute.cpp | 25 ++++++++++++------- .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 9 +++++++ 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 5c15d0749e894..0d9ac35f07e02 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -931,14 +931,23 @@ struct WgToSgVectorStepOp : public OpConversionPattern { return failure(); VectorType newTy = type.cloneWith(*sgShape, type.getElementType()); - Value steps = vector::StepOp::create(rewriter, loc, newTy); + auto steps = vector::StepOp::create(rewriter, loc, newTy); SmallVector newOps; for (auto offsets : *sgOffsets) { // Broadcast the offset scalar to a vector & add to the base steps - Value bcastOffset = + auto bcastOffset = vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]); - Value finalSteps = + auto finalSteps = arith::AddIOp::create(rewriter, loc, steps, bcastOffset); + if (!layout.getLaneLayoutAsInt().empty() || + !layout.getLaneDataAsInt().empty()) { + xegpu::setDistributeLayoutAttr(steps->getResult(0), + layout.dropSgLayoutAndData()); + xegpu::setDistributeLayoutAttr(bcastOffset->getResult(0), + layout.dropSgLayoutAndData()); + xegpu::setDistributeLayoutAttr(finalSteps->getResult(0), + layout.dropSgLayoutAndData()); + } newOps.push_back(finalSteps); } @@ -971,14 +980,12 @@ struct WgToSgVectorShapeCastOp VectorType::get(sgShape, resultType.getElementType()); // TODO: Add check for compatible layouts in layout attr. - // Only support ShapeCast which expands or reduces unit dims only. - // That is, only allow shape casts where the non-unit dimensions are - // preserved, and any added or removed dimensions must be of size 1. auto srcType = dyn_cast(adaptor.getSource()[0].getType()); if (!srcType) return failure(); - auto isUnitOrPreserved = [](ArrayRef src, ArrayRef dst) { + // Check that shape_cast only adds/removes unit dimensions, + auto onlyUnitDims = [](ArrayRef src, ArrayRef dst) { // Remove all 1s from both shapes and compare the rest. SmallVector srcNonUnit, dstNonUnit; for (int64_t d : src) @@ -990,8 +997,8 @@ struct WgToSgVectorShapeCastOp return srcNonUnit == dstNonUnit; }; - if (!isUnitOrPreserved(srcType.getShape(), sgShape) || - !isUnitOrPreserved(sgShape, srcType.getShape())) + if (!onlyUnitDims(srcType.getShape(), sgShape) || + !onlyUnitDims(sgShape, srcType.getShape())) return failure(); SmallVector newShapeCastOps; diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index da015c6c0e4a7..7614b8a290ea1 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -418,4 +418,13 @@ gpu.module @test_distribution { %cast = vector.shape_cast %load {layout_result_0 = #xegpu.layout} : vector<256x128xf32> to vector<256x1x128x1xf32> gpu.return } + + // CHECK-LABEL: broadcast + // CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index + gpu.func @broadcast(%arg0: index, %arg1: index) { + %muli = arith.muli %arg0, %arg1 : index + // CHECK: vector.broadcast {{.*}} : index to vector<1x1x1x32xindex> + %broadcast = vector.broadcast %muli {layout_result_0 = #xegpu.layout} : index to vector<4x2x6x32xindex> + gpu.return + } } From 2739fa83ace612f04afb5c6ec0a4ac50d227f1ba Mon Sep 17 00:00:00 2001 From: nbpatel Date: Mon, 8 Sep 2025 18:16:58 +0000 Subject: [PATCH 10/13] Fix --- mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 0d9ac35f07e02..a05dcc9c474b8 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -997,8 +997,7 @@ struct WgToSgVectorShapeCastOp return srcNonUnit == dstNonUnit; }; - if (!onlyUnitDims(srcType.getShape(), sgShape) || - !onlyUnitDims(sgShape, srcType.getShape())) + if (!onlyUnitDims(srcType.getShape(), sgShape)) return failure(); SmallVector newShapeCastOps; From 77f32611477d962d5b46ac034dfc67dbf3e481a1 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Tue, 9 Sep 2025 04:02:00 +0000 Subject: [PATCH 11/13] Add check --- .../Transforms/XeGPUWgToSgDistribute.cpp | 17 +++++++++++ .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 29 +++++++++---------- 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index a05dcc9c474b8..82ed77ae7130a 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -1000,6 +1000,23 @@ struct WgToSgVectorShapeCastOp if (!onlyUnitDims(srcType.getShape(), sgShape)) return failure(); + // Check to verify that if expanding dims, the input operand's layout + // is sliceAttr and if reducing dims, result's layout is + // sliceAttr. + int srcRank = srcType.getRank(); + int dstRank = sgShape.size(); + if (dstRank > srcRank) { + // Expanding dims: input operand's layout must be a SliceAttr + auto srcLayout = xegpu::getDistributeLayoutAttr(op.getSource()); + if (!srcLayout || !isa(srcLayout)) + return failure(); + } else if (dstRank < srcRank) { + // Reducing dims: result's layout must be a SliceAttr + auto resLayout = xegpu::getDistributeLayoutAttr(op.getResult()); + if (!resLayout || !isa(resLayout)) + return failure(); + } + SmallVector newShapeCastOps; for (auto src : adaptor.getSource()) { auto newShapeCast = diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 7614b8a290ea1..3478a9b91da5f 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -408,23 +408,20 @@ gpu.module @test_distribution { } // CHECK-LABEL: vector_shape_cast - gpu.func @vector_shape_cast(%src: memref<256x128xf32>) { - %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32> - -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout> - %load = xegpu.load_nd %tdesc[0, 0] - : !xegpu.tensor_desc<256x128xf32, #xegpu.layout> - -> vector<256x128xf32> - //CHECK: vector.shape_cast {{.*}} : vector<32x32xf32> to vector<32x1x32x1xf32> - %cast = vector.shape_cast %load {layout_result_0 = #xegpu.layout} : vector<256x128xf32> to vector<256x1x128x1xf32> + gpu.func @vector_shape_cast() { + %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0, 1, 2]>} dense<10> : vector<128xindex> + %step = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0, 1, 2]>} : vector<128xindex> + %muli = arith.muli %cst, %step {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0, 1, 2]>} : vector<128xindex> + //CHECK: vector.shape_cast {{.*}} : vector<32xindex> to vector<1x1x1x32xindex> + %shape_cast = vector.shape_cast %muli {layout_result_0 = #xegpu.layout} : vector<128xindex> to vector<1x1x1x128xindex> gpu.return } - // CHECK-LABEL: broadcast - // CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index - gpu.func @broadcast(%arg0: index, %arg1: index) { - %muli = arith.muli %arg0, %arg1 : index - // CHECK: vector.broadcast {{.*}} : index to vector<1x1x1x32xindex> - %broadcast = vector.broadcast %muli {layout_result_0 = #xegpu.layout} : index to vector<4x2x6x32xindex> - gpu.return - } + // CHECK-LABEL: vector_broadcast + gpu.func @vector_broadcast(%arg0: index, %arg1: index) { + %muli = arith.muli %arg0, %arg1 : index + // CHECK: vector.broadcast {{.*}} : index to vector<1x1x1x32xindex> + %broadcast = vector.broadcast %muli {layout_result_0 = #xegpu.layout} : index to vector<4x2x6x32xindex> + gpu.return + } } From d0546b214e244fc36411a93ade6381151b6a282f Mon Sep 17 00:00:00 2001 From: nbpatel Date: Tue, 9 Sep 2025 20:29:07 +0000 Subject: [PATCH 12/13] Add check --- .../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 82ed77ae7130a..c62c90ab9693c 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -1008,12 +1008,22 @@ struct WgToSgVectorShapeCastOp if (dstRank > srcRank) { // Expanding dims: input operand's layout must be a SliceAttr auto srcLayout = xegpu::getDistributeLayoutAttr(op.getSource()); - if (!srcLayout || !isa(srcLayout)) + auto srcSliceAttr = cast(srcLayout); + if (!srcLayout || !srcSliceAttr) + return failure(); + auto resLayout = xegpu::getDistributeLayoutAttr(op.getResult()); + // Check srcLayout is a slice attr on top of resLayout + if (srcSliceAttr.getParent() != resLayout) return failure(); } else if (dstRank < srcRank) { // Reducing dims: result's layout must be a SliceAttr auto resLayout = xegpu::getDistributeLayoutAttr(op.getResult()); - if (!resLayout || !isa(resLayout)) + auto resSliceAttr = cast(resLayout); + auto srcLayout = xegpu::getDistributeLayoutAttr(op.getSource()); + if (!resSliceAttr || !srcLayout) + return failure(); + // Check resLayout is a sliced attr from srcLayout + if (resSliceAttr.getParent() != srcLayout) return failure(); } From 9f3446e9d6a0bd81a58b79b3a18f472707e66202 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Fri, 12 Sep 2025 17:32:35 +0000 Subject: [PATCH 13/13] Clang-format --- .../Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 30f40ad4969a9..d7592fed6d186 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -1000,14 +1000,12 @@ struct WgToSgVectorShapeCastOp if (!onlyUnitDims(srcType.getShape(), sgShape)) return failure(); - // Check to verify that if expanding dims, the input operand's layout - // is sliceAttr and if reducing dims, result's layout is - // sliceAttr. // For rank reducing or increasing shape_cast ops, the lower rank layout // must be a slice of higher rank layout. - int64_t sourceRank = srcType.getRank();; + int64_t sourceRank = srcType.getRank(); int64_t resultRank = sgShape.size(); - xegpu::DistributeLayoutAttr sourceLayout = xegpu::getDistributeLayoutAttr(op.getSource()); + xegpu::DistributeLayoutAttr sourceLayout = + xegpu::getDistributeLayoutAttr(op.getSource()); if (sourceRank < resultRank && !sourceLayout.isSliceOf(layout)) return failure(); if (sourceRank > resultRank && !layout.isSliceOf(sourceLayout))