Skip to content

Commit df1b9a3

Browse files
committed
add xegpu transform ops
1 parent 9f733f4 commit df1b9a3

File tree

9 files changed

+1085
-0
lines changed

9 files changed

+1085
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
add_subdirectory(IR)
22
add_subdirectory(Transforms)
3+
add_subdirectory(TransformOps)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
set(LLVM_TARGET_DEFINITIONS XeGPUTransformOps.td)
2+
mlir_tablegen(XeGPUTransformOps.h.inc -gen-op-decls)
3+
mlir_tablegen(XeGPUTransformOps.cpp.inc -gen-op-defs)
4+
add_public_tablegen_target(MLIRXeGPUTransformOpsIncGen)
5+
6+
add_mlir_doc(XeGPUTransformOps XeGPUTransformOps Dialects/ -gen-op-doc)
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
//===- XeGPUTransformOps.h - XeGPU transformation ops -----------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H
10+
#define MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H
11+
12+
#include "mlir/Bytecode/BytecodeOpInterface.h"
13+
#include "mlir/Dialect/SCF/IR/SCF.h"
14+
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
15+
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
16+
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
17+
18+
#define GET_OP_CLASSES
19+
#include <mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h.inc>
20+
21+
namespace mlir {
22+
class DialectRegistry;
23+
24+
namespace xegpu {
25+
void registerTransformDialectExtension(DialectRegistry &registry);
26+
} // namespace xegpu
27+
} // namespace mlir
28+
29+
#endif // MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
//===- XeGPUTransformOps.td - XeGPU transformation ops -----*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef XEGPU_EXTENSION
10+
#define XEGPU_EXTENSION
11+
12+
include "mlir/Dialect/Transform/IR/TransformDialect.td"
13+
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
14+
include "mlir/Dialect/Transform/IR/TransformTypes.td"
15+
include "mlir/IR/OpBase.td"
16+
include "mlir/Interfaces/SideEffectInterfaces.td"
17+
18+
def XeGPUHoistDescOp : Op<Transform_Dialect, "xegpu.hoist_desc_ops", [
19+
TransformOpInterface, TransformEachOpTrait,
20+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
21+
]> {
22+
23+
let summary = "Hoists xegpu tile descriptor ops outside the containing loop";
24+
let description = [{
25+
Hoists `xepu.create_nd_tdesc` out of the loop. If the
26+
descriptor's offset is loop dependent, a `xegpu.update_nd_offset` op is
27+
inserted in the loop to increment the offset.
28+
}];
29+
30+
let arguments = (ins TransformHandleTypeInterface : $loop);
31+
let results = (outs TransformHandleTypeInterface : $transformed);
32+
33+
let assemblyFormat = "$loop attr-dict `:` functional-type(operands, results)";
34+
35+
let extraClassDeclaration = [{
36+
::mlir::DiagnosedSilenceableFailure applyToOne(
37+
::mlir::transform::TransformRewriter & rewriter,
38+
::mlir::Operation * target,
39+
::mlir::transform::ApplyToEachResultList & results,
40+
::mlir::transform::TransformState & state);
41+
}];
42+
}
43+
44+
def XeGPUSetDPASLayoutOp : Op<Transform_Dialect, "xegpu.set_dpas_layout", [
45+
TransformOpInterface, TransformEachOpTrait,
46+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
47+
]> {
48+
49+
let summary = "Set xegpu.layout attribute to an DPAS op operand.";
50+
let description = [{
51+
Given a `xegpu.dpas` operation, this transform adds `xegpu.layout`
52+
attribute to it's operand's tensor descriptor. The target operand is
53+
defined by the `tileIndex` argument. The layout is defined by the
54+
`sg_layout`, `sg_data` and `inst_data` attributes. The `load_data`
55+
attribute defines the tile size used for loading the data. It must be a
56+
multiple of the `inst_data` size.
57+
}];
58+
59+
let arguments = (ins TransformHandleTypeInterface : $dpasOp,
60+
I64Attr : $tileIndex,
61+
DenseI32ArrayAttr : $sgLayout,
62+
DenseI32ArrayAttr : $sgData,
63+
DenseI32ArrayAttr : $loadData,
64+
DenseI32ArrayAttr : $instData);
65+
66+
let results = (outs);
67+
68+
let assemblyFormat =
69+
"$dpasOp `index` `=` $tileIndex `sg_layout` `=` $sgLayout `sg_data` `=` "
70+
"$sgData `load_data` `=` $loadData `inst_data` `=` $instData attr-dict `:` type($dpasOp)";
71+
72+
let extraClassDeclaration = [{
73+
::mlir::DiagnosedSilenceableFailure applyToOne(
74+
::mlir::transform::TransformRewriter & rewriter,
75+
::mlir::Operation * target,
76+
::mlir::transform::ApplyToEachResultList & results,
77+
::mlir::transform::TransformState & state);
78+
}];
79+
}
80+
81+
def XeGPUInsertPrefetchOp : Op<Transform_Dialect, "xegpu.insert_prefetch",
82+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
83+
DeclareOpInterfaceMethods<TransformOpInterface>]> {
84+
85+
let summary = "Adds xegpu prefetch ops to matmul operand tiles.";
86+
let description = [{
87+
Given a `xegpu.dpas` operation residing in a `scf.for` loop, this transform inserts cooperative `xegpu.prefetch` operations for the A (index = 0) or B (index = 1) operand. The prefetch tile size is determined by the `sg_layout` and `sg_data` attributes.
88+
}];
89+
90+
let arguments = (ins TransformHandleTypeInterface : $dpasOp,
91+
TransformHandleTypeInterface : $loopOp,
92+
I64Attr : $tileIndex,
93+
DenseI32ArrayAttr : $sgLayout,
94+
DenseI32ArrayAttr : $sgData);
95+
96+
let results = (outs TransformHandleTypeInterface : $transformedDpasOp,
97+
TransformHandleTypeInterface : $transformedLoopOp);
98+
99+
let assemblyFormat =
100+
"$dpasOp $loopOp `index` `=` $tileIndex `sg_layout` `=` $sgLayout `sg_data` `=` "
101+
"$sgData attr-dict `:` functional-type(operands, results)";
102+
}
103+
104+
// TODO this should be handled with gpu transform ops.
105+
// Add gpu mapping to scf.forall op and use something like
106+
// transform.gpu.map_forall_to_blocks to convert to gpu.launch op.
107+
def XeGPUSetGPULaunchThreadsOp
108+
: Op<Transform_Dialect, "xegpu.set_gpu_launch_threads", [
109+
TransformOpInterface, TransformEachOpTrait,
110+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
111+
]> {
112+
113+
let summary = "Set number of threads for a given gpu.launch operation";
114+
let description = [{Set number of threads for a given gpu.launch operation}];
115+
116+
let arguments = (ins TransformHandleTypeInterface
117+
: $launchOp, DenseI32ArrayAttr
118+
: $threads);
119+
let results = (outs);
120+
let assemblyFormat =
121+
"$launchOp `threads` `=` $threads attr-dict `:` type($launchOp)";
122+
123+
let extraClassDeclaration = [{
124+
::mlir::DiagnosedSilenceableFailure applyToOne(
125+
::mlir::transform::TransformRewriter & rewriter,
126+
::mlir::Operation * target,
127+
::mlir::transform::ApplyToEachResultList & results,
128+
::mlir::transform::TransformState & state);
129+
}];
130+
}
131+
132+
#endif // XEGPU_EXTENSION

mlir/include/mlir/InitAllExtensions.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
5656
#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
5757
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
58+
#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
5859
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
5960
#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
6061
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
@@ -114,6 +115,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
114115
vector::registerTransformDialectExtension(registry);
115116
arm_neon::registerTransformDialectExtension(registry);
116117
arm_sve::registerTransformDialectExtension(registry);
118+
xegpu::registerTransformDialectExtension(registry);
117119

118120
// Translation extensions need to be registered by calling
119121
// `registerAllToLLVMIRTranslations` (see All.h).

mlir/lib/Dialect/XeGPU/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
add_subdirectory(IR)
22
add_subdirectory(Transforms)
33
add_subdirectory(Utils)
4+
add_subdirectory(TransformOps)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
add_mlir_dialect_library(MLIRXeGPUTransformOps
2+
XeGPUTransformOps.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${PROJECT_SOURCE_DIR}/mlir/Dialect/XeGPU/TransformOps/
6+
7+
DEPENDS
8+
MLIRXeGPUTransformOpsIncGen
9+
10+
LINK_LIBS PUBLIC
11+
MLIRIR
12+
MLIRTransformDialect
13+
MLIRFuncDialect
14+
MLIRSCFDialect
15+
)

0 commit comments

Comments
 (0)