Skip to content

Commit 101e674

Browse files
Resolve code review comments
Signed-off-by: Davide Grohmann <[email protected]> Change-Id: Ie69a1696a7b31869c1ba94bdf7aa214d52175565
1 parent f3eb19f commit 101e674

File tree

7 files changed

+104
-61
lines changed

7 files changed

+104
-61
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1343,7 +1343,7 @@ def SPIRV_C_StorageTensorArrayNonUniformIndexingEXT : I32EnumAttrCase<"Stora
13431343
];
13441344
}
13451345
def SPIRV_C_GraphARM : I32EnumAttrCase<"GraphARM", 4191> {
1346-
list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM, SPIRV_C_Shader, SPIRV_C_VulkanMemoryModel];
1346+
list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM];
13471347
list<Availability> availability = [
13481348
Extension<[SPV_ARM_graph]>
13491349
];
@@ -4883,7 +4883,6 @@ class SPIRV_ArmVendorOp<string mnemonic, list<Trait> traits = []> :
48834883
SPIRV_VendorOp<mnemonic, "ARM", traits> {
48844884
}
48854885

4886-
48874886
def SPIRV_FPFMM_None : I32BitEnumAttrCaseNone<"None">;
48884887
def SPIRV_FPFMM_NotNaN : I32BitEnumAttrCaseBit<"NotNaN", 0>;
48894888
def SPIRV_FPFMM_NotInf : I32BitEnumAttrCaseBit<"NotInf", 1>;

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -29,39 +29,11 @@ class SPIRV_GraphARMOp<string mnemonic, list<Trait> traits = []> :
2929
let availability = [
3030
MinVersion<SPIRV_V_1_0>,
3131
MaxVersion<SPIRV_V_1_6>,
32-
Extension<[SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model]>,
32+
Extension<[SPV_ARM_graph, SPV_ARM_tensors]>,
3333
Capability<[SPIRV_C_GraphARM]>
3434
];
3535
}
3636

37-
def SPIRV_GraphConstantARMOp : SPIRV_GraphARMOp<"GraphConstant", [Pure]> {
38-
let summary = "Declare a graph constant.";
39-
40-
let description = [{
41-
Declare a graph constant.
42-
Result Type must be an OpTypeTensorARM.
43-
GraphConstantID must be a 32-bit integer literal.
44-
}];
45-
46-
let arguments = (ins
47-
I32Attr: $graph_constant_id
48-
);
49-
50-
let results = (outs
51-
SPIRV_AnyTensorArm:$output
52-
);
53-
54-
let hasVerifier = 0;
55-
56-
let autogenSerialization = 0;
57-
58-
let assemblyFormat = [{
59-
attr-dict `:` type($output)
60-
}];
61-
}
62-
63-
// -----
64-
6537
def SPIRV_GraphARMOp : SPIRV_GraphARMOp<"Graph", [
6638
AutomaticAllocationScope, DeclareOpInterfaceMethods<CallableOpInterface>,
6739
FunctionOpInterface, InModuleScope, IsolatedFromAbove
@@ -122,13 +94,47 @@ def SPIRV_GraphARMOp : SPIRV_GraphARMOp<"Graph", [
12294
}];
12395
}
12496

97+
// -----
98+
12599
// Check that an op can only be used within the scope of a spirv.ARM.Graph op.
126100
def InGraphScope : PredOpTrait<
127101
"op must appear in a spirv.ARM.Graph op's block",
128102
CPred<"isNestedInGraphARMOpInterface($_op.getParentOp())">>;
129103

130104
// -----
131105

