Skip to content

Conversation

nbpatel
Copy link
Contributor

@nbpatel nbpatel commented Aug 26, 2025

This PR adds patterns to distribute vector.step and vector.shape_cast op from wg to sg and it also enables constant, broadcast and elementwise ops to handle the slice attribute

@nbpatel nbpatel requested a review from chencha3 August 26, 2025 16:08
auto layoutName = xegpu::getLayoutName(op->getResult(0));
auto attr = op->getAttr(layoutName);

xegpu::DistributeLayoutAttr layout = nullptr;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make it static function maybe? seems like it is used in many places

@nbpatel nbpatel marked this pull request as ready for review August 28, 2025 15:14
@nbpatel nbpatel requested a review from adam-smnk August 28, 2025 15:14
@llvmbot
Copy link
Member

llvmbot commented Aug 28, 2025

@llvm/pr-subscribers-mlir-gpu

Author: Nishant Patel (nbpatel)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/155443.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+118-7)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir (+53)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 0b7fe81facfce..059641af2219a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -507,8 +507,15 @@ struct WgToSgVectorBroadcastOp
     for (auto operand : adaptor.getOperands().front()) {
       auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
                                                       newResultType, operand);
-      xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
-                                     layout.dropSgLayoutAndData());
+      if (auto sliceAttr = dyn_cast_if_present<xegpu::SliceAttr>(layout)) {
+        if (sliceAttr.isForSubgroup())
+          xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
+                                         sliceAttr.dropSgLayoutAndData());
+      } else if (auto layoutAttr =
+                     dyn_cast_if_present<xegpu::LayoutAttr>(layout)) {
+        if (auto newLayout = layoutAttr.dropSgLayoutAndData())
+          xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0), newLayout);
+      }
       newBroadcastOps.push_back(newBroadcast.getResult());
     }
 
@@ -566,6 +573,10 @@ struct WgToSgElementwiseOp : public ConversionPattern {
         if (auto layout = dyn_cast<xegpu::LayoutAttr>(attr.getValue())) {
           if (auto newLayout = layout.dropSgLayoutAndData())
             state.addAttribute(attr.getName(), newLayout);
+        } else if (auto sliceAttr =
+                       dyn_cast<xegpu::SliceAttr>(attr.getValue())) {
+          if (sliceAttr.isForSubgroup())
+            state.addAttribute(attr.getName(), sliceAttr.dropSgLayoutAndData());
         } else {
           state.addAttribute(attr.getName(), attr.getValue());
         }
@@ -756,8 +767,15 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
     auto sgAttr = DenseElementsAttr::get(newType, singleVal);
     auto cstOp =
         arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
-    if (auto newLayout = layout.dropSgLayoutAndData())
-      xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout);
+    if (auto sliceAttr = dyn_cast_if_present<xegpu::SliceAttr>(layout)) {
+      if (sliceAttr.isForSubgroup())
+        xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
+                                       sliceAttr.dropSgLayoutAndData());
+    } else if (auto layoutAttr =
+                   dyn_cast_if_present<xegpu::LayoutAttr>(layout)) {
+      if (auto newLayout = layoutAttr.dropSgLayoutAndData())
+        xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout);
+    }
     SmallVector<Value> newConsts(count, cstOp);
 
     rewriter.replaceOpWithMultiple(op, {newConsts});
@@ -765,6 +783,90 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
   }
 };
 
