diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 81e25f7537cb0..e9f8437d7c102 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -29,9 +29,22 @@ class XeGPU_Op traits = []>: void printProperties(::mlir::MLIRContext *ctx, ::mlir::OpAsmPrinter &p, const Properties &prop, ::mlir::ArrayRef<::llvm::StringRef> elidedProps) { - Attribute propAttr = getPropertiesAsAttr(ctx, prop); - if (propAttr) - p << "<" << propAttr << ">"; + + DictionaryAttr propAttr = dyn_cast_if_present(getPropertiesAsAttr(ctx, prop)); + + // filter out the elidedProps from propAttr, and get the resultAttr + mlir::SmallVector filteredAttrs; + if (propAttr) { + for (auto namedAttr : propAttr.getValue()) { + if (llvm::is_contained(elidedProps, namedAttr.getName().strref())) + continue; + filteredAttrs.push_back(namedAttr); + } + } + + if (!filteredAttrs.empty()) { + p << "<" << DictionaryAttr::get(ctx, filteredAttrs) << ">"; + } } static ::mlir::ParseResult parseProperties(::mlir::OpAsmParser &parser, @@ -288,6 +301,8 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> { }]; let arguments = (ins XeGPU_TensorDesc: $TensorDesc, + Variadic: $offsets, + OptionalAttr: $const_offsets, OptionalAttr: $l1_hint, OptionalAttr: $l2_hint, OptionalAttr: $l3_hint); @@ -298,7 +313,18 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> { } }]; - let assemblyFormat = "$TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc))"; + let assemblyFormat = [{ + $TensorDesc `` + custom($offsets, $const_offsets) + prop-dict attr-dict `:` qualified(type($TensorDesc)) + }]; + + let builders = [ + OpBuilder<(ins "Value": $TensorDesc, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, + "xegpu::CachePolicyAttr": $l3_hint)> + ]; let hasVerifier = 1; } @@ -343,6 +369,8 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [ }]; let arguments = (ins XeGPU_TensorDesc: $TensorDesc, + Variadic: $offsets, + OptionalAttr: $const_offsets, OptionalAttr: $packed, OptionalAttr: $transpose, OptionalAttr: $l1_hint, @@ -361,7 +389,20 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [ } }]; - let assemblyFormat = "$TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc)) `->` type($value)"; + let assemblyFormat = [{ + $TensorDesc `` + custom($offsets, $const_offsets) + prop-dict attr-dict `:` qualified(type($TensorDesc)) `->` type($value) + }]; + + let builders = [ + OpBuilder<(ins "Type": $value, "Value": $TensorDesc, + "UnitAttr": $packed, "DenseI64ArrayAttr": $transpose, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, + "xegpu::CachePolicyAttr": $l3_hint)> + ]; + let hasVerifier = 1; } @@ -400,6 +441,8 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [ let arguments = (ins XeGPU_ValueType: $value, XeGPU_TensorDesc: $TensorDesc, + Variadic: $offsets, + OptionalAttr: $const_offsets, OptionalAttr: $l1_hint, OptionalAttr: $l2_hint, OptionalAttr: $l3_hint); @@ -414,8 +457,21 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [ } }]; - let assemblyFormat = [{$value `,` $TensorDesc prop-dict attr-dict - `:` type($value) `,` qualified(type($TensorDesc))}]; + let assemblyFormat = [{ + $value `,` + $TensorDesc `` + custom($offsets, $const_offsets) + prop-dict attr-dict `:` type($value) `,` qualified(type($TensorDesc)) + }]; + + let builders = [ + OpBuilder<(ins "Value": $value, "Value": $TensorDesc, + "xegpu::CachePolicyAttr": $l1_hint, + "xegpu::CachePolicyAttr": $l2_hint, + "xegpu::CachePolicyAttr": $l3_hint)> + ]; + + let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index c8da5558438ea..e0046d2c9a37a 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -329,18 +329,30 @@ ParseResult parseOptionalDynamicIndexList( return success(); } -void printOptionalDynamicIndexList( - OpAsmPrinter &printer, Operation *op, OperandRange values, - ArrayRef integers, TypeRange valueTypes = TypeRange(), - AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) { +void printOptionalDynamicIndexList(OpAsmPrinter &printer, Operation *op, + OperandRange values, + DenseI64ArrayAttr integers) { + + if (!integers) + return; return printDynamicIndexList(printer, op, values, integers, - /*scalableFlags=*/{}, valueTypes, delimiter); + /*scalableFlags=*/{}, {}, + AsmParser::Delimiter::Square); } - //===----------------------------------------------------------------------===// // XeGPU_PrefetchNdOp //===----------------------------------------------------------------------===// + +void PrefetchNdOp::build(OpBuilder &builder, OperationState &state, + Value tensorDesc, xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + + return build(builder, state, tensorDesc, ValueRange(), DenseI64ArrayAttr(), + l1_hint, l2_hint, l3_hint); +} + LogicalResult PrefetchNdOp::verify() { auto tdescTy = getTensorDescType(); if (tdescTy.isScattered()) @@ -355,12 +367,34 @@ LogicalResult PrefetchNdOp::verify() { if (!isReadHintOrNone(getL3HintAttr())) return emitOpError("invalid l3_hint: ") << getL3HintAttr(); + int64_t tDescRank = tdescTy.getRank(); + int64_t offsetSize = static_cast(getOffsets().size()); + int64_t constOffsetSize = + getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0; + if (((offsetSize != 0) && (offsetSize != tDescRank)) || + ((constOffsetSize != 0) && (constOffsetSize != tDescRank))) + return emitOpError( + "Mismatched ranks between offsets and tensor descriptor"); + return success(); } //===----------------------------------------------------------------------===// // XeGPU_LoadNdOp //===----------------------------------------------------------------------===// + +void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, + Value tensorDesc, UnitAttr packed, + DenseI64ArrayAttr transpose, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + + return build(builder, state, retType, tensorDesc, ValueRange(), + DenseI64ArrayAttr(), packed, transpose, l1_hint, l2_hint, + l3_hint); +} + LogicalResult LoadNdOp::verify() { auto tdescTy = getTensorDescType(); auto valueTy = getType(); @@ -442,12 +476,31 @@ LogicalResult LoadNdOp::verify() { << " is not consistent with tensor descriptor " << tdescTy; + int64_t tDescRank = tdescTy.getRank(); + int64_t offsetSize = static_cast(getOffsets().size()); + int64_t constOffsetSize = + getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0; + if (((offsetSize != 0) && (offsetSize != tDescRank)) || + ((constOffsetSize != 0) && (constOffsetSize != tDescRank))) + return emitOpError( + "Mismatched ranks between offsets and tensor descriptor"); + return success(); } //===----------------------------------------------------------------------===// // XeGPU_StoreNdOp //===----------------------------------------------------------------------===// + +void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value, + Value tensorDesc, xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + + return build(builder, state, value, tensorDesc, ValueRange(), + DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint); +} + LogicalResult StoreNdOp::verify() { auto dstTy = getTensorDescType(); // Tile auto valTy = getValueType(); // Vector @@ -502,6 +555,15 @@ LogicalResult StoreNdOp::verify() { << " is not consistent with tensor descriptor " << dstTy; + int64_t tDescRank = dstTy.getRank(); + int64_t offsetSize = static_cast(getOffsets().size()); + int64_t constOffsetSize = + getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0; + if (((offsetSize != 0) && (offsetSize != tDescRank)) || + ((constOffsetSize != 0) && (constOffsetSize != tDescRank))) + return emitOpError( + "Mismatched ranks between offsets and tensor descriptor"); + return success(); } diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index e95d2f75d8b5a..8957ea5399ea2 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -352,6 +352,10 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern { if (!storeOp) return failure(); + int64_t offsetSize = static_cast(storeOp.getOffsets().size()); + if ((offsetSize != 0) || storeOp.getConstOffsetsAttr()) + return failure(); + xegpu::TensorDescType tensorDescTy = storeOp.getTensorDescType(); xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr(); if (!layout) @@ -464,6 +468,11 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern { warpOp, "warp result is not a xegpu::LoadNd op"); auto loadOp = operand->get().getDefiningOp(); + + int64_t offsetSize = static_cast(loadOp.getOffsets().size()); + if ((offsetSize != 0) || loadOp.getConstOffsetsAttr()) + return failure(); + xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType(); xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr(); if (!layout) @@ -767,6 +776,11 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern { auto prefetchOp = dyn_cast_or_null(lastNode); if (!prefetchOp) return failure(); + + int64_t offsetSize = static_cast(prefetchOp.getOffsets().size()); + if ((offsetSize != 0) || prefetchOp.getConstOffsetsAttr()) + return failure(); + xegpu::LayoutAttr layout = prefetchOp.getTensorDescType().getLayoutAttr(); if (!layout) return rewriter.notifyMatchFailure( diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index 0d44415595cb8..a6208b455aa35 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -218,6 +218,10 @@ struct UnrollPrefetchNdOp : public UnrollPattern { if (!targetShape) return failure(); + int64_t offsetSize = static_cast(op.getOffsets().size()); + if ((offsetSize != 0) || op.getConstOffsetsAttr()) + return failure(); + SmallVector convertedTdescTypes = getUnrolledTypes(tdescTy, *targetShape); SmallVector convertedTdesc = pack( @@ -245,6 +249,10 @@ struct UnrollLoadNdOp : public UnrollPattern { if (!targetShape) return failure(); + int64_t offsetSize = static_cast(op.getOffsets().size()); + if ((offsetSize != 0) || op.getConstOffsetsAttr()) + return failure(); + Type elemTy = tdescTy.getElementType(); VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy); @@ -279,6 +287,10 @@ struct UnrollStoreNdOp : public UnrollPattern { if (!targetShape) return failure(); + int64_t offsetSize = static_cast(op.getOffsets().size()); + if ((offsetSize != 0) || op.getConstOffsetsAttr()) + return failure(); + SmallVector convertedValTypes = getUnrolledTypes(valueTy, *targetShape); SmallVector convertedTdescTypes = diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 80bb5e888bdc7..e73f780daaa0d 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -219,6 +219,11 @@ struct WgToSgLoadNdOp : public OpConversionPattern { matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SmallVector newLoadOps; + + int64_t offsetSize = static_cast(op.getOffsets().size()); + if ((offsetSize != 0) || op.getConstOffsetsAttr()) + return failure(); + for (auto src : adaptor.getTensorDesc()) { xegpu::TensorDescType tdescTy = dyn_cast(src.getType()); @@ -241,6 +246,11 @@ struct WgToSgStoreNdOp : public OpConversionPattern { LogicalResult matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + + int64_t offsetSize = static_cast(op.getOffsets().size()); + if ((offsetSize != 0) || op.getConstOffsetsAttr()) + return failure(); + for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc())) xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, t, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); @@ -323,6 +333,11 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern { LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + + int64_t offsetSize = static_cast(op.getOffsets().size()); + if ((offsetSize != 0) || op.getConstOffsetsAttr()) + return failure(); + for (auto src : adaptor.getTensorDesc()) xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), TypeRange(), src, op->getAttrs()); diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index eb564d55bfd51..516c2158cb0f8 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -132,6 +132,31 @@ func.func @subgroup_load_nd_9(%src: memref<4x8x16xf16>) { return } +// ----- +func.func @subgroup_load_nd_offset_1(%src: memref<4x8x16xf16>, %x : index) { + %1 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<16xf16> +// expected-error@+1 {{Mismatched ranks between offsets and tensor descriptor}} + %2 = xegpu.load_nd %1[0, 0] : !xegpu.tensor_desc<16xf16> -> vector<16xf16> + return +} + +// ----- +func.func @subgroup_load_nd_offset_2(%src: memref<4x8x16xf16>, %x : index) { + %3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + // expected-error@+1 {{Mismatched ranks between offsets and tensor descriptor}} + xegpu.prefetch_nd %3[0] : !xegpu.tensor_desc<8x16xf16> + return +} + +// ----- +func.func @subgroup_load_nd_offset_3(%src: memref<4x8x16xf16>, %x : index) { + %3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + %5 = xegpu.load_nd %3[0, 0] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + // expected-error@+1 {{Mismatched ranks between offsets and tensor descriptor}} + xegpu.store_nd %5, %3[%x] : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> + return +} + // ----- func.func @load_nd_layout(%src: memref<24x32xf32>) { %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16xf32> diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir index 695437354cd7c..3ebb1b969ac74 100644 --- a/mlir/test/Dialect/XeGPU/ops.mlir +++ b/mlir/test/Dialect/XeGPU/ops.mlir @@ -112,12 +112,21 @@ gpu.func @prefetch_nd(%src: memref<24x32xf16>) { gpu.return } -// CHECK: gpu.func @prefetch_nd_2(%[[arg0:.*]]: memref<8x24x32x48x64xf16>) { -gpu.func @prefetch_nd_2(%src: memref<8x24x32x48x64xf16>) { - // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16> - %1 = xegpu.create_nd_tdesc %src[0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16> - // CHECK: xegpu.prefetch_nd %[[R0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<1x2x4x8x16xf16> - xegpu.prefetch_nd %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: !xegpu.tensor_desc<1x2x4x8x16xf16> +// CHECK: gpu.func @prefetch_nd_2(%[[arg0:.*]]: memref<48x64xf16>) { +gpu.func @prefetch_nd_2(%src: memref<48x64xf16>) { + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<48x64xf16> -> !xegpu.tensor_desc<8x16xf16> + %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<48x64xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK: xegpu.prefetch_nd %[[R0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf16> + xegpu.prefetch_nd %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: !xegpu.tensor_desc<8x16xf16> + gpu.return +} + +// CHECK: gpu.func @prefetch_nd_offset_1(%[[arg0:.*]]: memref<48x64xf16>, %arg1: index, %arg2: index) { +gpu.func @prefetch_nd_offset_1(%src: memref<48x64xf16>, %x : index, %y : index) { + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<48x64xf16> -> !xegpu.tensor_desc<8x16xf16> + %1 = xegpu.create_nd_tdesc %src : memref<48x64xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK: xegpu.prefetch_nd %[[R0]][%arg1, %arg2] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf16> + xegpu.prefetch_nd %1[%x, %y] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: !xegpu.tensor_desc<8x16xf16> gpu.return } @@ -260,6 +269,15 @@ gpu.func @subgroup_load_nd_8(%src: memref<24x32xf32>) { gpu.return } +// CHECK: func @subgroup_load_nd_offset_1(%[[arg0:.*]]: memref<24x32xf32>, %arg1: index, %arg2: index) { +gpu.func @subgroup_load_nd_offset_1(%src: memref<24x32xf32>, %x : index, %y : index) { + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32> + %1 = xegpu.create_nd_tdesc %src : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32> + // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]][%arg1, %arg2] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose = array}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32> + %2 = xegpu.load_nd %1[%x, %y] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose = array}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32> + gpu.return +} + // CHECK: func @simt_load_nd_8(%[[arg0:.*]]: memref<24x32xf32>) { gpu.func @simt_load_nd_8(%src: memref<24x32xf32>) { // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32> @@ -269,6 +287,16 @@ gpu.func @simt_load_nd_8(%src: memref<24x32xf32>) { gpu.return } + +// CHECK: func @simt_load_nd_offset_1(%[[arg0:.*]]: memref<24x32xf32>) { +gpu.func @simt_load_nd_offset_1(%src: memref<24x32xf32>) { + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32> + %1 = xegpu.create_nd_tdesc %src : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32> + // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]][0, 0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose = array}> : !xegpu.tensor_desc<16x8xf32> -> vector<8xf32> + %2 = xegpu.load_nd %1[0, 0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose = array}> : !xegpu.tensor_desc<16x8xf32> -> vector<8xf32> + gpu.return +} + // CHECK: func @subgroup_store_nd(%[[arg0:.*]]: memref<24x32xf16>) { gpu.func @subgroup_store_nd(%dst: memref<24x32xf16>) { // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x32xf16> @@ -291,8 +319,19 @@ gpu.func @simt_store_nd(%src: memref<24x32xf16>) { gpu.return } -// CHECK: func @subgroup_store_nd_2(%[[arg0:.*]]: memref<24x32xf16>) { -gpu.func @subgroup_store_nd_2(%dst: memref<24x32xf16>) { +// CHECK: func @subgroup_store_nd_2(%[[arg0:.*]]: memref<24x32xf16>, %arg1: index) { +gpu.func @subgroup_store_nd_2(%dst: memref<24x32xf16>, %x : index) { + // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16> + %1 = arith.constant dense<1.0>: vector<32xf16> + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16> + %2 = xegpu.create_nd_tdesc %dst : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16> + // CHECK: xegpu.store_nd %[[C]], %[[R0]][%arg1] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<32xf16>, !xegpu.tensor_desc<32xf16> + xegpu.store_nd %1, %2[%x] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: vector<32xf16>, !xegpu.tensor_desc<32xf16> + gpu.return +} + +// CHECK: func @subgroup_store_nd_offset_1(%[[arg0:.*]]: memref<24x32xf16>) { +gpu.func @subgroup_store_nd_offset_1(%dst: memref<24x32xf16>) { // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16> %1 = arith.constant dense<1.0>: vector<32xf16> // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16> @@ -313,6 +352,17 @@ gpu.func @simt_store_nd_2(%src: memref<24x32xf16>) { gpu.return } +// CHECK: func @simt_store_nd_offset_1(%[[arg0:.*]]: memref<24x32xf16>) { +gpu.func @simt_store_nd_offset_1(%src: memref<24x32xf16>) { + // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2xf16> + %1 = arith.constant dense<1.0>: vector<2xf16> + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16> + %2 = xegpu.create_nd_tdesc %src : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16> + // CHECK: xegpu.store_nd %[[C]], %[[R0]][0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<2xf16>, !xegpu.tensor_desc<32xf16> + xegpu.store_nd %1, %2[0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: vector<2xf16>, !xegpu.tensor_desc<32xf16> + gpu.return +} + // CHECK: gpu.func @update_nd_tdesc(%[[arg0:.*]]: memref<24x32xf32>) { gpu.func @update_nd_tdesc(%src: memref<24x32xf32>) { // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>