-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[MLIR][XeGPU] Distribute vector.step & vector.shape_cast op from wg to sg #155443
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
auto layoutName = xegpu::getLayoutName(op->getResult(0)); | ||
auto attr = op->getAttr(layoutName); | ||
|
||
xegpu::DistributeLayoutAttr layout = nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make it static function maybe? seems like it is used in many places
@llvm/pr-subscribers-mlir-gpu Author: Nishant Patel (nbpatel) ChangesFull diff: https://github.com/llvm/llvm-project/pull/155443.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 0b7fe81facfce..059641af2219a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -507,8 +507,15 @@ struct WgToSgVectorBroadcastOp
for (auto operand : adaptor.getOperands().front()) {
auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
newResultType, operand);
- xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
- layout.dropSgLayoutAndData());
+ if (auto sliceAttr = dyn_cast_if_present<xegpu::SliceAttr>(layout)) {
+ if (sliceAttr.isForSubgroup())
+ xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
+ sliceAttr.dropSgLayoutAndData());
+ } else if (auto layoutAttr =
+ dyn_cast_if_present<xegpu::LayoutAttr>(layout)) {
+ if (auto newLayout = layoutAttr.dropSgLayoutAndData())
+ xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0), newLayout);
+ }
newBroadcastOps.push_back(newBroadcast.getResult());
}
@@ -566,6 +573,10 @@ struct WgToSgElementwiseOp : public ConversionPattern {
if (auto layout = dyn_cast<xegpu::LayoutAttr>(attr.getValue())) {
if (auto newLayout = layout.dropSgLayoutAndData())
state.addAttribute(attr.getName(), newLayout);
+ } else if (auto sliceAttr =
+ dyn_cast<xegpu::SliceAttr>(attr.getValue())) {
+ if (sliceAttr.isForSubgroup())
+ state.addAttribute(attr.getName(), sliceAttr.dropSgLayoutAndData());
} else {
state.addAttribute(attr.getName(), attr.getValue());
}
@@ -756,8 +767,15 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
auto sgAttr = DenseElementsAttr::get(newType, singleVal);
auto cstOp =
arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
- if (auto newLayout = layout.dropSgLayoutAndData())
- xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout);
+ if (auto sliceAttr = dyn_cast_if_present<xegpu::SliceAttr>(layout)) {
+ if (sliceAttr.isForSubgroup())
+ xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
+ sliceAttr.dropSgLayoutAndData());
+ } else if (auto layoutAttr =
+ dyn_cast_if_present<xegpu::LayoutAttr>(layout)) {
+ if (auto newLayout = layoutAttr.dropSgLayoutAndData())
+ xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout);
+ }
SmallVector<Value> newConsts(count, cstOp);
rewriter.replaceOpWithMultiple(op, {newConsts});
@@ -765,6 +783,90 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
}
};
+// This pattern distributes the vector.step ops to work at subgroup level
+struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
+ using OpConversionPattern<vector::StepOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getResult());
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+
+ Location loc = op.getLoc();
+ VectorType type = op.getResult().getType();
+ auto wgShape = type.getShape();
+ std::optional<SmallVector<int64_t>> sgShape =
+ getSgShapeAndCount(wgShape, layout).first;
+ if (!sgShape)
+ return failure();
+
+ Value sgId =
+ gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
+ auto maybeOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+ if (failed(maybeOffsets))
+ return failure();
+
+ VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
+ Value base = vector::StepOp::create(rewriter, loc, newTy);
+ SmallVector<Value> newOps;
+ for (auto offsets : *maybeOffsets) {
+ Value bcast =
+ vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
+ Value add = arith::AddIOp::create(rewriter, loc, base, bcast);
+ newOps.push_back(add);
+ }
+
+ rewriter.replaceOpWithMultiple(op, {newOps});
+ return success();
+ }
+};
+
+// This pattern transforms vector.shape_cast ops to work at subgroup level.
+struct WgToSgVectorShapeCastOp
+ : public OpConversionPattern<vector::ShapeCastOp> {
+ using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
+ if (!resultType)
+ return failure();
+
+ ArrayRef<int64_t> wgShape = resultType.getShape();
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getResult());
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+ VectorType newResultType =
+ VectorType::get(sgShape, resultType.getElementType());
+
+ SmallVector<Value> newShapeCastOps;
+ for (auto src : adaptor.getSource()) {
+ auto newShapeCast =
+ rewriter.create<vector::ShapeCastOp>(op.getLoc(), newResultType, src);
+ if (auto sliceAttr = dyn_cast_if_present<xegpu::SliceAttr>(layout)) {
+ if (sliceAttr.isForSubgroup())
+ xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0),
+ sliceAttr.dropSgLayoutAndData());
+ } else if (auto layoutAttr =
+ dyn_cast_if_present<xegpu::LayoutAttr>(layout)) {
+ if (auto newLayout = layoutAttr.dropSgLayoutAndData())
+ xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0), newLayout);
+ }
+ newShapeCastOps.push_back(newShapeCast.getResult());
+ }
+
+ rewriter.replaceOpWithMultiple(op, {newShapeCastOps});
+ return success();
+ }
+};
+
struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
LogicalResult
@@ -826,8 +928,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
- WgToSgArithConstantOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp>(
- patterns.getContext());
+ WgToSgArithConstantOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp,
+ WgToSgVectorStepOp, WgToSgVectorShapeCastOp>(patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
@@ -949,7 +1051,16 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
auto vecType = dyn_cast<VectorType>(op.getType());
if (!vecType)
return true;
- return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
+
+ auto layout = xegpu::getDistributeLayoutAttr(op.getResult());
+ return isLegal(layout);
+ });
+
+ target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp>(
+ [=](Operation *op) -> bool {
+ // Check for either a SliceAttr or LayoutAttr on the result.
+ auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
+ return isLegal(layout);
});
target.addDynamicallyLegalOp<vector::BroadcastOp>(
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 32157a7911f62..7601274ba4969 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -2,6 +2,7 @@
//CHECK: #map = affine_map<()[s0] -> (s0 floordiv 4)>
//CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)>
+//CHECK: #map2 = affine_map<()[s0] -> (s0 floordiv 8)>
gpu.module @test_distribution {
// CHECK-LABEL: create_nd_tdesc_no_offset
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
@@ -321,4 +322,56 @@ gpu.module @test_distribution {
xegpu.store_matrix %cst, %mdesc[0, 0] {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32]>} : vector<64x128xf32>, !xegpu.mem_desc<64x128xf32>
gpu.return
}
+
+ // CHECK-LABEL: vector_step_op
+ gpu.func @vector_step_op_slice_attr() {
+ //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
+ //CHECK: [[IDY:%.+]] = affine.apply #map2()[[[sgId]]]
+ //CHECK: [[c32:%.+]] = arith.constant 32 : index
+ //CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
+ //CHECK: [[c0:%.+]] = arith.constant 0 : index
+ //CHECK: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
+ //CHECK: [[c128:%.+]] = arith.constant 128 : index
+ //CHECK: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
+ //CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
+ //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
+ //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
+ %step = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>, dims = [1]>}: vector<128xindex>
+ gpu.return
+ }
+
+ gpu.func @vector_step_op_layout_attr() {
+ //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
+ //CHECK: [[c16:%.+]] = arith.constant 16 : index
+ //CHECK: [[c8:%.+]] = arith.constant 8 : index
+ //CHECK: [[LOCALY:%.+]] = index.mul [[sgId]], [[c8]]
+ //CHECK: [[c0:%.+]] = arith.constant 0 : index
+ //CHECK: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
+ //CHECK: [[c128:%.+]] = arith.constant 128 : index
+ //CHECK: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
+ //CHECK: [[BASE:%.+]] = vector.step : vector<8xindex>
+ //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<8xindex>
+ //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<8xindex>
+ %step = vector.step {layout_result_0 = #xegpu.layout<sg_layout = [16], sg_data = [8]>}: vector<128xindex>
+ gpu.return
+ }
+
+ // CHECK-LABEL: constant_with_slice_attr
+ gpu.func @constant_with_slice_attr() {
+ //CHECK: [[cst:%.+]] = arith.constant dense<10> : vector<1xindex>
+ %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 1]>, dims = [1, 2, 3]>} dense<10> : vector<4xindex>
+ gpu.return
+ }
+
+ // CHECK-LABEL: vector_shape_cast
+ gpu.func @vector_shape_cast(%src: memref<256x128xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ %load = xegpu.load_nd %tdesc[0, 0]
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<256x128xf32>
+ //CHECK: vector.shape_cast {{.*}} : vector<32x32xf32> to vector<8x4x8x4xf32>
+ %cast = vector.shape_cast %load {layout_result_0 = #xegpu.layout<sg_layout = [2, 4, 2, 2], sg_data = [8, 4, 8, 4]>} : vector<256x128xf32> to vector<16x16x16x8xf32>
+ gpu.return
+ }
}
|
@llvm/pr-subscribers-mlir Author: Nishant Patel (nbpatel) ChangesFull diff: https://github.com/llvm/llvm-project/pull/155443.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 0b7fe81facfce..059641af2219a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -507,8 +507,15 @@ struct WgToSgVectorBroadcastOp
for (auto operand : adaptor.getOperands().front()) {
auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
newResultType, operand);
- xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
- layout.dropSgLayoutAndData());
+ if (auto sliceAttr = dyn_cast_if_present<xegpu::SliceAttr>(layout)) {
+ if (sliceAttr.isForSubgroup())
+ xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
+ sliceAttr.dropSgLayoutAndData());
+ } else if (auto layoutAttr =
+ dyn_cast_if_present<xegpu::LayoutAttr>(layout)) {
+ if (auto newLayout = layoutAttr.dropSgLayoutAndData())
+ xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0), newLayout);
+ }
newBroadcastOps.push_back(newBroadcast.getResult());
}
@@ -566,6 +573,10 @@ struct WgToSgElementwiseOp : public ConversionPattern {
if (auto layout = dyn_cast<xegpu::LayoutAttr>(attr.getValue())) {
if (auto newLayout = layout.dropSgLayoutAndData())
state.addAttribute(attr.getName(), newLayout);
+ } else if (auto sliceAttr =
+ dyn_cast<xegpu::SliceAttr>(attr.getValue())) {
+ if (sliceAttr.isForSubgroup())
+ state.addAttribute(attr.getName(), sliceAttr.dropSgLayoutAndData());
} else {
state.addAttribute(attr.getName(), attr.getValue());
}
@@ -756,8 +767,15 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
auto sgAttr = DenseElementsAttr::get(newType, singleVal);
auto cstOp =
arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
- if (auto newLayout = layout.dropSgLayoutAndData())
- xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout);
+ if (auto sliceAttr = dyn_cast_if_present<xegpu::SliceAttr>(layout)) {
+ if (sliceAttr.isForSubgroup())
+ xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
+ sliceAttr.dropSgLayoutAndData());
+ } else if (auto layoutAttr =
+ dyn_cast_if_present<xegpu::LayoutAttr>(layout)) {
+ if (auto newLayout = layoutAttr.dropSgLayoutAndData())
+ xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout);
+ }
SmallVector<Value> newConsts(count, cstOp);
rewriter.replaceOpWithMultiple(op, {newConsts});
@@ -765,6 +783,90 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
}
};
+// This pattern distributes the vector.step ops to work at subgroup level
+struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
+ using OpConversionPattern<vector::StepOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getResult());
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+
+ Location loc = op.getLoc();
+ VectorType type = op.getResult().getType();
+ auto wgShape = type.getShape();
+ std::optional<SmallVector<int64_t>> sgShape =
+ getSgShapeAndCount(wgShape, layout).first;
+ if (!sgShape)
+ return failure();
+
+ Value sgId =
+ gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
+ auto maybeOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+ if (failed(maybeOffsets))
+ return failure();
+
+ VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
+ Value base = vector::StepOp::create(rewriter, loc, newTy);
+ SmallVector<Value> newOps;
+ for (auto offsets : *maybeOffsets) {
+ Value bcast =
+ vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
+ Value add = arith::AddIOp::create(rewriter, loc, base, bcast);
+ newOps.push_back(add);
+ }
+
+ rewriter.replaceOpWithMultiple(op, {newOps});
+ return success();
+ }
+};
+
+// This pattern transforms vector.shape_cast ops to work at subgroup level.
+struct WgToSgVectorShapeCastOp
+ : public OpConversionPattern<vector::ShapeCastOp> {
+ using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
+ if (!resultType)
+ return failure();
+
+ ArrayRef<int64_t> wgShape = resultType.getShape();
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getResult());
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+ VectorType newResultType =
+ VectorType::get(sgShape, resultType.getElementType());
+
+ SmallVector<Value> newShapeCastOps;
+ for (auto src : adaptor.getSource()) {
+ auto newShapeCast =
+ rewriter.create<vector::ShapeCastOp>(op.getLoc(), newResultType, src);
+ if (auto sliceAttr = dyn_cast_if_present<xegpu::SliceAttr>(layout)) {
+ if (sliceAttr.isForSubgroup())
+ xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0),
+ sliceAttr.dropSgLayoutAndData());
+ } else if (auto layoutAttr =
+ dyn_cast_if_present<xegpu::LayoutAttr>(layout)) {
+ if (auto newLayout = layoutAttr.dropSgLayoutAndData())
+ xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0), newLayout);
+ }
+ newShapeCastOps.push_back(newShapeCast.getResult());
+ }
+
+ rewriter.replaceOpWithMultiple(op, {newShapeCastOps});
+ return success();
+ }
+};
+
struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
LogicalResult
@@ -826,8 +928,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
- WgToSgArithConstantOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp>(
- patterns.getContext());
+ WgToSgArithConstantOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp,
+ WgToSgVectorStepOp, WgToSgVectorShapeCastOp>(patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
@@ -949,7 +1051,16 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
auto vecType = dyn_cast<VectorType>(op.getType());
if (!vecType)
return true;
- return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
+
+ auto layout = xegpu::getDistributeLayoutAttr(op.getResult());
+ return isLegal(layout);
+ });
+
+ target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp>(
+ [=](Operation *op) -> bool {
+ // Check for either a SliceAttr or LayoutAttr on the result.
+ auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
+ return isLegal(layout);
});
target.addDynamicallyLegalOp<vector::BroadcastOp>(
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 32157a7911f62..7601274ba4969 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -2,6 +2,7 @@
//CHECK: #map = affine_map<()[s0] -> (s0 floordiv 4)>
//CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)>
+//CHECK: #map2 = affine_map<()[s0] -> (s0 floordiv 8)>
gpu.module @test_distribution {
// CHECK-LABEL: create_nd_tdesc_no_offset
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
@@ -321,4 +322,56 @@ gpu.module @test_distribution {
xegpu.store_matrix %cst, %mdesc[0, 0] {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32]>} : vector<64x128xf32>, !xegpu.mem_desc<64x128xf32>
gpu.return
}
+
+ // CHECK-LABEL: vector_step_op
+ gpu.func @vector_step_op_slice_attr() {
+ //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
+ //CHECK: [[IDY:%.+]] = affine.apply #map2()[[[sgId]]]
+ //CHECK: [[c32:%.+]] = arith.constant 32 : index
+ //CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
+ //CHECK: [[c0:%.+]] = arith.constant 0 : index
+ //CHECK: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
+ //CHECK: [[c128:%.+]] = arith.constant 128 : index
+ //CHECK: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
+ //CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
+ //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
+ //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
+ %step = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>, dims = [1]>}: vector<128xindex>
+ gpu.return
+ }
+
+ gpu.func @vector_step_op_layout_attr() {
+ //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
+ //CHECK: [[c16:%.+]] = arith.constant 16 : index
+ //CHECK: [[c8:%.+]] = arith.constant 8 : index
+ //CHECK: [[LOCALY:%.+]] = index.mul [[sgId]], [[c8]]
+ //CHECK: [[c0:%.+]] = arith.constant 0 : index
+ //CHECK: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
+ //CHECK: [[c128:%.+]] = arith.constant 128 : index
+ //CHECK: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
+ //CHECK: [[BASE:%.+]] = vector.step : vector<8xindex>
+ //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<8xindex>
+ //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<8xindex>
+ %step = vector.step {layout_result_0 = #xegpu.layout<sg_layout = [16], sg_data = [8]>}: vector<128xindex>
+ gpu.return
+ }
+
+ // CHECK-LABEL: constant_with_slice_attr
+ gpu.func @constant_with_slice_attr() {
+ //CHECK: [[cst:%.+]] = arith.constant dense<10> : vector<1xindex>
+ %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 1]>, dims = [1, 2, 3]>} dense<10> : vector<4xindex>
+ gpu.return
+ }
+
+ // CHECK-LABEL: vector_shape_cast
+ gpu.func @vector_shape_cast(%src: memref<256x128xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ %load = xegpu.load_nd %tdesc[0, 0]
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<256x128xf32>
+ //CHECK: vector.shape_cast {{.*}} : vector<32x32xf32> to vector<8x4x8x4xf32>
+ %cast = vector.shape_cast %load {layout_result_0 = #xegpu.layout<sg_layout = [2, 4, 2, 2], sg_data = [8, 4, 8, 4]>} : vector<256x128xf32> to vector<16x16x16x8xf32>
+ gpu.return
+ }
}
|
@@ -507,8 +507,15 @@ struct WgToSgVectorBroadcastOp | |||
for (auto operand : adaptor.getOperands().front()) { | |||
auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(), | |||
newResultType, operand); | |||
xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0), | |||
layout.dropSgLayoutAndData()); | |||
if (auto sliceAttr = dyn_cast_if_present<xegpu::SliceAttr>(layout)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please mention this and Elementwise op (and arith) changes in the description
Ping for reviews |
This PR adds patterns to distribute vector.step and vector.shape_cast op from wg to sg and it also enables constant, broadcast and elementwise ops to handle the slice attribute