+// This pattern distributes the vector.step ops to work at subgroup level
+struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
+  using OpConversionPattern<vector::StepOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getDistributeLayoutAttr(op.getResult());
+    if (!layout || !layout.isForWorkgroup())
+      return failure();
+
+    Location loc = op.getLoc();
+    VectorType type = op.getResult().getType();
+    auto wgShape = type.getShape();
+    std::optional<SmallVector<int64_t>> sgShape =
+        getSgShapeAndCount(wgShape, layout).first;
+    if (!sgShape)
+      return failure();
+
+    Value sgId =
+        gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
+    auto maybeOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+    if (failed(maybeOffsets))
+      return failure();
+
+    VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
+    Value base = vector::StepOp::create(rewriter, loc, newTy);
+    SmallVector<Value> newOps;
+    for (auto offsets : *maybeOffsets) {
+      Value bcast =
+          vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
+      Value add = arith::AddIOp::create(rewriter, loc, base, bcast);
+      newOps.push_back(add);
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newOps});
+    return success();
+  }
+};
+
+// This pattern transforms vector.shape_cast ops to work at subgroup level.
+struct WgToSgVectorShapeCastOp
+    : public OpConversionPattern<vector::ShapeCastOp> {
+  using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
+    if (!resultType)
+      return failure();
+
+    ArrayRef<int64_t> wgShape = resultType.getShape();
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getDistributeLayoutAttr(op.getResult());
+    if (!layout || !layout.isForWorkgroup())
+      return failure();
+
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+    VectorType newResultType =
+        VectorType::get(sgShape, resultType.getElementType());
+
+    SmallVector<Value> newShapeCastOps;
+    for (auto src : adaptor.getSource()) {
+      auto newShapeCast =
+          rewriter.create<vector::ShapeCastOp>(op.getLoc(), newResultType, src);
+      if (auto sliceAttr = dyn_cast_if_present<xegpu::SliceAttr>(layout)) {
+        if (sliceAttr.isForSubgroup())
+          xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0),
+                                         sliceAttr.dropSgLayoutAndData());
+      } else if (auto layoutAttr =
+                     dyn_cast_if_present<xegpu::LayoutAttr>(layout)) {
+        if (auto newLayout = layoutAttr.dropSgLayoutAndData())
+          xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0), newLayout);
+      }
+      newShapeCastOps.push_back(newShapeCast.getResult());
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newShapeCastOps});
+    return success();
+  }
+};
+
 struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
   using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
   LogicalResult
@@ -826,8 +928,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
            WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
            WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
            WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
