Skip to content

Commit 5ea1a54

Browse files
[mlir][spirv] Add support for SPV_ARM_graph extension - part 1
This is the first patch to add support for the SPV_ARM_graph SPIR-V extension to MLIR’s SPIR-V dialect. The extension introduces a new Graph abstraction for expressing dataflow computations over full resources. The part 1 implementation includes: A new GraphType, modeled similarly to FunctionType, for typed graph signatures. New operations in the spirv.arm namespace: spirv.arm.Graph spirv.arm.GraphEntryPoint spirv.arm.GraphConstant spirv.arm.GraphOutput Verifier and VCE updates to properly gate usage under SPV_ARM_graph. Tests covering parsing, verification. Graphs currently support only SPV_ARM_tensors, but are designed to generalize to other resource types, such as images. Spec: KhronosGroup/SPIRV-Registry#346 RFC: https://discourse.llvm.org/t/rfc-add-support-for-spv-arm-graph-extension-in-mlir-spir-v-dialect/86947 Signed-off-by: Davide Grohmann <[email protected]> Change-Id: Ia74b7ab0161b03d3d4702e93c34d7f55cd295a5f
1 parent df71243 commit 5ea1a54

File tree

17 files changed

+630
-18
lines changed

17 files changed

+630
-18
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,7 @@ def SPV_NV_ray_tracing_motion_blur : I32EnumAttrCase<"SPV_NV_ray_tracing_m
425425
def SPV_NVX_multiview_per_view_attributes : I32EnumAttrCase<"SPV_NVX_multiview_per_view_attributes", 5015>;
426426

427427
def SPV_ARM_tensors : I32EnumAttrCase<"SPV_ARM_tensors", 6000>;
428+
def SPV_ARM_graph : I32EnumAttrCase<"SPV_ARM_graph", 6001>;
428429

