Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 118 additions & 7 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -507,8 +507,15 @@ struct WgToSgVectorBroadcastOp
for (auto operand : adaptor.getOperands().front()) {
auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
newResultType, operand);
xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
layout.dropSgLayoutAndData());
if (auto sliceAttr = dyn_cast_if_present<xegpu::SliceAttr>(layout)) {
if (sliceAttr.isForSubgroup())
xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
sliceAttr.dropSgLayoutAndData());
} else if (auto layoutAttr =
dyn_cast_if_present<xegpu::LayoutAttr>(layout)) {
if (auto newLayout = layoutAttr.dropSgLayoutAndData())
xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0), newLayout);
}
newBroadcastOps.push_back(newBroadcast.getResult());
}

Expand Down Expand Up @@ -566,6 +573,10 @@ struct WgToSgElementwiseOp : public ConversionPattern {
if (auto layout = dyn_cast<xegpu::LayoutAttr>(attr.getValue())) {
if (auto newLayout = layout.dropSgLayoutAndData())
state.addAttribute(attr.getName(), newLayout);
} else if (auto sliceAttr =
dyn_cast<xegpu::SliceAttr>(attr.getValue())) {
if (sliceAttr.isForSubgroup())
state.addAttribute(attr.getName(), sliceAttr.dropSgLayoutAndData());
} else {
state.addAttribute(attr.getName(), attr.getValue());
}
Expand Down Expand Up @@ -756,15 +767,106 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
auto sgAttr = DenseElementsAttr::get(newType, singleVal);
auto cstOp =
arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
if (auto newLayout = layout.dropSgLayoutAndData())
xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout);
if (auto sliceAttr = dyn_cast_if_present<xegpu::SliceAttr>(layout)) {
if (sliceAttr.isForSubgroup())
xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
sliceAttr.dropSgLayoutAndData());
} else if (auto layoutAttr =
dyn_cast_if_present<xegpu::LayoutAttr>(layout)) {
if (auto newLayout = layoutAttr.dropSgLayoutAndData())
xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout);
}
SmallVector<Value> newConsts(count, cstOp);

rewriter.replaceOpWithMultiple(op, {newConsts});
return success();
}
};

// This pattern distributes the vector.step ops to work at subgroup level
struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
using OpConversionPattern<vector::StepOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
xegpu::DistributeLayoutAttr layout =
xegpu::getDistributeLayoutAttr(op.getResult());
if (!layout || !layout.isForWorkgroup())
return failure();

Location loc = op.getLoc();
VectorType type = op.getResult().getType();
auto wgShape = type.getShape();
std::optional<SmallVector<int64_t>> sgShape =
getSgShapeAndCount(wgShape, layout).first;
if (!sgShape)
return failure();

Value sgId =
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
auto maybeOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
if (failed(maybeOffsets))
return failure();

VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
Value base = vector::StepOp::create(rewriter, loc, newTy);
SmallVector<Value> newOps;
for (auto offsets : *maybeOffsets) {
Value bcast =
vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
Value add = arith::AddIOp::create(rewriter, loc, base, bcast);
newOps.push_back(add);
}

rewriter.replaceOpWithMultiple(op, {newOps});
return success();
}
};

// This pattern transforms vector.shape_cast ops to work at subgroup level.
struct WgToSgVectorShapeCastOp
: public OpConversionPattern<vector::ShapeCastOp> {
using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
if (!resultType)
return failure();

ArrayRef<int64_t> wgShape = resultType.getShape();
xegpu::DistributeLayoutAttr layout =
xegpu::getDistributeLayoutAttr(op.getResult());
if (!layout || !layout.isForWorkgroup())
return failure();

SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
VectorType newResultType =
VectorType::get(sgShape, resultType.getElementType());

SmallVector<Value> newShapeCastOps;
for (auto src : adaptor.getSource()) {
auto newShapeCast =
rewriter.create<vector::ShapeCastOp>(op.getLoc(), newResultType, src);
if (auto sliceAttr = dyn_cast_if_present<xegpu::SliceAttr>(layout)) {
if (sliceAttr.isForSubgroup())
xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0),
sliceAttr.dropSgLayoutAndData());
} else if (auto layoutAttr =
dyn_cast_if_present<xegpu::LayoutAttr>(layout)) {
if (auto newLayout = layoutAttr.dropSgLayoutAndData())
xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0), newLayout);
}
newShapeCastOps.push_back(newShapeCast.getResult());
}