-           WgToSgArithConstantOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp>(
-          patterns.getContext());
+           WgToSgArithConstantOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp,
+           WgToSgVectorStepOp, WgToSgVectorShapeCastOp>(patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
@@ -949,7 +1051,16 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
         auto vecType = dyn_cast<VectorType>(op.getType());
         if (!vecType)
           return true;
-        return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
+
+        auto layout = xegpu::getDistributeLayoutAttr(op.getResult());
+        return isLegal(layout);
+      });
+
+  target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp>(
+      [=](Operation *op) -> bool {
+        // Check for either a SliceAttr or LayoutAttr on the result.
+        auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
+        return isLegal(layout);
       });
 
   target.addDynamicallyLegalOp<vector::BroadcastOp>(
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 32157a7911f62..7601274ba4969 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -2,6 +2,7 @@
 
 //CHECK: #map = affine_map<()[s0] -> (s0 floordiv 4)>
 //CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)>
+//CHECK: #map2 = affine_map<()[s0] -> (s0 floordiv 8)>
 gpu.module @test_distribution {
   // CHECK-LABEL: create_nd_tdesc_no_offset
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
@@ -321,4 +322,56 @@ gpu.module @test_distribution {
     xegpu.store_matrix %cst, %mdesc[0, 0] {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32]>} : vector<64x128xf32>, !xegpu.mem_desc<64x128xf32>
     gpu.return
   }
+
+  // CHECK-LABEL: vector_step_op
+  gpu.func @vector_step_op_slice_attr() {
+    //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
+    //CHECK: [[IDY:%.+]] = affine.apply #map2()[[[sgId]]]
+    //CHECK: [[c32:%.+]] = arith.constant 32 : index
+    //CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
+    //CHECK: [[c0:%.+]] = arith.constant 0 : index
+    //CHECK: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
+    //CHECK: [[c128:%.+]] = arith.constant 128 : index
+    //CHECK: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
+    //CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
+    //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
+    //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
+    %step = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>, dims = [1]>}: vector<128xindex>
+    gpu.return
+  }
+
+  gpu.func @vector_step_op_layout_attr() {
+    //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
+    //CHECK: [[c16:%.+]] = arith.constant 16 : index
+    //CHECK: [[c8:%.+]] = arith.constant 8 : index
+    //CHECK: [[LOCALY:%.+]] = index.mul [[sgId]], [[c8]]
+    //CHECK: [[c0:%.+]] = arith.constant 0 : index
+    //CHECK: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
+    //CHECK: [[c128:%.+]] = arith.constant 128 : index
+    //CHECK: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
+    //CHECK: [[BASE:%.+]] = vector.step : vector<8xindex>
+    //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<8xindex>
+    //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<8xindex>
+    %step = vector.step {layout_result_0 = #xegpu.layout<sg_layout = [16], sg_data = [8]>}: vector<128xindex>
+    gpu.return
+  }
+
+  // CHECK-LABEL: constant_with_slice_attr
+  gpu.func @constant_with_slice_attr() {
+    //CHECK: [[cst:%.+]] = arith.constant dense<10> : vector<1xindex>
+    %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 1]>, dims = [1, 2, 3]>} dense<10> : vector<4xindex>
+    gpu.return
+  }
+
+  // CHECK-LABEL: vector_shape_cast
+  gpu.func @vector_shape_cast(%src: memref<256x128xf32>) {
+    %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
+      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc[0, 0]
+      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+      -> vector<256x128xf32>
+    //CHECK: vector.shape_cast {{.*}} : vector<32x32xf32> to vector<8x4x8x4xf32>
+    %cast = vector.shape_cast %load {layout_result_0 = #xegpu.layout<sg_layout = [2, 4, 2, 2], sg_data = [8, 4, 8, 4]>} : vector<256x128xf32> to vector<16x16x16x8xf32>
+    gpu.return
+  }
 }

@llvmbot
Copy link
Member

llvmbot commented Aug 28, 2025

@llvm/pr-subscribers-mlir

Author: Nishant Patel (nbpatel)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/155443.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+118-7)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir (+53)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 0b7fe81facfce..059641af2219a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -507,8 +507,15 @@ struct WgToSgVectorBroadcastOp
     for (auto operand : adaptor.getOperands().front()) {
       auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
                                                       newResultType, operand);
-      xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
-                                     layout.dropSgLayoutAndData());
+      if (auto sliceAttr = dyn_cast_if_present<xegpu::SliceAttr>(layout)) {
+        if (sliceAttr.isForSubgroup())
+          xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
+                                         sliceAttr.dropSgLayoutAndData());
+      } else if (auto layoutAttr =
+                     dyn_cast_if_present<xegpu::LayoutAttr>(layout)) {
+        if (auto newLayout = layoutAttr.dropSgLayoutAndData())
+          xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0), newLayout);
+      }
       newBroadcastOps.push_back(newBroadcast.getResult());
     }
 
@@ -566,6 +573,10 @@ struct WgToSgElementwiseOp : public ConversionPattern {
         if (auto layout = dyn_cast<xegpu::LayoutAttr>(attr.getValue())) {
           if (auto newLayout = layout.dropSgLayoutAndData())
             state.addAttribute(attr.getName(), newLayout);
+        } else if (auto sliceAttr =
+                       dyn_cast<xegpu::SliceAttr>(attr.getValue())) {
+          if (sliceAttr.isForSubgroup())
+            state.addAttribute(attr.getName(), sliceAttr.dropSgLayoutAndData());
         } else {
           state.addAttribute(attr.getName(), attr.getValue());
         }
@@ -756,8 +767,15 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
     auto sgAttr = DenseElementsAttr::get(newType, singleVal);
     auto cstOp =
         arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
-    if (auto newLayout = layout.dropSgLayoutAndData())
-      xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout);
+    if (auto sliceAttr = dyn_cast_if_present<xegpu::SliceAttr>(layout)) {
+      if (sliceAttr.isForSubgroup())
+        xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
+                                       sliceAttr.dropSgLayoutAndData());
+    } else if (auto layoutAttr =
+                   dyn_cast_if_present<xegpu::LayoutAttr>(layout)) {
+      if (auto newLayout = layoutAttr.dropSgLayoutAndData())
+        xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout);
+    }
     SmallVector<Value> newConsts(count, cstOp);
 
     rewriter.replaceOpWithMultiple(op, {newConsts});