429430
def SPIRV_ExtensionAttr :
430431
SPIRV_I32EnumAttr<"Extension", "supported SPIR-V extensions", "ext", [
@@ -449,7 +450,7 @@ def SPIRV_ExtensionAttr :
449450
SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max,
450451
SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add,
451452
SPV_EXT_mesh_shader, SPV_EXT_replicated_composites,
452-
SPV_ARM_tensors,
453+
SPV_ARM_tensors, SPV_ARM_graph,
453454
SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot,
454455
SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask,
455456
SPV_AMD_shader_image_load_store_lod, SPV_AMD_texture_gather_bias_lod,
@@ -1341,6 +1342,12 @@ def SPIRV_C_StorageTensorArrayNonUniformIndexingEXT : I32EnumAttrCase<"Stora
13411342
Extension<[SPV_ARM_tensors]>
13421343
];
13431344
}
1345+
def SPIRV_C_GraphARM : I32EnumAttrCase<"GraphARM", 4191> {
1346+
list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM, SPIRV_C_Shader, SPIRV_C_VulkanMemoryModel];
1347+
list<Availability> availability = [
1348+
Extension<[SPV_ARM_graph]>
1349+
];
1350+
}
13441351
def SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR : I32EnumAttrCase<"WorkgroupMemoryExplicitLayout8BitAccessKHR", 4429> {
13451352
list<I32EnumAttrCase> implies = [SPIRV_C_WorkgroupMemoryExplicitLayoutKHR];
13461353
list<Availability> availability = [
@@ -1560,7 +1567,7 @@ def SPIRV_CapabilityAttr :
15601567
SPIRV_C_GeometryPointSize, SPIRV_C_ImageCubeArray, SPIRV_C_ImageRect,
15611568
SPIRV_C_GeometryStreams, SPIRV_C_MultiViewport,
15621569
SPIRV_C_TensorsARM, SPIRV_C_StorageTensorArrayDynamicIndexingEXT,
1563-
SPIRV_C_StorageTensorArrayNonUniformIndexingEXT,
1570+
SPIRV_C_StorageTensorArrayNonUniformIndexingEXT, SPIRV_C_GraphARM,
15641571
SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR, SPIRV_C_VariablePointers,
15651572
SPIRV_C_RayTraversalPrimitiveCullingKHR, SPIRV_C_SampleMaskOverrideCoverageNV,
15661573
SPIRV_C_GeometryShaderPassthroughNV, SPIRV_C_PerViewAttributesNV,
@@ -4569,6 +4576,13 @@ def SPIRV_OC_OpGroupNonUniformLogicalAnd : I32EnumAttrCase<"OpGroupNonUnifo
45694576
def SPIRV_OC_OpGroupNonUniformLogicalOr : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>;
45704577
def SPIRV_OC_OpGroupNonUniformLogicalXor : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>;
45714578
def SPIRV_OC_OpTypeTensorARM : I32EnumAttrCase<"OpTypeTensorARM", 4163>;
4579+
def SPIRV_OC_OpGraphConstantARM : I32EnumAttrCase<"OpGraphConstantARM", 4181>;
4580+
def SPIRV_OC_OpGraphEntryPointARM : I32EnumAttrCase<"OpGraphEntryPointARM", 4182>;
4581+
def SPIRV_OC_OpGraphARM : I32EnumAttrCase<"OpGraphARM", 4183>;
4582+
def SPIRV_OC_OpGraphInputARM : I32EnumAttrCase<"OpGraphInputARM", 4184>;
4583+
def SPIRV_OC_OpGraphSetOutputARM : I32EnumAttrCase<"OpGraphSetOutputARM", 4185>;
4584+
def SPIRV_OC_OpGraphEndARM : I32EnumAttrCase<"OpGraphEndARM", 4186>;
4585+
def SPIRV_OC_OpTypeGraphARM : I32EnumAttrCase<"OpTypeGraphARM", 4190>;
45724586
def SPIRV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
45734587
def SPIRV_OC_OpGroupNonUniformRotateKHR : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>;
45744588
def SPIRV_OC_OpSDot : I32EnumAttrCase<"OpSDot", 4450>;
@@ -4689,6 +4703,9 @@ def SPIRV_OpcodeAttr :
46894703
SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr,
46904704
SPIRV_OC_OpGroupNonUniformLogicalXor,
46914705
SPIRV_OC_OpTypeTensorARM,
4706+
SPIRV_OC_OpGraphEntryPointARM, SPIRV_OC_OpGraphARM,
4707+
SPIRV_OC_OpGraphInputARM, SPIRV_OC_OpGraphSetOutputARM, SPIRV_OC_OpGraphEndARM,
4708+
SPIRV_OC_OpTypeGraphARM, SPIRV_OC_OpGraphConstantARM,
46924709
SPIRV_OC_OpSubgroupBallotKHR,
46934710
SPIRV_OC_OpGroupNonUniformRotateKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot,
46944711
SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat,
@@ -4862,6 +4879,11 @@ class SPIRV_NvVendorOp<string mnemonic, list<Trait> traits = []> :
48624879
SPIRV_VendorOp<mnemonic, "NV", traits> {
48634880
}
48644881

4882+
class SPIRV_ArmVendorOp<string mnemonic, list<Trait> traits = []> :
4883+
SPIRV_VendorOp<mnemonic, "ARM", traits> {
4884+
}
4885+
4886+
48654887
def SPIRV_FPFMM_None : I32BitEnumAttrCaseNone<"None">;
48664888
def SPIRV_FPFMM_NotNaN : I32BitEnumAttrCaseBit<"NotNaN", 0>;
48674889
def SPIRV_FPFMM_NotInf : I32BitEnumAttrCaseBit<"NotInf", 1>;
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
//===- SPIRVGraphOps.td - Graph extended insts spec file -----*- tablegen -*-=//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This is the op definition spec of Graph extension ops.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_DIALECT_SPIRV_IR_GRAPH_OPS
14+
#define MLIR_DIALECT_SPIRV_IR_GRAPH_OPS
15+
16+
include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
17+
include "mlir/Interfaces/CallInterfaces.td"
18+
include "mlir/Interfaces/SideEffectInterfaces.td"
19+
include "mlir/Interfaces/FunctionInterfaces.td"
20+
21+
//===----------------------------------------------------------------------===//
22+
// SPIR-V Graph opcode specification.
23+
//===----------------------------------------------------------------------===//
24+
25+
// Base class for all Graph ops.
26+
class SPIRV_GraphARMOp<string mnemonic, list<Trait> traits = []> :
27+
SPIRV_ArmVendorOp<mnemonic, traits> {
28+
29+
let availability = [
30+
MinVersion<SPIRV_V_1_0>,
31+
MaxVersion<SPIRV_V_1_6>,
32+
Extension<[SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model]>,
33+
Capability<[SPIRV_C_GraphARM]>
34+
];
35+
}
36+
37+
def SPIRV_GraphConstantARMOp : SPIRV_GraphARMOp<"GraphConstant", [Pure]> {
38+
let summary = "Declare a graph constant.";
39+
40+
let description = [{
41+
Declare a graph constant.
42+
Result Type must be an OpTypeTensorARM.
43+
GraphConstantID must be a 32-bit integer literal.
44+
}];
45+
46+
let arguments = (ins
47+
I32Attr: $graph_constant_id
48+
);
49+
50+
let results = (outs
51+
SPIRV_AnyTensorArm:$output
52+
);
53+
54+
let hasVerifier = 0;
55+
56+
let autogenSerialization = 0;
57+
58+
let assemblyFormat = [{
59+
attr-dict `:` type($output)
60+
}];
61+
}
62+
63+
// -----
64+
65+
def SPIRV_GraphARMOp : SPIRV_GraphARMOp<"Graph", [
66+
AutomaticAllocationScope, DeclareOpInterfaceMethods<CallableOpInterface>,
67+
FunctionOpInterface, InModuleScope, IsolatedFromAbove
68+
]> {
69+
70+
let summary = "Declare or define a SPIR-V graph";
71+
72+
let description = [{
73+
This op declares or defines a SPIR-V graph using one region, which
74+
contains one or more blocks.
75+
76+
Different from the SPIR-V binary format, this op is not allowed to
77+
implicitly capture global values, and all external references must use
78+
function arguments or symbol references. This op itself defines a symbol
79+
that is unique in the enclosing module op.
80+
81+
This op itself takes no operands and generates no results. Its region
82+
can take zero or more arguments and return zero or more values.
83+
84+
```
85+
spv-graph-arm-op ::= `spirv.ARM.Graph` function-signature
86+
region
87+
```
88+
}];
89+
90+
let arguments = (ins
91+
TypeAttrOf<GraphType>:$function_type,
92+
OptionalAttr<DictArrayAttr>:$arg_attrs,
93+
OptionalAttr<DictArrayAttr>:$res_attrs,
94+
OptionalAttr<BoolAttr>:$entry_point,
95+
StrAttr:$sym_name
96+
);
97+
98+
let results = (outs);
99+
100+
let regions = (region AnyRegion:$body);
101+
102+
let hasVerifier = 0;
103+
104+
let builders = [
105+
OpBuilder<(ins "StringRef":$name, "GraphType":$type,
106+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs, CArg<"bool", "false">:$entry_point)>];
107+
108+
let hasOpcode = 0;
109+
110+
let autogenSerialization = 0;
111+
112+
let extraClassDeclaration = [{
113+
/// Hook for FunctionOpInterface, called after verifying that the 'type'
114+
/// attribute is present and checks if it holds a function type. Ensures
115+
/// getType, getNumArguments, and getNumResults can be called safely
116+
LogicalResult verifyType();
117+
118+
/// Hook for FunctionOpInterface, called after verifying the function
119+
/// type and the presence of the (potentially empty) function body.
120+
/// Ensures SPIR-V specific semantics.
121+
LogicalResult verifyBody();
122+
}];
123+
}
124+
125+
// Check that an op can only be used within the scope of a spirv.ARM.Graph op.
126+
def InGraphScope : PredOpTrait<
127+
"op must appear in a spirv.ARM.Graph op's block",
128+
CPred<"isNestedInGraphARMOpInterface($_op.getParentOp())">>;
129+
130+
// -----
131+
132+
def SPIRV_GraphEntryPointARMOp : SPIRV_GraphARMOp<"GraphEntryPoint", [InModuleScope]> {
133+
let summary = [{
134+
Declare a graph entry point and its interface.
135+
}];
136+
137+
let description = [{
138+
Graph Entry Point must be the Result <id> of an OpGraphARM instruction.
139+
140+
Name is a name string for the graphentry point. A module cannot have two
141+
OpGraphEntryPointARM instructions with the same Name string.
142+
143+
Interface is a list of symbol references to `spirv.GlobalVariable`
144+
operations. These declare the set of global variables from a
145+
module that form the interface of this entry point. The set of
146+
Interface symbols must be equal to or a superset of the
147+
`spirv.GlobalVariable`s referenced by the entry point’s static call
148+
tree, within the interface’s storage classes.
149+
150+
```
151+
entry-point-op ::= ssa-id `=` `spirv.ARM.GraphEntryPoint`
152+
symbol-reference (`, ` symbol-reference)*
153+
```
154+
}];
155+
156+
let arguments = (ins
157+
FlatSymbolRefAttr:$fn,
158+
SymbolRefArrayAttr:$interface
159+
);
160+
161+
let results = (outs);
162+
163+
let autogenSerialization = 0;
164+
165+
let builders = [
166+
OpBuilder<(ins "spirv::GraphARMOp":$graph, "ArrayRef<Attribute>":$interfaceVars)>];
167+
}
168+
169+
// -----
170+
171+
def SPIRV_GraphOutputsARMOp : SPIRV_GraphARMOp<"GraphOutputs", [InGraphScope, Pure,
172+
Terminator]> {
173+
174+
let summary = "Define graph outputs.";
175+
176+
let description = [{
177+
Values are the graph outputs values and must match the GraphOutputs Type
178+
operand of the OpTypeGraphARM type of the OpGraphARM body this
179+
instruction is in.
180+
181+
This instruction must be the last instruction in a block.
182+
183+
```
184+
graph-output-op ::= `spirv.ARM.GraphOutputs` ssa-use `:` type-list-no-parens
185+
```
186+
}];
187+
188+
let arguments = (ins
189+
Variadic<SPIRV_AnyTensorArm>:$value
190+
);
191+
192+
let results = (outs);
193+
194+
let autogenSerialization = 0;
195+
196+
let hasOpcode = 0;
197+
198+
let assemblyFormat = "$value attr-dict `:` type($value)";
199+
}
200+
201+
#endif // MLIR_DIALECT_SPIRV_IR_GRAPH_OPS

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td"
3232
include "mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td"
3333
include "mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td"
3434
include "mlir/Dialect/SPIRV/IR/SPIRVGLOps.td"
35+
include "mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td"
3536
include "mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td"
3637
include "mlir/Dialect/SPIRV/IR/SPIRVImageOps.td"
3738
include "mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td"

mlir/include/mlir/IR/Builders.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class Type;
2424
class IntegerType;
2525
class FloatType;
2626
class FunctionType;
27+
class GraphType;
2728
class IndexType;
2829
class MemRefType;
2930
class VectorType;
@@ -81,6 +82,7 @@ class Builder {
8182
IntegerType getIntegerType(unsigned width);
8283
IntegerType getIntegerType(unsigned width, bool isSigned);
8384
FunctionType getFunctionType(TypeRange inputs, TypeRange results);
85+
GraphType getGraphType(TypeRange inputs, TypeRange results);
8486
TupleType getTupleType(TypeRange elementTypes);
8587
NoneType getNoneType();
8688

mlir/include/mlir/IR/BuiltinTypes.td

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ def Builtin_Float128 : Builtin_CachedFloatType<"Float128", "f128"> {
403403
// FunctionType
404404
//===----------------------------------------------------------------------===//
405405

406-
def Builtin_Function : Builtin_Type<"Function", "function"> {
406+
class Builtin_FunctionLike<string Name, string typeMnemonic> : Builtin_Type<Name, typeMnemonic> {
407407
let summary = "Map from a list of inputs to a list of results";
408408
let description = [{
409409
Syntax:
@@ -434,6 +434,7 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
434434
}]>
435435
];
436436
let skipDefaultBuilders = 1;
437+
let storageClass = "FunctionTypeStorage";
437438
let genStorageClass = 0;
438439
let extraClassDeclaration = [{
439440
/// Input types.
@@ -444,23 +445,26 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
444445
unsigned getNumResults() const;
445446
Type getResult(unsigned i) const { return getResults()[i]; }
446447

447-
/// Returns a clone of this function type with the given argument
448+
/// Returns a clone of this function-like type with the given argument
448449
/// and result types.
449-
FunctionType clone(TypeRange inputs, TypeRange results) const;
450+
}] # Name # "Type" # [{ clone(TypeRange inputs, TypeRange results) const;
450451

451-
/// Returns a new function type with the specified arguments and results
452+
/// Returns a new function-like type with the specified arguments and results
452453
/// inserted.
453-
FunctionType getWithArgsAndResults(ArrayRef<unsigned> argIndices,
454+
}] # Name # "Type" # [{ getWithArgsAndResults(ArrayRef<unsigned> argIndices,
454455
TypeRange argTypes,
455456
ArrayRef<unsigned> resultIndices,
456457
TypeRange resultTypes);
457458

458-
/// Returns a new function type without the specified arguments and results.
459-
FunctionType getWithoutArgsAndResults(const BitVector &argIndices,
459+
/// Returns a new function-like type without the specified arguments and results.
460+
}] # Name # "Type" # [{ getWithoutArgsAndResults(const BitVector &argIndices,
460461
const BitVector &resultIndices);
461462
}];
462463
}
463464