rewriter.replaceOpWithMultiple(op, {newShapeCastOps});
return success();
}
};

struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
LogicalResult
Expand Down Expand Up @@ -826,8 +928,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
WgToSgArithConstantOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp>(
patterns.getContext());
WgToSgArithConstantOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp,
WgToSgVectorStepOp, WgToSgVectorShapeCastOp>(patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
Expand Down Expand Up @@ -949,7 +1051,16 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
auto vecType = dyn_cast<VectorType>(op.getType());
if (!vecType)
return true;
return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));

auto layout = xegpu::getDistributeLayoutAttr(op.getResult());
return isLegal(layout);
});

target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp>(
[=](Operation *op) -> bool {
// Check for either a SliceAttr or LayoutAttr on the result.
auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
return isLegal(layout);
});

target.addDynamicallyLegalOp<vector::BroadcastOp>(
Expand Down
53 changes: 53 additions & 0 deletions mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

//CHECK: #map = affine_map<()[s0] -> (s0 floordiv 4)>
//CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)>
//CHECK: #map2 = affine_map<()[s0] -> (s0 floordiv 8)>
gpu.module @test_distribution {
// CHECK-LABEL: create_nd_tdesc_no_offset
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
Expand Down Expand Up @@ -321,4 +322,56 @@ gpu.module @test_distribution {
xegpu.store_matrix %cst, %mdesc[0, 0] {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32]>} : vector<64x128xf32>, !xegpu.mem_desc<64x128xf32>
gpu.return
}

// CHECK-LABEL: vector_step_op
gpu.func @vector_step_op_slice_attr() {
//CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
//CHECK: [[IDY:%.+]] = affine.apply #map2()[[[sgId]]]
//CHECK: [[c32:%.+]] = arith.constant 32 : index
//CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
//CHECK: [[c0:%.+]] = arith.constant 0 : index
//CHECK: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
//CHECK: [[c128:%.+]] = arith.constant 128 : index
//CHECK: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
//CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
//CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
//CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
%step = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>, dims = [1]>}: vector<128xindex>
gpu.return
}

gpu.func @vector_step_op_layout_attr() {
//CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
//CHECK: [[c16:%.+]] = arith.constant 16 : index
//CHECK: [[c8:%.+]] = arith.constant 8 : index
//CHECK: [[LOCALY:%.+]] = index.mul [[sgId]], [[c8]]
//CHECK: [[c0:%.+]] = arith.constant 0 : index
//CHECK: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
//CHECK: [[c128:%.+]] = arith.constant 128 : index
//CHECK: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
//CHECK: [[BASE:%.+]] = vector.step : vector<8xindex>
//CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<8xindex>
//CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<8xindex>
%step = vector.step {layout_result_0 = #xegpu.layout<sg_layout = [16], sg_data = [8]>}: vector<128xindex>
gpu.return
}

// CHECK-LABEL: constant_with_slice_attr
gpu.func @constant_with_slice_attr() {
//CHECK: [[cst:%.+]] = arith.constant dense<10> : vector<1xindex>
%cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 1]>, dims = [1, 2, 3]>} dense<10> : vector<4xindex>
gpu.return
}

// CHECK-LABEL: vector_shape_cast
gpu.func @vector_shape_cast(%src: memref<256x128xf32>) {
%tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
%load = xegpu.load_nd %tdesc[0, 0]
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
-> vector<256x128xf32>
//CHECK: vector.shape_cast {{.*}} : vector<32x32xf32> to vector<8x4x8x4xf32>
%cast = vector.shape_cast %load {layout_result_0 = #xegpu.layout<sg_layout = [2, 4, 2, 2], sg_data = [8, 4, 8, 4]>} : vector<256x128xf32> to vector<16x16x16x8xf32>
gpu.return
}
}