@@ -765,6 +783,90 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
   }
 };
 
+// This pattern distributes the vector.step ops to work at subgroup level
+struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
+  using OpConversionPattern<vector::StepOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getDistributeLayoutAttr(op.getResult());
+    if (!layout || !layout.isForWorkgroup())
+      return failure();
+
+    Location loc = op.getLoc();
+    VectorType type = op.getResult().getType();
+    auto wgShape = type.getShape();
+    std::optional<SmallVector<int64_t>> sgShape =
+        getSgShapeAndCount(wgShape, layout).first;
+    if (!sgShape)
+      return failure();
+
+    Value sgId =
+        gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
+    auto maybeOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+    if (failed(maybeOffsets))
+      return failure();
+
+    VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
+    Value base = vector::StepOp::create(rewriter, loc, newTy);
+    SmallVector<Value> newOps;
+    for (auto offsets : *maybeOffsets) {
+      Value bcast =
+          vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
+      Value add = arith::AddIOp::create(rewriter, loc, base, bcast);
+      newOps.push_back(add);
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newOps});
+    return success();
+  }
+};
+
+// This pattern transforms vector.shape_cast ops to work at subgroup level.
+struct WgToSgVectorShapeCastOp
+    : public OpConversionPattern<vector::ShapeCastOp> {
+  using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
+    if (!resultType)
+      return failure();
+
+    ArrayRef<int64_t> wgShape = resultType.getShape();
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getDistributeLayoutAttr(op.getResult());
+    if (!layout || !layout.isForWorkgroup())
+      return failure();
+
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+    VectorType newResultType =
+        VectorType::get(sgShape, resultType.getElementType());
+
+    SmallVector<Value> newShapeCastOps;
+    for (auto src : adaptor.getSource()) {
+      auto newShapeCast =
+          rewriter.create<vector::ShapeCastOp>(op.getLoc(), newResultType, src);
+      if (auto sliceAttr = dyn_cast_if_present<xegpu::SliceAttr>(layout)) {
+        if (sliceAttr.isForSubgroup())
+          xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0),
+                                         sliceAttr.dropSgLayoutAndData());
+      } else if (auto layoutAttr =
+                     dyn_cast_if_present<xegpu::LayoutAttr>(layout)) {
+        if (auto newLayout = layoutAttr.dropSgLayoutAndData())
+          xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0), newLayout);
+      }
+      newShapeCastOps.push_back(newShapeCast.getResult());
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newShapeCastOps});
+    return success();
+  }
+};
+
 struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
   using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
   LogicalResult
@@ -826,8 +928,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
            WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
            WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
            WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