106+
def SPIRV_GraphConstantARMOp : SPIRV_GraphARMOp<"GraphConstant", [InGraphScope, Pure, ConstantLike]> {
107+
let summary = "Declare a graph constant.";
108+
109+
let description = [{
110+
Declare a graph constant.
111+
Result Type must be an OpTypeTensorARM.
112+
GraphConstantID must be a 32-bit integer literal.
113+
114+
```
115+
spv-graph-constant-arm-op ::= `spirv.ARM.GraphConstant` { graph_constant_id = 42 : i32 }
116+
```
117+
}];
118+
119+
let arguments = (ins
120+
I32Attr: $graph_constant_id
121+
);
122+
123+
let results = (outs
124+
SPIRV_AnyTensorArm:$output
125+
);
126+
127+
let hasVerifier = 0;
128+
129+
let autogenSerialization = 0;
130+
131+
let assemblyFormat = [{
132+
attr-dict `:` type($output)
133+
}];
134+
}
135+
136+
// -----
137+
132138
def SPIRV_GraphEntryPointARMOp : SPIRV_GraphARMOp<"GraphEntryPoint", [InModuleScope]> {
133139
let summary = [{
134140
Declare a graph entry point and its interface.

mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,13 +1140,11 @@ void spirv::GraphEntryPointARMOp::build(OpBuilder &builder,
11401140

11411141
ParseResult spirv::GraphEntryPointARMOp::parse(OpAsmParser &parser,
11421142
OperationState &result) {
1143-
SmallVector<Type, 0> idTypes;
11441143
SmallVector<Attribute, 4> interfaceVars;
11451144

11461145
FlatSymbolRefAttr fn;
1147-
if (parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
1146+
if (parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes))
11481147
return failure();
1149-
}
11501148

11511149
if (!parser.parseOptionalComma()) {
11521150
// Parse the interface variables
@@ -1224,7 +1222,7 @@ ParseResult spirv::GraphARMOp::parse(OpAsmParser &parser,
12241222
getResAttrsAttrName(result.name));
12251223

12261224
// Parse the optional function body.
1227-
auto *body = result.addRegion();
1225+
Region *body = result.addRegion();
12281226
OptionalParseResult parseResult =
12291227
parser.parseOptionalRegion(*body, entryArgs);
12301228
return failure(parseResult.has_value() && failed(*parseResult));
@@ -1234,7 +1232,7 @@ void spirv::GraphARMOp::print(OpAsmPrinter &printer) {
12341232
// Print graph name, signature, and control.
12351233
printer << " ";
12361234
printer.printSymbolName(getSymName());
1237-
auto grType = getFunctionType();
1235+
GraphType grType = getFunctionType();
12381236
function_interface_impl::printFunctionSignature(
12391237
printer, *this, grType.getInputs(),
12401238
/*isVariadic=*/false, grType.getResults());
@@ -1288,9 +1286,10 @@ LogicalResult spirv::GraphARMOp::verifyBody() {
12881286
<< grType.getNumResults() << " results";
12891287

12901288
auto graphOutputOperandTypes = graphOutputsARMOp.getValue().getType();
1291-
for (unsigned i = 0; i < graphOutputOperandTypes.size(); ++i) {
1292-
auto graphOutputOperandType = graphOutputOperandTypes[i];
1293-
auto grResultType = grType.getResult(i);
1289+
for (unsigned i = 0, size = graphOutputOperandTypes.size(); i < size;
1290+
++i) {
1291+
Type graphOutputOperandType = graphOutputOperandTypes[i];
1292+
Type grResultType = grType.getResult(i);
12941293
if (graphOutputOperandType != grResultType)
12951294
return graphOutputsARMOp.emitError("type of return operand ")
12961295
<< i << " (" << graphOutputOperandType
@@ -1339,13 +1338,13 @@ LogicalResult spirv::GraphOutputsARMOp::verify() {
13391338
auto graph = cast<GraphARMOp>((*this)->getParentOp());
13401339

13411340
// The operand number and types must match the graph signature.
1342-
const auto &results = graph.getFunctionType().getResults();
1341+
const ArrayRef<Type> &results = graph.getFunctionType().getResults();
13431342
if (getNumOperands() != results.size())
13441343
return emitOpError("has ")
13451344
<< getNumOperands() << " operands, but enclosing graph (@"
13461345
<< graph.getName() << ") returns " << results.size();
13471346

1348-
for (unsigned i = 0; i < results.size(); i++)
1347+
for (unsigned i = 0, size = results.size(); i < size; ++i)
13491348
if (getOperand(i).getType() != results[i])
13501349
return emitError() << "type of return operand " << i << " ("
13511350
<< getOperand(i).getType()

mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,8 @@ void UpdateVCEPass::runOnOperation() {
160160

161161
// If the op is FunctionLike make sure to process input and result types
162162
if (auto funcOpInterface = dyn_cast<FunctionOpInterface>(op)) {
163-
auto inputTypes = funcOpInterface.getArgumentTypes();
164-
auto resultTypes = funcOpInterface.getResultTypes();
163+
ArrayRef<Type> inputTypes = funcOpInterface.getArgumentTypes();
164+
ArrayRef<Type> resultTypes = funcOpInterface.getResultTypes();
165165
valueTypes.append(inputTypes.begin(), inputTypes.end());
166166
valueTypes.append(resultTypes.begin(), resultTypes.end());
167167
}

mlir/test/Dialect/SPIRV/IR/availability.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,11 +315,11 @@ func.func @constant_composite_replicate() -> () {
315315
spirv.ARM.Graph @graph_arm(%arg0: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> {
316316
// CHECK: spirv.ARM.GraphOutputs min version: v1.0
317317
// CHECK: spirv.ARM.GraphOutputs max version: v1.6
318-
// CHECK: spirv.ARM.GraphOutputs extensions: [ [SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model] ]
318+
// CHECK: spirv.ARM.GraphOutputs extensions: [ [SPV_ARM_graph, SPV_ARM_tensors] ]
319319
// CHECK: spirv.ARM.GraphOutputs capabilities: [ [GraphARM] ]
320320
spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8>
321321
// CHECK: spirv.ARM.Graph min version: v1.0
322322
// CHECK: spirv.ARM.Graph max version: v1.6
323-
// CHECK: spirv.ARM.Graph extensions: [ [SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model] ]
323+
// CHECK: spirv.ARM.Graph extensions: [ [SPV_ARM_graph, SPV_ARM_tensors] ]
324324
// CHECK: spirv.ARM.Graph capabilities: [ [GraphARM] ]
325325
}

mlir/test/Dialect/SPIRV/IR/graph-ops.mlir

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,68 @@
11
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
22

3+
//===----------------------------------------------------------------------===//
4+
// spirv.ARM.Graph and spirv.ARM.GraphOutputs
5+
//===----------------------------------------------------------------------===//
6+
7+
spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph]> {
8+
// CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
9+
spirv.ARM.Graph @graphAndOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
10+
// CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<14x19xi16>
11+
spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
12+
}
13+
}
14+
15+
// -----
16+
317
//===----------------------------------------------------------------------===//
418
// spirv.ARM.GraphConstant
519
//===----------------------------------------------------------------------===//
620

7-
spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Shader, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph]> {
8-
// CHECK: spirv.ARM.GraphConstant {graph_constant_id = 42 : i32} : !spirv.arm.tensor<14xi32>
9-
%0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<14xi32>
21+
spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph]> {
22+
// CHECK: spirv.ARM.Graph {{@.*}}() -> !spirv.arm.tensor<2x3xi16> {
23+
spirv.ARM.Graph @graphConstant() -> !spirv.arm.tensor<2x3xi16> {
24+
// CHECK: [[CONST:%.*]] = spirv.ARM.GraphConstant {graph_constant_id = 42 : i32} : !spirv.arm.tensor<2x3xi16>
25+
%0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16>
26+
// CHECK: spirv.ARM.GraphOutputs [[CONST:%.*]] : !spirv.arm.tensor<2x3xi16>
27+
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3xi16>
28+
}
29+
}
30+
// -----
31+
32+
//===----------------------------------------------------------------------===//
33+
// spirv.ARM.GraphEntryPoint
34+
//===----------------------------------------------------------------------===//
35+
1036

37+
spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph]> {
1138
// CHECK: spirv.GlobalVariable [[VARARG0:@.*]] bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
12-
spirv.GlobalVariable @main_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
13-
// CHECK: spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<2x3xi16>, UniformConstant>
14-
spirv.GlobalVariable @main_res_0 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<2x3xi16>, UniformConstant>
39+
spirv.GlobalVariable @entrypoint_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
40+
// CHECK: spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
41+
spirv.GlobalVariable @entrypoint_res_0 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
1542
// CHECK: spirv.ARM.GraphEntryPoint [[GN:@.*]], [[VARARG0]], [[VARRES0]]
16-
spirv.ARM.GraphEntryPoint @main, @main_arg_0, @main_res_0
17-
// CHECK: spirv.ARM.Graph [[GN]]({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<2x3xi16> attributes {entry_point = true} {
18-
spirv.ARM.Graph @main(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<2x3xi16> attributes {entry_point = true} {
19-
// CHECK: [[CONST2:%.*]] = spirv.ARM.GraphConstant {graph_constant_id = 42 : i32} : !spirv.arm.tensor<2x3xi16>
20-
%1 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16>
21-
// CHECK: spirv.ARM.GraphOutputs [[OUT:%.*]] : !spirv.arm.tensor<2x3xi16>
22-
spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<2x3xi16>
43+
spirv.ARM.GraphEntryPoint @entrypoint, @entrypoint_arg_0, @entrypoint_res_0
44+
// CHECK: spirv.ARM.Graph [[GN]]({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
45+
spirv.ARM.Graph @entrypoint(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
46+
// CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<14x19xi16>
47+
spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
48+
}
49+
}
50+
51+
// -----
52+
53+
//===----------------------------------------------------------------------===//
54+
// Multiple spirv.ARM.Graphs
55+
//===----------------------------------------------------------------------===//
56+
57+
spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph]> {
58+
// CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
59+
spirv.ARM.Graph @graph1(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
60+
// CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<14x19xi16>
61+
spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
2362
}
2463

2564
// CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> {
26-
spirv.ARM.Graph @empty_graph(%arg0: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> {
65+
spirv.ARM.Graph @graph2(%arg0: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> {
2766
// CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x16x16x16xi8>
2867
spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8>
2968
}

mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,10 +232,10 @@ spirv.module Logical GLSL450 attributes {
232232
}
233233
}
234234

235-
// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.5, [GraphARM, TensorsARM, Int8, Float16, VulkanMemoryModel], [SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model]>
235+
// CHECK: requires #spirv.vce<v1.5, [GraphARM, TensorsARM, Int8, Float16, VulkanMemoryModel], [SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model]>
236236
spirv.module Logical Vulkan attributes {
237237
spirv.target_env = #spirv.target_env<
238-
#spirv.vce<v1.5, [GraphARM, TensorsARM, Float16], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>,
238+
#spirv.vce<v1.5, [VulkanMemoryModel, GraphARM, TensorsARM, Float16], [SPV_ARM_tensors, SPV_ARM_graph]>,
239239
#spirv.resource_limits<>>
240240
} {
241241
spirv.ARM.Graph @argmax(%arg0 : !spirv.arm.tensor<14x19xi8>, %arg1 : !spirv.arm.tensor<1xf16>) -> !spirv.arm.tensor<14x19xi8> {

0 commit comments

Comments
 (0)