465+
def Builtin_Function : Builtin_FunctionLike<"Function", "function">;
466+
def Builtin_Graph : Builtin_FunctionLike<"Graph", "graph">;
467+
464468
//===----------------------------------------------------------------------===//
465469
// IndexType
466470
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/CommonTypeConstraints.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,13 @@ class OpaqueType<string dialect, string name, string summary>
387387
def FunctionType : Type<CPred<"::llvm::isa<::mlir::FunctionType>($_self)">,
388388
"function type", "::mlir::FunctionType">;
389389

390+
// Graph Type
391+
392+
// Any graph type.
393+
def GraphType : Type<CPred<"::llvm::isa<::mlir::GraphType>($_self)">,
394+
"graph type", "::mlir::GraphType">;
395+
396+
390397
// A container type is a type that has another type embedded within it.
391398
class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
392399
string descr, string cppType = "::mlir::Type"> :

mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,8 +1065,9 @@ LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
10651065
return verifyRegionAttribute(op->getLoc(), argType, attribute);
10661066
}
10671067

1068-
LogicalResult SPIRVDialect::verifyRegionResultAttribute(
1069-
Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
1070-
NamedAttribute attribute) {
1068+
LogicalResult
1069+
SPIRVDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
1070+
unsigned resultIndex,
1071+
NamedAttribute attribute) {
10711072
return op->emitError("cannot attach SPIR-V attributes to region result");
10721073
}

mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,18 @@ static bool isNestedInFunctionOpInterface(Operation *op) {
3131
return isNestedInFunctionOpInterface(op->getParentOp());
3232
}
3333

34+
/// Returns true if the given op is a GraphARM op or nested in a
35+
/// GraphARM op without a module-like op in the middle.
36+
static bool isNestedInGraphARMOpInterface(Operation *op) {
37+
if (!op)
38+
return false;
39+
if (op->hasTrait<OpTrait::SymbolTable>())
40+
return false;
41+
if (isa<spirv::GraphARMOp>(op))
42+
return true;
43+
return isNestedInGraphARMOpInterface(op->getParentOp());
44+
}
45+
3446
/// Returns true if the given op is an module-like op that maintains a symbol
3547
/// table.
3648
static bool isDirectInModuleLikeOp(Operation *op) {

0 commit comments

Comments
 (0)