-           WgToSgArithConstantOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp>(
-          patterns.getContext());
+           WgToSgArithConstantOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp,
+           WgToSgVectorStepOp, WgToSgVectorShapeCastOp>(patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
@@ -949,7 +1051,16 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
         auto vecType = dyn_cast<VectorType>(op.getType());
         if (!vecType)
           return true;
-        return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
+
+        auto layout = xegpu::getDistributeLayoutAttr(op.getResult());
+        return isLegal(layout);
+      });
+
+  target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp>(
+      [=](Operation *op) -> bool {
+        // Check for either a SliceAttr or LayoutAttr on the result.
+        auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
+        return isLegal(layout);
       });
 
   target.addDynamicallyLegalOp<vector::BroadcastOp>(
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 32157a7911f62..7601274ba4969 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -2,6 +2,7 @@
 
 //CHECK: #map = affine_map<()[s0] -> (s0 floordiv 4)>
 //CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)>
+//CHECK: #map2 = affine_map<()[s0] -> (s0 floordiv 8)>
 gpu.module @test_distribution {
   // CHECK-LABEL: create_nd_tdesc_no_offset
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
@@ -321,4 +322,56 @@ gpu.module @test_distribution {
     xegpu.store_matrix %cst, %mdesc[0, 0] {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32]>} : vector<64x128xf32>, !xegpu.mem_desc<64x128xf32>
     gpu.return
   }
+
+  // CHECK-LABEL: vector_step_op
+  gpu.func @vector_step_op_slice_attr() {
+    //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
+    //CHECK: [[IDY:%.+]] = affine.apply #map2()[[[sgId]]]
+    //CHECK: [[c32:%.+]] = arith.constant 32 : index
+    //CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
+    //CHECK: [[c0:%.+]] = arith.constant 0 : index
+    //CHECK: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
+    //CHECK: [[c128:%.+]] = arith.constant 128 : index
+    //CHECK: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
+    //CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
+    //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
+    //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
+    %step = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>, dims = [1]>}: vector<128xindex>
+    gpu.return
+  }
+
+  gpu.func @vector_step_op_layout_attr() {
+    //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
+    //CHECK: [[c16:%.+]] = arith.constant 16 : index
+    //CHECK: [[c8:%.+]] = arith.constant 8 : index
+    //CHECK: [[LOCALY:%.+]] = index.mul [[sgId]], [[c8]]
+    //CHECK: [[c0:%.+]] = arith.constant 0 : index
+    //CHECK: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
+    //CHECK: [[c128:%.+]] = arith.constant 128 : index
+    //CHECK: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
+    //CHECK: [[BASE:%.+]] = vector.step : vector<8xindex>
+    //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<8xindex>
+    //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<8xindex>
+    %step = vector.step {layout_result_0 = #xegpu.layout<sg_layout = [16], sg_data = [8]>}: vector<128xindex>
+    gpu.return
+  }
+
+  // CHECK-LABEL: constant_with_slice_attr
+  gpu.func @constant_with_slice_attr() {
+    //CHECK: [[cst:%.+]] = arith.constant dense<10> : vector<1xindex>
+    %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 1]>, dims = [1, 2, 3]>} dense<10> : vector<4xindex>
+    gpu.return
+  }
+
+  // CHECK-LABEL: vector_shape_cast
+  gpu.func @vector_shape_cast(%src: memref<256x128xf32>) {
+    %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
+      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc[0, 0]
+      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+      -> vector<256x128xf32>
+    //CHECK: vector.shape_cast {{.*}} : vector<32x32xf32> to vector<8x4x8x4xf32>
+    %cast = vector.shape_cast %load {layout_result_0 = #xegpu.layout<sg_layout = [2, 4, 2, 2], sg_data = [8, 4, 8, 4]>} : vector<256x128xf32> to vector<16x16x16x8xf32>
+    gpu.return
+  }
 }

@@ -507,8 +507,15 @@ struct WgToSgVectorBroadcastOp
for (auto operand : adaptor.getOperands().front()) {
auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
newResultType, operand);
xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
layout.dropSgLayoutAndData());
if (auto sliceAttr = dyn_cast_if_present<xegpu::SliceAttr>(layout)) {
Copy link

@Garra1980 Garra1980 Aug 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please mention this and Elementwise op (and arith) changes in the description

@nbpatel
Copy link
Contributor Author

nbpatel commented Sep 2, 2025

Ping for reviews

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants