From f3eb19fdd5acf04635e842ce0b002c607047d688 Mon Sep 17 00:00:00 2001 From: Davide Grohmann Date: Mon, 4 Aug 2025 10:42:01 +0200 Subject: [PATCH 1/5] [mlir][spirv] Add support for SPV_ARM_graph extension - part 1 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This is the first patch to add support for the SPV_ARM_graph SPIR-V extension to MLIR’s SPIR-V dialect. The extension introduces a new Graph abstraction for expressing dataflow computations over full resources. The part 1 implementation includes: A new GraphType, modeled similarly to FunctionType, for typed graph signatures. New operations in the spirv.arm namespace: spirv.arm.Graph spirv.arm.GraphEntryPoint spirv.arm.GraphConstant spirv.arm.GraphOutput Verifier and VCE updates to properly gate usage under SPV_ARM_graph. Tests covering parsing, verification. Graphs currently support only SPV_ARM_tensors, but are designed to generalize to other resource types, such as images. Spec: KhronosGroup/SPIRV-Registry#346 RFC: https://discourse.llvm.org/t/rfc-add-support-for-spv-arm-graph-extension-in-mlir-spir-v-dialect/86947 Signed-off-by: Davide Grohmann Change-Id: Ia74b7ab0161b03d3d4702e93c34d7f55cd295a5f --- .../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 26 +- .../mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td | 201 +++++++++++++++ .../include/mlir/Dialect/SPIRV/IR/SPIRVOps.td | 1 + mlir/include/mlir/IR/Builders.h | 2 + mlir/include/mlir/IR/BuiltinTypes.td | 18 +- mlir/include/mlir/IR/CommonTypeConstraints.td | 7 + mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | 7 +- .../Dialect/SPIRV/IR/SPIRVOpDefinition.cpp | 12 + mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 230 ++++++++++++++++++ .../SPIRV/Transforms/UpdateVCEPass.cpp | 8 + mlir/lib/IR/AsmPrinter.cpp | 17 +- mlir/lib/IR/Builders.cpp | 4 + mlir/lib/IR/BuiltinTypes.cpp | 39 +++ mlir/test/Dialect/SPIRV/IR/availability.mlir | 17 ++ mlir/test/Dialect/SPIRV/IR/graph-ops.mlir | 30 +++ .../SPIRV/Transforms/vce-deduction.mlir | 11 + .../lib/Dialect/SPIRV/TestAvailability.cpp | 18 +- 17 files changed, 630 insertions(+), 18 deletions(-) create mode 100644 mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td create mode 100644 mlir/test/Dialect/SPIRV/IR/graph-ops.mlir diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index bdfd728d1d0b3..a27554a3c6f64 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -425,6 +425,7 @@ def SPV_NV_ray_tracing_motion_blur : I32EnumAttrCase<"SPV_NV_ray_tracing_m def SPV_NVX_multiview_per_view_attributes : I32EnumAttrCase<"SPV_NVX_multiview_per_view_attributes", 5015>; def SPV_ARM_tensors : I32EnumAttrCase<"SPV_ARM_tensors", 6000>; +def SPV_ARM_graph : I32EnumAttrCase<"SPV_ARM_graph", 6001>; def SPIRV_ExtensionAttr : SPIRV_I32EnumAttr<"Extension", "supported SPIR-V extensions", "ext", [ @@ -449,7 +450,7 @@ def SPIRV_ExtensionAttr : SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max, SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add, SPV_EXT_mesh_shader, SPV_EXT_replicated_composites, - SPV_ARM_tensors, + SPV_ARM_tensors, SPV_ARM_graph, SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot, SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask, SPV_AMD_shader_image_load_store_lod, SPV_AMD_texture_gather_bias_lod, @@ -1341,6 +1342,12 @@ def SPIRV_C_StorageTensorArrayNonUniformIndexingEXT : I32EnumAttrCase<"Stora Extension<[SPV_ARM_tensors]> ]; } +def SPIRV_C_GraphARM : I32EnumAttrCase<"GraphARM", 4191> { + list implies = [SPIRV_C_TensorsARM, SPIRV_C_Shader, SPIRV_C_VulkanMemoryModel]; + list availability = [ + Extension<[SPV_ARM_graph]> + ]; +} def SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR : I32EnumAttrCase<"WorkgroupMemoryExplicitLayout8BitAccessKHR", 4429> { list implies = [SPIRV_C_WorkgroupMemoryExplicitLayoutKHR]; list availability = [ @@ -1560,7 +1567,7 @@ def SPIRV_CapabilityAttr : SPIRV_C_GeometryPointSize, SPIRV_C_ImageCubeArray, SPIRV_C_ImageRect, SPIRV_C_GeometryStreams, SPIRV_C_MultiViewport, SPIRV_C_TensorsARM, SPIRV_C_StorageTensorArrayDynamicIndexingEXT, - SPIRV_C_StorageTensorArrayNonUniformIndexingEXT, + SPIRV_C_StorageTensorArrayNonUniformIndexingEXT, SPIRV_C_GraphARM, SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR, SPIRV_C_VariablePointers, SPIRV_C_RayTraversalPrimitiveCullingKHR, SPIRV_C_SampleMaskOverrideCoverageNV, SPIRV_C_GeometryShaderPassthroughNV, SPIRV_C_PerViewAttributesNV, @@ -4569,6 +4576,13 @@ def SPIRV_OC_OpGroupNonUniformLogicalAnd : I32EnumAttrCase<"OpGroupNonUnifo def SPIRV_OC_OpGroupNonUniformLogicalOr : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>; def SPIRV_OC_OpGroupNonUniformLogicalXor : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>; def SPIRV_OC_OpTypeTensorARM : I32EnumAttrCase<"OpTypeTensorARM", 4163>; +def SPIRV_OC_OpGraphConstantARM : I32EnumAttrCase<"OpGraphConstantARM", 4181>; +def SPIRV_OC_OpGraphEntryPointARM : I32EnumAttrCase<"OpGraphEntryPointARM", 4182>; +def SPIRV_OC_OpGraphARM : I32EnumAttrCase<"OpGraphARM", 4183>; +def SPIRV_OC_OpGraphInputARM : I32EnumAttrCase<"OpGraphInputARM", 4184>; +def SPIRV_OC_OpGraphSetOutputARM : I32EnumAttrCase<"OpGraphSetOutputARM", 4185>; +def SPIRV_OC_OpGraphEndARM : I32EnumAttrCase<"OpGraphEndARM", 4186>; +def SPIRV_OC_OpTypeGraphARM : I32EnumAttrCase<"OpTypeGraphARM", 4190>; def SPIRV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>; def SPIRV_OC_OpGroupNonUniformRotateKHR : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>; def SPIRV_OC_OpSDot : I32EnumAttrCase<"OpSDot", 4450>; @@ -4689,6 +4703,9 @@ def SPIRV_OpcodeAttr : SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr, SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpTypeTensorARM, + SPIRV_OC_OpGraphEntryPointARM, SPIRV_OC_OpGraphARM, + SPIRV_OC_OpGraphInputARM, SPIRV_OC_OpGraphSetOutputARM, SPIRV_OC_OpGraphEndARM, + SPIRV_OC_OpTypeGraphARM, SPIRV_OC_OpGraphConstantARM, SPIRV_OC_OpSubgroupBallotKHR, SPIRV_OC_OpGroupNonUniformRotateKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot, SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat, @@ -4862,6 +4879,11 @@ class SPIRV_NvVendorOp traits = []> : SPIRV_VendorOp { } +class SPIRV_ArmVendorOp traits = []> : + SPIRV_VendorOp { +} + + def SPIRV_FPFMM_None : I32BitEnumAttrCaseNone<"None">; def SPIRV_FPFMM_NotNaN : I32BitEnumAttrCaseBit<"NotNaN", 0>; def SPIRV_FPFMM_NotInf : I32BitEnumAttrCaseBit<"NotInf", 1>; diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td new file mode 100644 index 0000000000000..38fb4b2eff414 --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td @@ -0,0 +1,201 @@ +//===- SPIRVGraphOps.td - Graph extended insts spec file -----*- tablegen -*-=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the op definition spec of Graph extension ops. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPIRV_IR_GRAPH_OPS +#define MLIR_DIALECT_SPIRV_IR_GRAPH_OPS + +include "mlir/Dialect/SPIRV/IR/SPIRVBase.td" +include "mlir/Interfaces/CallInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/FunctionInterfaces.td" + +//===----------------------------------------------------------------------===// +// SPIR-V Graph opcode specification. +//===----------------------------------------------------------------------===// + +// Base class for all Graph ops. +class SPIRV_GraphARMOp traits = []> : + SPIRV_ArmVendorOp { + + let availability = [ + MinVersion, + MaxVersion, + Extension<[SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model]>, + Capability<[SPIRV_C_GraphARM]> + ]; +} + +def SPIRV_GraphConstantARMOp : SPIRV_GraphARMOp<"GraphConstant", [Pure]> { + let summary = "Declare a graph constant."; + + let description = [{ + Declare a graph constant. + Result Type must be an OpTypeTensorARM. + GraphConstantID must be a 32-bit integer literal. + }]; + + let arguments = (ins + I32Attr: $graph_constant_id + ); + + let results = (outs + SPIRV_AnyTensorArm:$output + ); + + let hasVerifier = 0; + + let autogenSerialization = 0; + + let assemblyFormat = [{ + attr-dict `:` type($output) + }]; +} + +// ----- + +def SPIRV_GraphARMOp : SPIRV_GraphARMOp<"Graph", [ + AutomaticAllocationScope, DeclareOpInterfaceMethods, + FunctionOpInterface, InModuleScope, IsolatedFromAbove + ]> { + + let summary = "Declare or define a SPIR-V graph"; + + let description = [{ + This op declares or defines a SPIR-V graph using one region, which + contains one or more blocks. + + Different from the SPIR-V binary format, this op is not allowed to + implicitly capture global values, and all external references must use + function arguments or symbol references. This op itself defines a symbol + that is unique in the enclosing module op. + + This op itself takes no operands and generates no results. Its region + can take zero or more arguments and return zero or more values. + + ``` + spv-graph-arm-op ::= `spirv.ARM.Graph` function-signature + region + ``` + }]; + + let arguments = (ins + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs, + OptionalAttr:$entry_point, + StrAttr:$sym_name + ); + + let results = (outs); + + let regions = (region AnyRegion:$body); + + let hasVerifier = 0; + + let builders = [ + OpBuilder<(ins "StringRef":$name, "GraphType":$type, + CArg<"ArrayRef", "{}">:$attrs, CArg<"bool", "false">:$entry_point)>]; + + let hasOpcode = 0; + + let autogenSerialization = 0; + + let extraClassDeclaration = [{ + /// Hook for FunctionOpInterface, called after verifying that the 'type' + /// attribute is present and checks if it holds a function type. Ensures + /// getType, getNumArguments, and getNumResults can be called safely + LogicalResult verifyType(); + + /// Hook for FunctionOpInterface, called after verifying the function + /// type and the presence of the (potentially empty) function body. + /// Ensures SPIR-V specific semantics. + LogicalResult verifyBody(); + }]; +} + +// Check that an op can only be used within the scope of a spirv.ARM.Graph op. +def InGraphScope : PredOpTrait< + "op must appear in a spirv.ARM.Graph op's block", + CPred<"isNestedInGraphARMOpInterface($_op.getParentOp())">>; + +// ----- + +def SPIRV_GraphEntryPointARMOp : SPIRV_GraphARMOp<"GraphEntryPoint", [InModuleScope]> { + let summary = [{ + Declare a graph entry point and its interface. + }]; + + let description = [{ + Graph Entry Point must be the Result of an OpGraphARM instruction. + + Name is a name string for the graphentry point. A module cannot have two + OpGraphEntryPointARM instructions with the same Name string. + + Interface is a list of symbol references to `spirv.GlobalVariable` + operations. These declare the set of global variables from a + module that form the interface of this entry point. The set of + Interface symbols must be equal to or a superset of the + `spirv.GlobalVariable`s referenced by the entry point’s static call + tree, within the interface’s storage classes. + + ``` + entry-point-op ::= ssa-id `=` `spirv.ARM.GraphEntryPoint` + symbol-reference (`, ` symbol-reference)* + ``` + }]; + + let arguments = (ins + FlatSymbolRefAttr:$fn, + SymbolRefArrayAttr:$interface + ); + + let results = (outs); + + let autogenSerialization = 0; + + let builders = [ + OpBuilder<(ins "spirv::GraphARMOp":$graph, "ArrayRef":$interfaceVars)>]; +} + +// ----- + +def SPIRV_GraphOutputsARMOp : SPIRV_GraphARMOp<"GraphOutputs", [InGraphScope, Pure, + Terminator]> { + + let summary = "Define graph outputs."; + + let description = [{ + Values are the graph outputs values and must match the GraphOutputs Type + operand of the OpTypeGraphARM type of the OpGraphARM body this + instruction is in. + + This instruction must be the last instruction in a block. + + ``` + graph-output-op ::= `spirv.ARM.GraphOutputs` ssa-use `:` type-list-no-parens + ``` + }]; + + let arguments = (ins + Variadic:$value + ); + + let results = (outs); + + let autogenSerialization = 0; + + let hasOpcode = 0; + + let assemblyFormat = "$value attr-dict `:` type($value)"; +} + +#endif // MLIR_DIALECT_SPIRV_IR_GRAPH_OPS diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td index 0fa1bb9d5bd01..96ef035eda37a 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td @@ -32,6 +32,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVGLOps.td" +include "mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVImageOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td" diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 2e356dec1981f..9d8d81a839fcb 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -24,6 +24,7 @@ class Type; class IntegerType; class FloatType; class FunctionType; +class GraphType; class IndexType; class MemRefType; class VectorType; @@ -81,6 +82,7 @@ class Builder { IntegerType getIntegerType(unsigned width); IntegerType getIntegerType(unsigned width, bool isSigned); FunctionType getFunctionType(TypeRange inputs, TypeRange results); + GraphType getGraphType(TypeRange inputs, TypeRange results); TupleType getTupleType(TypeRange elementTypes); NoneType getNoneType(); diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td index a0c8acea91dc5..08847dd11c685 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -403,7 +403,7 @@ def Builtin_Float128 : Builtin_CachedFloatType<"Float128", "f128"> { // FunctionType //===----------------------------------------------------------------------===// -def Builtin_Function : Builtin_Type<"Function", "function"> { +class Builtin_FunctionLike : Builtin_Type { let summary = "Map from a list of inputs to a list of results"; let description = [{ Syntax: @@ -434,6 +434,7 @@ def Builtin_Function : Builtin_Type<"Function", "function"> { }]> ]; let skipDefaultBuilders = 1; + let storageClass = "FunctionTypeStorage"; let genStorageClass = 0; let extraClassDeclaration = [{ /// Input types. @@ -444,23 +445,26 @@ def Builtin_Function : Builtin_Type<"Function", "function"> { unsigned getNumResults() const; Type getResult(unsigned i) const { return getResults()[i]; } - /// Returns a clone of this function type with the given argument + /// Returns a clone of this function-like type with the given argument /// and result types. - FunctionType clone(TypeRange inputs, TypeRange results) const; + }] # Name # "Type" # [{ clone(TypeRange inputs, TypeRange results) const; - /// Returns a new function type with the specified arguments and results + /// Returns a new function-like type with the specified arguments and results /// inserted. - FunctionType getWithArgsAndResults(ArrayRef argIndices, + }] # Name # "Type" # [{ getWithArgsAndResults(ArrayRef argIndices, TypeRange argTypes, ArrayRef resultIndices, TypeRange resultTypes); - /// Returns a new function type without the specified arguments and results. - FunctionType getWithoutArgsAndResults(const BitVector &argIndices, + /// Returns a new function-like type without the specified arguments and results. + }] # Name # "Type" # [{ getWithoutArgsAndResults(const BitVector &argIndices, const BitVector &resultIndices); }]; } +def Builtin_Function : Builtin_FunctionLike<"Function", "function">; +def Builtin_Graph : Builtin_FunctionLike<"Graph", "graph">; + //===----------------------------------------------------------------------===// // IndexType //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td index 45ec1846580f2..aab1b01c5cff9 100644 --- a/mlir/include/mlir/IR/CommonTypeConstraints.td +++ b/mlir/include/mlir/IR/CommonTypeConstraints.td @@ -387,6 +387,13 @@ class OpaqueType def FunctionType : Type($_self)">, "function type", "::mlir::FunctionType">; +// Graph Type + +// Any graph type. +def GraphType : Type($_self)">, + "graph type", "::mlir::GraphType">; + + // A container type is a type that has another type embedded within it. class ContainerType : diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp index fcf1526491971..6f18dcefea14d 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -1065,8 +1065,9 @@ LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op, return verifyRegionAttribute(op->getLoc(), argType, attribute); } -LogicalResult SPIRVDialect::verifyRegionResultAttribute( - Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/, - NamedAttribute attribute) { +LogicalResult +SPIRVDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, + unsigned resultIndex, + NamedAttribute attribute) { return op->emitError("cannot attach SPIR-V attributes to region result"); } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp index d8dfe164458e2..2f3a28ff16173 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp @@ -31,6 +31,18 @@ static bool isNestedInFunctionOpInterface(Operation *op) { return isNestedInFunctionOpInterface(op->getParentOp()); } +/// Returns true if the given op is a GraphARM op or nested in a +/// GraphARM op without a module-like op in the middle. +static bool isNestedInGraphARMOpInterface(Operation *op) { + if (!op) + return false; + if (op->hasTrait()) + return false; + if (isa(op)) + return true; + return isNestedInGraphARMOpInterface(op->getParentOp()); +} + /// Returns true if the given op is an module-like op that maintains a symbol /// table. static bool isDirectInModuleLikeOp(Operation *op) { diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index f99339852824c..8dfdfea8a5c54 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -1126,6 +1126,236 @@ void spirv::FuncOp::build(OpBuilder &builder, OperationState &state, state.addRegion(); } +//===----------------------------------------------------------------------===// +// spirv.GraphEntryPointARM +//===----------------------------------------------------------------------===// + +void spirv::GraphEntryPointARMOp::build(OpBuilder &builder, + OperationState &state, + spirv::GraphARMOp graph, + ArrayRef interfaceVars) { + build(builder, state, SymbolRefAttr::get(graph), + builder.getArrayAttr(interfaceVars)); +} + +ParseResult spirv::GraphEntryPointARMOp::parse(OpAsmParser &parser, + OperationState &result) { + SmallVector idTypes; + SmallVector interfaceVars; + + FlatSymbolRefAttr fn; + if (parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) { + return failure(); + } + + if (!parser.parseOptionalComma()) { + // Parse the interface variables + if (parser.parseCommaSeparatedList([&]() -> ParseResult { + // The name of the interface variable attribute isnt important + FlatSymbolRefAttr var; + NamedAttrList attrs; + if (parser.parseAttribute(var, Type(), "var_symbol", attrs)) + return failure(); + interfaceVars.push_back(var); + return success(); + })) + return failure(); + } + result.addAttribute("interface", + parser.getBuilder().getArrayAttr(interfaceVars)); + return success(); +} + +void spirv::GraphEntryPointARMOp::print(OpAsmPrinter &printer) { + printer << " "; + printer.printSymbolName(getFn()); + auto interfaceVars = getInterface().getValue(); + if (!interfaceVars.empty()) { + printer << ", "; + llvm::interleaveComma(interfaceVars, printer); + } +} + +LogicalResult spirv::GraphEntryPointARMOp::verify() { + // Checks for fn and interface symbol reference are done in spirv::ModuleOp + // verification. + return success(); +} + +//===----------------------------------------------------------------------===// +// spirv.GraphARM +//===----------------------------------------------------------------------===// + +ParseResult spirv::GraphARMOp::parse(OpAsmParser &parser, + OperationState &result) { + SmallVector entryArgs; + SmallVector resultAttrs; + SmallVector resultTypes; + auto &builder = parser.getBuilder(); + + // Parse the name as a symbol. + StringAttr nameAttr; + if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), + result.attributes)) + return failure(); + + // Parse the function signature. + bool isVariadic = false; + if (function_interface_impl::parseFunctionSignatureWithArguments( + parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes, + resultAttrs)) + return failure(); + + SmallVector argTypes; + for (auto &arg : entryArgs) + argTypes.push_back(arg.type); + auto grType = builder.getGraphType(argTypes, resultTypes); + result.addAttribute(getFunctionTypeAttrName(result.name), + TypeAttr::get(grType)); + + // If additional attributes are present, parse them. + if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) + return failure(); + + // Add the attributes to the function arguments. + assert(resultAttrs.size() == resultTypes.size()); + call_interface_impl::addArgAndResultAttrs( + builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name), + getResAttrsAttrName(result.name)); + + // Parse the optional function body. + auto *body = result.addRegion(); + OptionalParseResult parseResult = + parser.parseOptionalRegion(*body, entryArgs); + return failure(parseResult.has_value() && failed(*parseResult)); +} + +void spirv::GraphARMOp::print(OpAsmPrinter &printer) { + // Print graph name, signature, and control. + printer << " "; + printer.printSymbolName(getSymName()); + auto grType = getFunctionType(); + function_interface_impl::printFunctionSignature( + printer, *this, grType.getInputs(), + /*isVariadic=*/false, grType.getResults()); + function_interface_impl::printFunctionAttributes(printer, *this, + {getFunctionTypeAttrName(), + getArgAttrsAttrName(), + getResAttrsAttrName()}); + + // Print the body. + Region &body = this->getBody(); + if (!body.empty()) { + printer << ' '; + printer.printRegion(body, /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); + } +} + +LogicalResult spirv::GraphARMOp::verifyType() { + if (getFunctionType().getNumResults() < 1) + return emitOpError("there should be at least one result"); + return success(); +} + +LogicalResult spirv::GraphARMOp::verifyBody() { + GraphType grType = getFunctionType(); + if (!isExternal()) { + Block &entryBlock = front(); + + unsigned numArguments = this->getNumArguments(); + if (entryBlock.getNumArguments() != numArguments) + return emitOpError("entry block must have ") + << numArguments << " arguments to match graph signature"; + + for (auto [index, grArgType, blockArgType] : + llvm::enumerate(getArgumentTypes(), entryBlock.getArgumentTypes())) { + if (blockArgType != grArgType) { + return emitOpError("type of entry block argument #") + << index << '(' << blockArgType + << ") must match the type of the corresponding argument in " + << "graph signature(" << grArgType << ')'; + } + } + } + + auto walkResult = walk([grType](Operation *op) -> WalkResult { + if (auto graphOutputsARMOp = dyn_cast(op)) { + if (grType.getNumResults() != graphOutputsARMOp.getNumOperands()) + return graphOutputsARMOp.emitOpError("has GraphOutputsARM returning ") + << graphOutputsARMOp.getNumOperands() + << "value(s) but enclosing graph requires " + << grType.getNumResults() << " results"; + + auto graphOutputOperandTypes = graphOutputsARMOp.getValue().getType(); + for (unsigned i = 0; i < graphOutputOperandTypes.size(); ++i) { + auto graphOutputOperandType = graphOutputOperandTypes[i]; + auto grResultType = grType.getResult(i); + if (graphOutputOperandType != grResultType) + return graphOutputsARMOp.emitError("type of return operand ") + << i << " (" << graphOutputOperandType + << ") doesn't match graph result type (" << grResultType + << ")"; + } + } + return WalkResult::advance(); + }); + + return failure(walkResult.wasInterrupted()); +} + +void spirv::GraphARMOp::build(OpBuilder &builder, OperationState &state, + StringRef name, GraphType type, + ArrayRef attrs, bool entryPoint) { + state.addAttribute(SymbolTable::getSymbolAttrName(), + builder.getStringAttr(name)); + state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); + state.attributes.append(attrs.begin(), attrs.end()); + state.addAttribute(getEntryPointAttrName(state.name), + builder.getBoolAttr(entryPoint)); + state.addRegion(); +} + +// Returns the argument types of this function. +ArrayRef spirv::GraphARMOp::getArgumentTypes() { + return getFunctionType().getInputs(); +} + +// Returns the result types of this function. +ArrayRef spirv::GraphARMOp::getResultTypes() { + return getFunctionType().getResults(); +} + +// CallableOpInterface +Region *spirv::GraphARMOp::getCallableRegion() { + return isExternal() ? nullptr : &getBody(); +} + +//===----------------------------------------------------------------------===// +// spirv.GraphOutputsARM +//===----------------------------------------------------------------------===// + +LogicalResult spirv::GraphOutputsARMOp::verify() { + auto graph = cast((*this)->getParentOp()); + + // The operand number and types must match the graph signature. + const auto &results = graph.getFunctionType().getResults(); + if (getNumOperands() != results.size()) + return emitOpError("has ") + << getNumOperands() << " operands, but enclosing graph (@" + << graph.getName() << ") returns " << results.size(); + + for (unsigned i = 0; i < results.size(); i++) + if (getOperand(i).getType() != results[i]) + return emitError() << "type of return operand " << i << " (" + << getOperand(i).getType() + << ") doesn't match graph result type (" << results[i] + << ")" + << " in graph @" << graph.getName(); + + return success(); +} + //===----------------------------------------------------------------------===// // spirv.GLFClampOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp index 6d3bda421f309..fd97b09d802f1 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp @@ -158,6 +158,14 @@ void UpdateVCEPass::runOnOperation() { if (auto globalVar = dyn_cast(op)) valueTypes.push_back(globalVar.getType()); + // If the op is FunctionLike make sure to process input and result types + if (auto funcOpInterface = dyn_cast(op)) { + auto inputTypes = funcOpInterface.getArgumentTypes(); + auto resultTypes = funcOpInterface.getResultTypes(); + valueTypes.append(inputTypes.begin(), inputTypes.end()); + valueTypes.append(resultTypes.begin(), resultTypes.end()); + } + // Requirements from values' types SmallVector, 4> typeExtensions; SmallVector, 8> typeCapabilities; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index de52fbd3f215c..9a5dbcf6f598e 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -104,7 +104,8 @@ void OpAsmPrinter::printFunctionalType(Operation *op) { // it is a function (avoiding a grammar ambiguity). bool wrapped = op->getNumResults() != 1; if (!wrapped && op->getResult(0).getType() && - llvm::isa(op->getResult(0).getType())) + (llvm::isa(op->getResult(0).getType()) || + llvm::isa(op->getResult(0).getType()))) wrapped = true; if (wrapped) @@ -2836,6 +2837,20 @@ void AsmPrinter::Impl::printTypeImpl(Type type) { os << '>'; }) .Case([&](Type) { os << "none"; }) + .Case([&](GraphType graphTy) { + os << '('; + interleaveComma(graphTy.getInputs(), [&](Type ty) { printType(ty); }); + os << ") -> "; + ArrayRef results = graphTy.getResults(); + if (results.size() == 1 && !(llvm::isa(results[0]) || + llvm::isa(results[0]))) { + printType(results[0]); + } else { + os << '('; + interleaveComma(results, [&](Type ty) { printType(ty); }); + os << ')'; + } + }) .Default([&](Type type) { return printDialectType(type); }); } diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index f657db142eeb9..3d366276b4375 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -76,6 +76,10 @@ FunctionType Builder::getFunctionType(TypeRange inputs, TypeRange results) { return FunctionType::get(context, inputs, results); } +GraphType Builder::getGraphType(TypeRange inputs, TypeRange results) { + return GraphType::get(context, inputs, results); +} + TupleType Builder::getTupleType(TypeRange elementTypes) { return TupleType::get(context, elementTypes); } diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp index 1604ebba190a1..ce47c60c9b932 100644 --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -179,6 +179,45 @@ FunctionType::getWithoutArgsAndResults(const BitVector &argIndices, return clone(newArgTypes, newResultTypes); } +//===----------------------------------------------------------------------===// +// GraphType +//===----------------------------------------------------------------------===// + +unsigned GraphType::getNumInputs() const { return getImpl()->numInputs; } + +ArrayRef GraphType::getInputs() const { return getImpl()->getInputs(); } + +unsigned GraphType::getNumResults() const { return getImpl()->numResults; } + +ArrayRef GraphType::getResults() const { return getImpl()->getResults(); } + +GraphType GraphType::clone(TypeRange inputs, TypeRange results) const { + return get(getContext(), inputs, results); +} + +/// Returns a new function type with the specified arguments and results +/// inserted. +GraphType GraphType::getWithArgsAndResults(ArrayRef argIndices, + TypeRange argTypes, + ArrayRef resultIndices, + TypeRange resultTypes) { + SmallVector argStorage, resultStorage; + TypeRange newArgTypes = + insertTypesInto(getInputs(), argIndices, argTypes, argStorage); + TypeRange newResultTypes = + insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage); + return clone(newArgTypes, newResultTypes); +} + +/// Returns a new function type without the specified arguments and results. +GraphType GraphType::getWithoutArgsAndResults(const BitVector &argIndices, + const BitVector &resultIndices) { + SmallVector argStorage, resultStorage; + TypeRange newArgTypes = filterTypesOut(getInputs(), argIndices, argStorage); + TypeRange newResultTypes = + filterTypesOut(getResults(), resultIndices, resultStorage); + return clone(newArgTypes, newResultTypes); +} //===----------------------------------------------------------------------===// // OpaqueType //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir index f56bc3967b4b7..bc1505d32d4d5 100644 --- a/mlir/test/Dialect/SPIRV/IR/availability.mlir +++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir @@ -306,3 +306,20 @@ func.func @constant_composite_replicate() -> () { %0 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : vector<2xi32> spirv.Return } + +//===----------------------------------------------------------------------===// +// GraphARM ops +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: graph_arm +spirv.ARM.Graph @graph_arm(%arg0: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> { + // CHECK: spirv.ARM.GraphOutputs min version: v1.0 + // CHECK: spirv.ARM.GraphOutputs max version: v1.6 + // CHECK: spirv.ARM.GraphOutputs extensions: [ [SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model] ] + // CHECK: spirv.ARM.GraphOutputs capabilities: [ [GraphARM] ] + spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8> +// CHECK: spirv.ARM.Graph min version: v1.0 +// CHECK: spirv.ARM.Graph max version: v1.6 +// CHECK: spirv.ARM.Graph extensions: [ [SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model] ] +// CHECK: spirv.ARM.Graph capabilities: [ [GraphARM] ] +} diff --git a/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir b/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir new file mode 100644 index 0000000000000..90c31e19db382 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir @@ -0,0 +1,30 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s + +//===----------------------------------------------------------------------===// +// spirv.ARM.GraphConstant +//===----------------------------------------------------------------------===// + +spirv.module Logical Vulkan requires #spirv.vce { + // CHECK: spirv.ARM.GraphConstant {graph_constant_id = 42 : i32} : !spirv.arm.tensor<14xi32> + %0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<14xi32> + + // CHECK: spirv.GlobalVariable [[VARARG0:@.*]] bind(0, 0) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @main_arg_0 bind(0, 0) : !spirv.ptr, UniformConstant> + // CHECK: spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @main_res_0 bind(0, 1) : !spirv.ptr, UniformConstant> + // CHECK: spirv.ARM.GraphEntryPoint [[GN:@.*]], [[VARARG0]], [[VARRES0]] + spirv.ARM.GraphEntryPoint @main, @main_arg_0, @main_res_0 + // CHECK: spirv.ARM.Graph [[GN]]({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<2x3xi16> attributes {entry_point = true} { + spirv.ARM.Graph @main(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<2x3xi16> attributes {entry_point = true} { + // CHECK: [[CONST2:%.*]] = spirv.ARM.GraphConstant {graph_constant_id = 42 : i32} : !spirv.arm.tensor<2x3xi16> + %1 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16> + // CHECK: spirv.ARM.GraphOutputs [[OUT:%.*]] : !spirv.arm.tensor<2x3xi16> + spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<2x3xi16> + } + + // CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> { + spirv.ARM.Graph @empty_graph(%arg0: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> { + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x16x16x16xi8> + spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8> + } +} diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir index 4e534a30ad516..cf9d86576b1f6 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir @@ -231,3 +231,14 @@ spirv.module Logical GLSL450 attributes { spirv.ReturnValue %val : bf16 } } + +// CHECK: spirv.module Logical Vulkan requires #spirv.vce +spirv.module Logical Vulkan attributes { + spirv.target_env = #spirv.target_env< + #spirv.vce, + #spirv.resource_limits<>> +} { + spirv.ARM.Graph @argmax(%arg0 : !spirv.arm.tensor<14x19xi8>, %arg1 : !spirv.arm.tensor<1xf16>) -> !spirv.arm.tensor<14x19xi8> { + spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi8> + } +} diff --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp index 2e5e591fe5f91..9efca825a663d 100644 --- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp +++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp @@ -21,7 +21,7 @@ using namespace mlir; namespace { /// A pass for testing SPIR-V op availability. struct PrintOpAvailability - : public PassWrapper> { + : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrintOpAvailability) void runOnOperation() override; @@ -33,12 +33,10 @@ struct PrintOpAvailability } // namespace void PrintOpAvailability::runOnOperation() { - auto f = getOperation(); - llvm::outs() << f.getName() << "\n"; - + auto moduleOp = getOperation(); Dialect *spirvDialect = getContext().getLoadedDialect("spirv"); - f->walk([&](Operation *op) { + auto opCallback = [&](Operation *op) { if (op->getDialect() != spirvDialect) return WalkResult::advance(); @@ -89,6 +87,16 @@ void PrintOpAvailability::runOnOperation() { os.flush(); return WalkResult::advance(); + }; + + moduleOp.walk([&](func::FuncOp f) { + llvm::outs() << f.getName() << "\n"; + f->walk(opCallback); + }); + + moduleOp.walk([&](spirv::GraphARMOp g) { + llvm::outs() << g.getName() << "\n"; + g->walk(opCallback); }); } From 101e67424627d0762c745a3ec9b7e506657b82b2 Mon Sep 17 00:00:00 2001 From: Davide Grohmann Date: Wed, 6 Aug 2025 10:11:35 +0200 Subject: [PATCH 2/5] Resolve code review comments Signed-off-by: Davide Grohmann Change-Id: Ie69a1696a7b31869c1ba94bdf7aa214d52175565 --- .../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 3 +- .../mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td | 64 ++++++++++-------- mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 19 +++--- .../SPIRV/Transforms/UpdateVCEPass.cpp | 4 +- mlir/test/Dialect/SPIRV/IR/availability.mlir | 4 +- mlir/test/Dialect/SPIRV/IR/graph-ops.mlir | 67 +++++++++++++++---- .../SPIRV/Transforms/vce-deduction.mlir | 4 +- 7 files changed, 104 insertions(+), 61 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index a27554a3c6f64..0e42d08cdb1fc 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -1343,7 +1343,7 @@ def SPIRV_C_StorageTensorArrayNonUniformIndexingEXT : I32EnumAttrCase<"Stora ]; } def SPIRV_C_GraphARM : I32EnumAttrCase<"GraphARM", 4191> { - list implies = [SPIRV_C_TensorsARM, SPIRV_C_Shader, SPIRV_C_VulkanMemoryModel]; + list implies = [SPIRV_C_TensorsARM]; list availability = [ Extension<[SPV_ARM_graph]> ]; @@ -4883,7 +4883,6 @@ class SPIRV_ArmVendorOp traits = []> : SPIRV_VendorOp { } - def SPIRV_FPFMM_None : I32BitEnumAttrCaseNone<"None">; def SPIRV_FPFMM_NotNaN : I32BitEnumAttrCaseBit<"NotNaN", 0>; def SPIRV_FPFMM_NotInf : I32BitEnumAttrCaseBit<"NotInf", 1>; diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td index 38fb4b2eff414..f2913239cc4e8 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td @@ -29,39 +29,11 @@ class SPIRV_GraphARMOp traits = []> : let availability = [ MinVersion, MaxVersion, - Extension<[SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model]>, + Extension<[SPV_ARM_graph, SPV_ARM_tensors]>, Capability<[SPIRV_C_GraphARM]> ]; } -def SPIRV_GraphConstantARMOp : SPIRV_GraphARMOp<"GraphConstant", [Pure]> { - let summary = "Declare a graph constant."; - - let description = [{ - Declare a graph constant. - Result Type must be an OpTypeTensorARM. - GraphConstantID must be a 32-bit integer literal. - }]; - - let arguments = (ins - I32Attr: $graph_constant_id - ); - - let results = (outs - SPIRV_AnyTensorArm:$output - ); - - let hasVerifier = 0; - - let autogenSerialization = 0; - - let assemblyFormat = [{ - attr-dict `:` type($output) - }]; -} - -// ----- - def SPIRV_GraphARMOp : SPIRV_GraphARMOp<"Graph", [ AutomaticAllocationScope, DeclareOpInterfaceMethods, FunctionOpInterface, InModuleScope, IsolatedFromAbove @@ -122,6 +94,8 @@ def SPIRV_GraphARMOp : SPIRV_GraphARMOp<"Graph", [ }]; } +// ----- + // Check that an op can only be used within the scope of a spirv.ARM.Graph op. def InGraphScope : PredOpTrait< "op must appear in a spirv.ARM.Graph op's block", @@ -129,6 +103,38 @@ def InGraphScope : PredOpTrait< // ----- +def SPIRV_GraphConstantARMOp : SPIRV_GraphARMOp<"GraphConstant", [InGraphScope, Pure, ConstantLike]> { + let summary = "Declare a graph constant."; + + let description = [{ + Declare a graph constant. + Result Type must be an OpTypeTensorARM. + GraphConstantID must be a 32-bit integer literal. + + ``` + spv-graph-constant-arm-op ::= `spirv.ARM.GraphConstant` { graph_constant_id = 42 : i32 } + ``` + }]; + + let arguments = (ins + I32Attr: $graph_constant_id + ); + + let results = (outs + SPIRV_AnyTensorArm:$output + ); + + let hasVerifier = 0; + + let autogenSerialization = 0; + + let assemblyFormat = [{ + attr-dict `:` type($output) + }]; +} + +// ----- + def SPIRV_GraphEntryPointARMOp : SPIRV_GraphARMOp<"GraphEntryPoint", [InModuleScope]> { let summary = [{ Declare a graph entry point and its interface. diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 8dfdfea8a5c54..953406da60a57 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -1140,13 +1140,11 @@ void spirv::GraphEntryPointARMOp::build(OpBuilder &builder, ParseResult spirv::GraphEntryPointARMOp::parse(OpAsmParser &parser, OperationState &result) { - SmallVector idTypes; SmallVector interfaceVars; FlatSymbolRefAttr fn; - if (parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) { + if (parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) return failure(); - } if (!parser.parseOptionalComma()) { // Parse the interface variables @@ -1224,7 +1222,7 @@ ParseResult spirv::GraphARMOp::parse(OpAsmParser &parser, getResAttrsAttrName(result.name)); // Parse the optional function body. - auto *body = result.addRegion(); + Region *body = result.addRegion(); OptionalParseResult parseResult = parser.parseOptionalRegion(*body, entryArgs); return failure(parseResult.has_value() && failed(*parseResult)); @@ -1234,7 +1232,7 @@ void spirv::GraphARMOp::print(OpAsmPrinter &printer) { // Print graph name, signature, and control. printer << " "; printer.printSymbolName(getSymName()); - auto grType = getFunctionType(); + GraphType grType = getFunctionType(); function_interface_impl::printFunctionSignature( printer, *this, grType.getInputs(), /*isVariadic=*/false, grType.getResults()); @@ -1288,9 +1286,10 @@ LogicalResult spirv::GraphARMOp::verifyBody() { << grType.getNumResults() << " results"; auto graphOutputOperandTypes = graphOutputsARMOp.getValue().getType(); - for (unsigned i = 0; i < graphOutputOperandTypes.size(); ++i) { - auto graphOutputOperandType = graphOutputOperandTypes[i]; - auto grResultType = grType.getResult(i); + for (unsigned i = 0, size = graphOutputOperandTypes.size(); i < size; + ++i) { + Type graphOutputOperandType = graphOutputOperandTypes[i]; + Type grResultType = grType.getResult(i); if (graphOutputOperandType != grResultType) return graphOutputsARMOp.emitError("type of return operand ") << i << " (" << graphOutputOperandType @@ -1339,13 +1338,13 @@ LogicalResult spirv::GraphOutputsARMOp::verify() { auto graph = cast((*this)->getParentOp()); // The operand number and types must match the graph signature. - const auto &results = graph.getFunctionType().getResults(); + const ArrayRef &results = graph.getFunctionType().getResults(); if (getNumOperands() != results.size()) return emitOpError("has ") << getNumOperands() << " operands, but enclosing graph (@" << graph.getName() << ") returns " << results.size(); - for (unsigned i = 0; i < results.size(); i++) + for (unsigned i = 0, size = results.size(); i < size; ++i) if (getOperand(i).getType() != results[i]) return emitError() << "type of return operand " << i << " (" << getOperand(i).getType() diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp index fd97b09d802f1..a2d221252fb69 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp @@ -160,8 +160,8 @@ void UpdateVCEPass::runOnOperation() { // If the op is FunctionLike make sure to process input and result types if (auto funcOpInterface = dyn_cast(op)) { - auto inputTypes = funcOpInterface.getArgumentTypes(); - auto resultTypes = funcOpInterface.getResultTypes(); + ArrayRef inputTypes = funcOpInterface.getArgumentTypes(); + ArrayRef resultTypes = funcOpInterface.getResultTypes(); valueTypes.append(inputTypes.begin(), inputTypes.end()); valueTypes.append(resultTypes.begin(), resultTypes.end()); } diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir index bc1505d32d4d5..4ef242bdc5b16 100644 --- a/mlir/test/Dialect/SPIRV/IR/availability.mlir +++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir @@ -315,11 +315,11 @@ func.func @constant_composite_replicate() -> () { spirv.ARM.Graph @graph_arm(%arg0: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> { // CHECK: spirv.ARM.GraphOutputs min version: v1.0 // CHECK: spirv.ARM.GraphOutputs max version: v1.6 - // CHECK: spirv.ARM.GraphOutputs extensions: [ [SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model] ] + // CHECK: spirv.ARM.GraphOutputs extensions: [ [SPV_ARM_graph, SPV_ARM_tensors] ] // CHECK: spirv.ARM.GraphOutputs capabilities: [ [GraphARM] ] spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8> // CHECK: spirv.ARM.Graph min version: v1.0 // CHECK: spirv.ARM.Graph max version: v1.6 -// CHECK: spirv.ARM.Graph extensions: [ [SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model] ] +// CHECK: spirv.ARM.Graph extensions: [ [SPV_ARM_graph, SPV_ARM_tensors] ] // CHECK: spirv.ARM.Graph capabilities: [ [GraphARM] ] } diff --git a/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir b/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir index 90c31e19db382..6919c7eecc632 100644 --- a/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir @@ -1,29 +1,68 @@ // RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s +//===----------------------------------------------------------------------===// +// spirv.ARM.Graph and spirv.ARM.GraphOutputs +//===----------------------------------------------------------------------===// + +spirv.module Logical Vulkan requires #spirv.vce { + // CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> { + spirv.ARM.Graph @graphAndOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> { + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<14x19xi16> + spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16> + } +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.ARM.GraphConstant //===----------------------------------------------------------------------===// -spirv.module Logical Vulkan requires #spirv.vce { - // CHECK: spirv.ARM.GraphConstant {graph_constant_id = 42 : i32} : !spirv.arm.tensor<14xi32> - %0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<14xi32> +spirv.module Logical Vulkan requires #spirv.vce { + // CHECK: spirv.ARM.Graph {{@.*}}() -> !spirv.arm.tensor<2x3xi16> { + spirv.ARM.Graph @graphConstant() -> !spirv.arm.tensor<2x3xi16> { + // CHECK: [[CONST:%.*]] = spirv.ARM.GraphConstant {graph_constant_id = 42 : i32} : !spirv.arm.tensor<2x3xi16> + %0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16> + // CHECK: spirv.ARM.GraphOutputs [[CONST:%.*]] : !spirv.arm.tensor<2x3xi16> + spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3xi16> + } +} +// ----- + +//===----------------------------------------------------------------------===// +// spirv.ARM.GraphEntryPoint +//===----------------------------------------------------------------------===// + +spirv.module Logical Vulkan requires #spirv.vce { // CHECK: spirv.GlobalVariable [[VARARG0:@.*]] bind(0, 0) : !spirv.ptr, UniformConstant> - spirv.GlobalVariable @main_arg_0 bind(0, 0) : !spirv.ptr, UniformConstant> - // CHECK: spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr, UniformConstant> - spirv.GlobalVariable @main_res_0 bind(0, 1) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @entrypoint_arg_0 bind(0, 0) : !spirv.ptr, UniformConstant> + // CHECK: spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @entrypoint_res_0 bind(0, 1) : !spirv.ptr, UniformConstant> // CHECK: spirv.ARM.GraphEntryPoint [[GN:@.*]], [[VARARG0]], [[VARRES0]] - spirv.ARM.GraphEntryPoint @main, @main_arg_0, @main_res_0 - // CHECK: spirv.ARM.Graph [[GN]]({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<2x3xi16> attributes {entry_point = true} { - spirv.ARM.Graph @main(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<2x3xi16> attributes {entry_point = true} { - // CHECK: [[CONST2:%.*]] = spirv.ARM.GraphConstant {graph_constant_id = 42 : i32} : !spirv.arm.tensor<2x3xi16> - %1 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16> - // CHECK: spirv.ARM.GraphOutputs [[OUT:%.*]] : !spirv.arm.tensor<2x3xi16> - spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<2x3xi16> + spirv.ARM.GraphEntryPoint @entrypoint, @entrypoint_arg_0, @entrypoint_res_0 + // CHECK: spirv.ARM.Graph [[GN]]({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> { + spirv.ARM.Graph @entrypoint(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> { + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<14x19xi16> + spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16> + } +} + +// ----- + +//===----------------------------------------------------------------------===// +// Multiple spirv.ARM.Graphs +//===----------------------------------------------------------------------===// + +spirv.module Logical Vulkan requires #spirv.vce { + // CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> { + spirv.ARM.Graph @graph1(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> { + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<14x19xi16> + spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16> } // CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> { - spirv.ARM.Graph @empty_graph(%arg0: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> { + spirv.ARM.Graph @graph2(%arg0: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> { // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x16x16x16xi8> spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8> } diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir index cf9d86576b1f6..18958cef7b00a 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir @@ -232,10 +232,10 @@ spirv.module Logical GLSL450 attributes { } } -// CHECK: spirv.module Logical Vulkan requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical Vulkan attributes { spirv.target_env = #spirv.target_env< - #spirv.vce, + #spirv.vce, #spirv.resource_limits<>> } { spirv.ARM.Graph @argmax(%arg0 : !spirv.arm.tensor<14x19xi8>, %arg1 : !spirv.arm.tensor<1xf16>) -> !spirv.arm.tensor<14x19xi8> { From 3cf2ee9ce028265868e05ac84d561b09db4e847e Mon Sep 17 00:00:00 2001 From: Davide Grohmann Date: Wed, 6 Aug 2025 12:14:31 +0200 Subject: [PATCH 3/5] Fix one more comment Signed-off-by: Davide Grohmann Change-Id: Ia24695d965919ffebdff9945cd7d72233faa922a --- mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 953406da60a57..fdefd5e3966ca 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -1167,7 +1167,7 @@ ParseResult spirv::GraphEntryPointARMOp::parse(OpAsmParser &parser, void spirv::GraphEntryPointARMOp::print(OpAsmPrinter &printer) { printer << " "; printer.printSymbolName(getFn()); - auto interfaceVars = getInterface().getValue(); + ArrayRef interfaceVars = getInterface().getValue(); if (!interfaceVars.empty()) { printer << ", "; llvm::interleaveComma(interfaceVars, printer); From 5ae57fe2d6fdfd8ccf95bd6a03c96bcc9db93ba1 Mon Sep 17 00:00:00 2001 From: Davide Grohmann Date: Fri, 8 Aug 2025 11:04:19 +0200 Subject: [PATCH 4/5] Resolve more review comments and expand testing In particular add negative testing. Signed-off-by: Davide Grohmann Change-Id: Iee4ba17c74b451eda7f76c6f905ca12c734d39d6 --- .../mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td | 36 +++++ mlir/include/mlir/IR/CommonTypeConstraints.td | 1 - mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 36 +++-- mlir/test/Dialect/SPIRV/IR/graph-ops.mlir | 131 +++++++++++++----- 4 files changed, 151 insertions(+), 53 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td index f2913239cc4e8..51df4dc79ae68 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td @@ -57,6 +57,14 @@ def SPIRV_GraphARMOp : SPIRV_GraphARMOp<"Graph", [ spv-graph-arm-op ::= `spirv.ARM.Graph` function-signature region ``` + + #### Example: + + ```mlir + spirv.ARM.Graph @graph(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> { + spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16> + } + ``` }]; let arguments = (ins @@ -114,6 +122,12 @@ def SPIRV_GraphConstantARMOp : SPIRV_GraphARMOp<"GraphConstant", [InGraphScope, ``` spv-graph-constant-arm-op ::= `spirv.ARM.GraphConstant` { graph_constant_id = 42 : i32 } ``` + + #### Example: + + ```mlir + %0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16> + ``` }]; let arguments = (ins @@ -157,6 +171,17 @@ def SPIRV_GraphEntryPointARMOp : SPIRV_GraphARMOp<"GraphEntryPoint", [InModuleSc entry-point-op ::= ssa-id `=` `spirv.ARM.GraphEntryPoint` symbol-reference (`, ` symbol-reference)* ``` + + #### Example: + + ```mlir + spirv.GlobalVariable @arg_0 bind(0, 0) : !spirv.ptr, UniformConstant> + spirv.GlobalVariable @res_0 bind(0, 1) : !spirv.ptr, UniformConstant> + spirv.ARM.GraphEntryPoint @graph, @arg_0, @res_0 + spirv.ARM.Graph @graph(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> { + ... + } + ``` }]; let arguments = (ins @@ -166,6 +191,9 @@ def SPIRV_GraphEntryPointARMOp : SPIRV_GraphARMOp<"GraphEntryPoint", [InModuleSc let results = (outs); + // Checks for graph and interface symbol reference are done in spirv::ModuleOp verification. + let hasVerifier = 0; + let autogenSerialization = 0; let builders = [ @@ -189,6 +217,14 @@ def SPIRV_GraphOutputsARMOp : SPIRV_GraphARMOp<"GraphOutputs", [InGraphScope, Pu ``` graph-output-op ::= `spirv.ARM.GraphOutputs` ssa-use `:` type-list-no-parens ``` + + #### Example: + + ```mlir + spirv.ARM.Graph @graph(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> { + spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16> + } + ``` }]; let arguments = (ins diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td index aab1b01c5cff9..8ba2daefd97aa 100644 --- a/mlir/include/mlir/IR/CommonTypeConstraints.td +++ b/mlir/include/mlir/IR/CommonTypeConstraints.td @@ -393,7 +393,6 @@ def FunctionType : Type($_self)">, def GraphType : Type($_self)">, "graph type", "::mlir::GraphType">; - // A container type is a type that has another type embedded within it. class ContainerType : diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index fdefd5e3966ca..398dc046b3912 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -1174,12 +1174,6 @@ void spirv::GraphEntryPointARMOp::print(OpAsmPrinter &printer) { } } -LogicalResult spirv::GraphEntryPointARMOp::verify() { - // Checks for fn and interface symbol reference are done in spirv::ModuleOp - // verification. - return success(); -} - //===----------------------------------------------------------------------===// // spirv.GraphARM //===----------------------------------------------------------------------===// @@ -1257,7 +1251,19 @@ LogicalResult spirv::GraphARMOp::verifyType() { } LogicalResult spirv::GraphARMOp::verifyBody() { - GraphType grType = getFunctionType(); + for (auto [index, graphArgType] : llvm::enumerate(getArgumentTypes())) { + if (!isa(graphArgType)) { + return emitOpError("type of argument #") + << index << " must be a TensorArmType, but got " << graphArgType; + } + } + for (auto [index, graphResType] : llvm::enumerate(getResultTypes())) { + if (!isa(graphResType)) { + return emitOpError("type of result #") + << index << " must be a TensorArmType, but got " << graphResType; + } + } + if (!isExternal()) { Block &entryBlock = front(); @@ -1277,15 +1283,17 @@ LogicalResult spirv::GraphARMOp::verifyBody() { } } + GraphType grType = getFunctionType(); auto walkResult = walk([grType](Operation *op) -> WalkResult { if (auto graphOutputsARMOp = dyn_cast(op)) { if (grType.getNumResults() != graphOutputsARMOp.getNumOperands()) - return graphOutputsARMOp.emitOpError("has GraphOutputsARM returning ") + return graphOutputsARMOp.emitOpError("is returning ") << graphOutputsARMOp.getNumOperands() - << "value(s) but enclosing graph requires " - << grType.getNumResults() << " results"; + << " value(s) but enclosing spirv.ARM.Graph requires " + << grType.getNumResults() << " result(s)"; - auto graphOutputOperandTypes = graphOutputsARMOp.getValue().getType(); + ValueTypeRange graphOutputOperandTypes = + graphOutputsARMOp.getValue().getType(); for (unsigned i = 0, size = graphOutputOperandTypes.size(); i < size; ++i) { Type graphOutputOperandType = graphOutputOperandTypes[i]; @@ -1341,15 +1349,15 @@ LogicalResult spirv::GraphOutputsARMOp::verify() { const ArrayRef &results = graph.getFunctionType().getResults(); if (getNumOperands() != results.size()) return emitOpError("has ") - << getNumOperands() << " operands, but enclosing graph (@" + << getNumOperands() << " operands, but enclosing spirv.ARM.Graph (@" << graph.getName() << ") returns " << results.size(); for (unsigned i = 0, size = results.size(); i < size; ++i) if (getOperand(i).getType() != results[i]) return emitError() << "type of return operand " << i << " (" << getOperand(i).getType() - << ") doesn't match graph result type (" << results[i] - << ")" + << ") doesn't match spirv.ARM.Graph result type (" + << results[i] << ")" << " in graph @" << graph.getName(); return success(); diff --git a/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir b/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir index 6919c7eecc632..591eaaea4c802 100644 --- a/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir @@ -4,12 +4,10 @@ // spirv.ARM.Graph and spirv.ARM.GraphOutputs //===----------------------------------------------------------------------===// -spirv.module Logical Vulkan requires #spirv.vce { - // CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> { - spirv.ARM.Graph @graphAndOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> { - // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<14x19xi16> - spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16> - } +// CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> { +spirv.ARM.Graph @graphAndOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> { + // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<14x19xi16> + spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16> } // ----- @@ -18,14 +16,12 @@ spirv.module Logical Vulkan requires #spirv.vce { - // CHECK: spirv.ARM.Graph {{@.*}}() -> !spirv.arm.tensor<2x3xi16> { - spirv.ARM.Graph @graphConstant() -> !spirv.arm.tensor<2x3xi16> { - // CHECK: [[CONST:%.*]] = spirv.ARM.GraphConstant {graph_constant_id = 42 : i32} : !spirv.arm.tensor<2x3xi16> - %0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16> - // CHECK: spirv.ARM.GraphOutputs [[CONST:%.*]] : !spirv.arm.tensor<2x3xi16> - spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3xi16> - } +// CHECK: spirv.ARM.Graph {{@.*}}() -> !spirv.arm.tensor<2x3xi16> { +spirv.ARM.Graph @graphConstant() -> !spirv.arm.tensor<2x3xi16> { + // CHECK: [[CONST:%.*]] = spirv.ARM.GraphConstant {graph_constant_id = 42 : i32} : !spirv.arm.tensor<2x3xi16> + %0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16> + // CHECK: spirv.ARM.GraphOutputs [[CONST:%.*]] : !spirv.arm.tensor<2x3xi16> + spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3xi16> } // ----- @@ -33,37 +29,96 @@ spirv.module Logical Vulkan requires #spirv.vce, UniformConstant> +spirv.GlobalVariable @entrypoint_arg_0 bind(0, 0) : !spirv.ptr, UniformConstant> +// CHECK: spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr, UniformConstant> +spirv.GlobalVariable @entrypoint_res_0 bind(0, 1) : !spirv.ptr, UniformConstant> +// CHECK: spirv.ARM.GraphEntryPoint [[GN:@.*]], [[VARARG0]], [[VARRES0]] +spirv.ARM.GraphEntryPoint @entrypoint, @entrypoint_arg_0, @entrypoint_res_0 -spirv.module Logical Vulkan requires #spirv.vce { - // CHECK: spirv.GlobalVariable [[VARARG0:@.*]] bind(0, 0) : !spirv.ptr, UniformConstant> - spirv.GlobalVariable @entrypoint_arg_0 bind(0, 0) : !spirv.ptr, UniformConstant> - // CHECK: spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr, UniformConstant> - spirv.GlobalVariable @entrypoint_res_0 bind(0, 1) : !spirv.ptr, UniformConstant> - // CHECK: spirv.ARM.GraphEntryPoint [[GN:@.*]], [[VARARG0]], [[VARRES0]] - spirv.ARM.GraphEntryPoint @entrypoint, @entrypoint_arg_0, @entrypoint_res_0 - // CHECK: spirv.ARM.Graph [[GN]]({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> { - spirv.ARM.Graph @entrypoint(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> { - // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<14x19xi16> - spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16> - } +// ----- + +//===----------------------------------------------------------------------===// +// spirv.ARM.Graph with no terminator +//===----------------------------------------------------------------------===// + +// expected-error @+1 {{empty block: expect at least a terminator}} +spirv.ARM.Graph @graphNoterminator(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> { +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.ARM.Graph with no result types +//===----------------------------------------------------------------------===// + +// expected-error @+1 {{'spirv.ARM.Graph' op there should be at least one result}} +spirv.ARM.Graph @graphNoOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> () { } // ----- //===----------------------------------------------------------------------===// -// Multiple spirv.ARM.Graphs +// spirv.ARM.GraphConstant outside graph scope //===----------------------------------------------------------------------===// -spirv.module Logical Vulkan requires #spirv.vce { - // CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> { - spirv.ARM.Graph @graph1(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> { - // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<14x19xi16> - spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16> - } +// expected-error @+1 {{'spirv.ARM.GraphConstant' op failed to verify that op must appear in a spirv.ARM.Graph op's block}} +%0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16> +// ----- + +//===----------------------------------------------------------------------===// +// spirv.ARM.GraphOutputs outside graph scope +//===----------------------------------------------------------------------===// + +%0 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi16> +// expected-error @+1 {{'spirv.ARM.GraphOutputs' op failed to verify that op must appear in a spirv.ARM.Graph op's block}} +spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<1xi16> + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.ARM.Graph return type does not match spirv.ARM.GraphOutputs +//===----------------------------------------------------------------------===// + +spirv.ARM.Graph @graphAndOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<5x3xi16> { + // expected-error @+1 {{type of return operand 0 ('!spirv.arm.tensor<14x19xi16>') doesn't match graph result type ('!spirv.arm.tensor<5x3xi16>')}} + spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16> +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.ARM.Graph return type does not match number of results in spirv.ARM.GraphOutputs +//===----------------------------------------------------------------------===// + +spirv.ARM.Graph @graphAndOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> (!spirv.arm.tensor<14x19xi16>, !spirv.arm.tensor<14x19xi16>) { + // expected-error @+1 {{'spirv.ARM.GraphOutputs' op is returning 1 value(s) but enclosing spirv.ARM.Graph requires 2 result(s)}} + spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16> +} + +// ----- + +spirv.ARM.Graph @graphAndOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> { + // expected-error @+1 {{'spirv.ARM.GraphOutputs' op is returning 2 value(s) but enclosing spirv.ARM.Graph requires 1 result(s)}} + spirv.ARM.GraphOutputs %arg0, %arg0 : !spirv.arm.tensor<14x19xi16>, !spirv.arm.tensor<14x19xi16> +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.ARM.Graph using a non TensorArmType argument +//===----------------------------------------------------------------------===// + +// expected-error @+1 {{'spirv.ARM.Graph' op type of argument #0 must be a TensorArmType, but got 'i8'}} +spirv.ARM.Graph @graphAndOutputs(%arg0 : i8) -> !spirv.arm.tensor<14x19xi16> { +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.ARM.Graph using a non TensorArmType result +//===----------------------------------------------------------------------===// - // CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> { - spirv.ARM.Graph @graph2(%arg0: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> { - // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x16x16x16xi8> - spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8> - } +// expected-error @+1 {{'spirv.ARM.Graph' op type of result #0 must be a TensorArmType, but got 'i8'}} +spirv.ARM.Graph @graphAndOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> i8 { } From 1a4f0aa4486f25230821c3c08830729b0efd900e Mon Sep 17 00:00:00 2001 From: Davide Grohmann Date: Fri, 8 Aug 2025 13:20:35 +0200 Subject: [PATCH 5/5] Extract SPV_ARM_graph operations in its own file Signed-off-by: Davide Grohmann Change-Id: Ia74db44157fb724c9f787387386338814413db30 --- mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp | 264 ++++++++++++++++++++++ mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt | 1 + mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 237 ------------------- 3 files changed, 265 insertions(+), 237 deletions(-) create mode 100644 mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp diff --git a/mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp new file mode 100644 index 0000000000000..e300596fd3733 --- /dev/null +++ b/mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp @@ -0,0 +1,264 @@ +//===- ArmGraphOps.cpp - MLIR SPIR-V SPV_ARM_graph operations +//------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the SPV_ARM_graph operations in the SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" + +#include "SPIRVParsingUtils.h" + +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/Interfaces/FunctionImplementation.h" + +using namespace mlir; +using namespace mlir::spirv::AttrNames; + +//===----------------------------------------------------------------------===// +// spirv.GraphARM +//===----------------------------------------------------------------------===// + +ParseResult spirv::GraphARMOp::parse(OpAsmParser &parser, + OperationState &result) { + SmallVector entryArgs; + SmallVector resultAttrs; + SmallVector resultTypes; + auto &builder = parser.getBuilder(); + + // Parse the name as a symbol. + StringAttr nameAttr; + if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), + result.attributes)) + return failure(); + + // Parse the function signature. + bool isVariadic = false; + if (function_interface_impl::parseFunctionSignatureWithArguments( + parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes, + resultAttrs)) + return failure(); + + SmallVector argTypes; + for (auto &arg : entryArgs) + argTypes.push_back(arg.type); + auto grType = builder.getGraphType(argTypes, resultTypes); + result.addAttribute(getFunctionTypeAttrName(result.name), + TypeAttr::get(grType)); + + // If additional attributes are present, parse them. + if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) + return failure(); + + // Add the attributes to the function arguments. + assert(resultAttrs.size() == resultTypes.size()); + call_interface_impl::addArgAndResultAttrs( + builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name), + getResAttrsAttrName(result.name)); + + // Parse the optional function body. + Region *body = result.addRegion(); + OptionalParseResult parseResult = + parser.parseOptionalRegion(*body, entryArgs); + return failure(parseResult.has_value() && failed(*parseResult)); +} + +void spirv::GraphARMOp::print(OpAsmPrinter &printer) { + // Print graph name, signature, and control. + printer << " "; + printer.printSymbolName(getSymName()); + GraphType grType = getFunctionType(); + function_interface_impl::printFunctionSignature( + printer, *this, grType.getInputs(), + /*isVariadic=*/false, grType.getResults()); + function_interface_impl::printFunctionAttributes(printer, *this, + {getFunctionTypeAttrName(), + getArgAttrsAttrName(), + getResAttrsAttrName()}); + + // Print the body. + Region &body = this->getBody(); + if (!body.empty()) { + printer << ' '; + printer.printRegion(body, /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); + } +} + +LogicalResult spirv::GraphARMOp::verifyType() { + if (getFunctionType().getNumResults() < 1) + return emitOpError("there should be at least one result"); + return success(); +} + +LogicalResult spirv::GraphARMOp::verifyBody() { + for (auto [index, graphArgType] : llvm::enumerate(getArgumentTypes())) { + if (!isa(graphArgType)) { + return emitOpError("type of argument #") + << index << " must be a TensorArmType, but got " << graphArgType; + } + } + for (auto [index, graphResType] : llvm::enumerate(getResultTypes())) { + if (!isa(graphResType)) { + return emitOpError("type of result #") + << index << " must be a TensorArmType, but got " << graphResType; + } + } + + if (!isExternal()) { + Block &entryBlock = front(); + + unsigned numArguments = this->getNumArguments(); + if (entryBlock.getNumArguments() != numArguments) + return emitOpError("entry block must have ") + << numArguments << " arguments to match graph signature"; + + for (auto [index, grArgType, blockArgType] : + llvm::enumerate(getArgumentTypes(), entryBlock.getArgumentTypes())) { + if (blockArgType != grArgType) { + return emitOpError("type of entry block argument #") + << index << '(' << blockArgType + << ") must match the type of the corresponding argument in " + << "graph signature(" << grArgType << ')'; + } + } + } + + GraphType grType = getFunctionType(); + auto walkResult = walk([grType](Operation *op) -> WalkResult { + if (auto graphOutputsARMOp = dyn_cast(op)) { + if (grType.getNumResults() != graphOutputsARMOp.getNumOperands()) + return graphOutputsARMOp.emitOpError("is returning ") + << graphOutputsARMOp.getNumOperands() + << " value(s) but enclosing spirv.ARM.Graph requires " + << grType.getNumResults() << " result(s)"; + + ValueTypeRange graphOutputOperandTypes = + graphOutputsARMOp.getValue().getType(); + for (unsigned i = 0, size = graphOutputOperandTypes.size(); i < size; + ++i) { + Type graphOutputOperandType = graphOutputOperandTypes[i]; + Type grResultType = grType.getResult(i); + if (graphOutputOperandType != grResultType) + return graphOutputsARMOp.emitError("type of return operand ") + << i << " (" << graphOutputOperandType + << ") doesn't match graph result type (" << grResultType + << ")"; + } + } + return WalkResult::advance(); + }); + + return failure(walkResult.wasInterrupted()); +} + +void spirv::GraphARMOp::build(OpBuilder &builder, OperationState &state, + StringRef name, GraphType type, + ArrayRef attrs, bool entryPoint) { + state.addAttribute(SymbolTable::getSymbolAttrName(), + builder.getStringAttr(name)); + state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); + state.attributes.append(attrs.begin(), attrs.end()); + state.addAttribute(getEntryPointAttrName(state.name), + builder.getBoolAttr(entryPoint)); + state.addRegion(); +} + +// Returns the argument types of this function. +ArrayRef spirv::GraphARMOp::getArgumentTypes() { + return getFunctionType().getInputs(); +} + +// Returns the result types of this function. +ArrayRef spirv::GraphARMOp::getResultTypes() { + return getFunctionType().getResults(); +} + +// CallableOpInterface +Region *spirv::GraphARMOp::getCallableRegion() { + return isExternal() ? nullptr : &getBody(); +} + +//===----------------------------------------------------------------------===// +// spirv.GraphOutputsARM +//===----------------------------------------------------------------------===// + +LogicalResult spirv::GraphOutputsARMOp::verify() { + auto graph = cast((*this)->getParentOp()); + + // The operand number and types must match the graph signature. + const ArrayRef &results = graph.getFunctionType().getResults(); + if (getNumOperands() != results.size()) + return emitOpError("has ") + << getNumOperands() << " operands, but enclosing spirv.ARM.Graph (@" + << graph.getName() << ") returns " << results.size(); + + for (unsigned i = 0, size = results.size(); i < size; ++i) + if (getOperand(i).getType() != results[i]) + return emitError() << "type of return operand " << i << " (" + << getOperand(i).getType() + << ") doesn't match spirv.ARM.Graph result type (" + << results[i] << ")" + << " in graph @" << graph.getName(); + + return success(); +} + +//===----------------------------------------------------------------------===// +// spirv.GraphEntryPointARM +//===----------------------------------------------------------------------===// + +void spirv::GraphEntryPointARMOp::build(OpBuilder &builder, + OperationState &state, + spirv::GraphARMOp graph, + ArrayRef interfaceVars) { + build(builder, state, SymbolRefAttr::get(graph), + builder.getArrayAttr(interfaceVars)); +} + +ParseResult spirv::GraphEntryPointARMOp::parse(OpAsmParser &parser, + OperationState &result) { + SmallVector interfaceVars; + + FlatSymbolRefAttr fn; + if (parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) + return failure(); + + if (!parser.parseOptionalComma()) { + // Parse the interface variables + if (parser.parseCommaSeparatedList([&]() -> ParseResult { + // The name of the interface variable attribute isnt important + FlatSymbolRefAttr var; + NamedAttrList attrs; + if (parser.parseAttribute(var, Type(), "var_symbol", attrs)) + return failure(); + interfaceVars.push_back(var); + return success(); + })) + return failure(); + } + result.addAttribute("interface", + parser.getBuilder().getArrayAttr(interfaceVars)); + return success(); +} + +void spirv::GraphEntryPointARMOp::print(OpAsmPrinter &printer) { + printer << " "; + printer.printSymbolName(getFn()); + ArrayRef interfaceVars = getInterface().getValue(); + if (!interfaceVars.empty()) { + printer << ", "; + llvm::interleaveComma(interfaceVars, printer); + } +} diff --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt index b9aa7b7491abf..60d705d940cfc 100644 --- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt @@ -3,6 +3,7 @@ mlir_tablegen(SPIRVCanonicalization.inc -gen-rewriters) add_public_tablegen_target(MLIRSPIRVCanonicalizationIncGen) add_mlir_dialect_library(MLIRSPIRVDialect + ArmGraphOps.cpp AtomicOps.cpp CastOps.cpp ControlFlowOps.cpp diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 398dc046b3912..f99339852824c 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -1126,243 +1126,6 @@ void spirv::FuncOp::build(OpBuilder &builder, OperationState &state, state.addRegion(); } -//===----------------------------------------------------------------------===// -// spirv.GraphEntryPointARM -//===----------------------------------------------------------------------===// - -void spirv::GraphEntryPointARMOp::build(OpBuilder &builder, - OperationState &state, - spirv::GraphARMOp graph, - ArrayRef interfaceVars) { - build(builder, state, SymbolRefAttr::get(graph), - builder.getArrayAttr(interfaceVars)); -} - -ParseResult spirv::GraphEntryPointARMOp::parse(OpAsmParser &parser, - OperationState &result) { - SmallVector interfaceVars; - - FlatSymbolRefAttr fn; - if (parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) - return failure(); - - if (!parser.parseOptionalComma()) { - // Parse the interface variables - if (parser.parseCommaSeparatedList([&]() -> ParseResult { - // The name of the interface variable attribute isnt important - FlatSymbolRefAttr var; - NamedAttrList attrs; - if (parser.parseAttribute(var, Type(), "var_symbol", attrs)) - return failure(); - interfaceVars.push_back(var); - return success(); - })) - return failure(); - } - result.addAttribute("interface", - parser.getBuilder().getArrayAttr(interfaceVars)); - return success(); -} - -void spirv::GraphEntryPointARMOp::print(OpAsmPrinter &printer) { - printer << " "; - printer.printSymbolName(getFn()); - ArrayRef interfaceVars = getInterface().getValue(); - if (!interfaceVars.empty()) { - printer << ", "; - llvm::interleaveComma(interfaceVars, printer); - } -} - -//===----------------------------------------------------------------------===// -// spirv.GraphARM -//===----------------------------------------------------------------------===// - -ParseResult spirv::GraphARMOp::parse(OpAsmParser &parser, - OperationState &result) { - SmallVector entryArgs; - SmallVector resultAttrs; - SmallVector resultTypes; - auto &builder = parser.getBuilder(); - - // Parse the name as a symbol. - StringAttr nameAttr; - if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), - result.attributes)) - return failure(); - - // Parse the function signature. - bool isVariadic = false; - if (function_interface_impl::parseFunctionSignatureWithArguments( - parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes, - resultAttrs)) - return failure(); - - SmallVector argTypes; - for (auto &arg : entryArgs) - argTypes.push_back(arg.type); - auto grType = builder.getGraphType(argTypes, resultTypes); - result.addAttribute(getFunctionTypeAttrName(result.name), - TypeAttr::get(grType)); - - // If additional attributes are present, parse them. - if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) - return failure(); - - // Add the attributes to the function arguments. - assert(resultAttrs.size() == resultTypes.size()); - call_interface_impl::addArgAndResultAttrs( - builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name), - getResAttrsAttrName(result.name)); - - // Parse the optional function body. - Region *body = result.addRegion(); - OptionalParseResult parseResult = - parser.parseOptionalRegion(*body, entryArgs); - return failure(parseResult.has_value() && failed(*parseResult)); -} - -void spirv::GraphARMOp::print(OpAsmPrinter &printer) { - // Print graph name, signature, and control. - printer << " "; - printer.printSymbolName(getSymName()); - GraphType grType = getFunctionType(); - function_interface_impl::printFunctionSignature( - printer, *this, grType.getInputs(), - /*isVariadic=*/false, grType.getResults()); - function_interface_impl::printFunctionAttributes(printer, *this, - {getFunctionTypeAttrName(), - getArgAttrsAttrName(), - getResAttrsAttrName()}); - - // Print the body. - Region &body = this->getBody(); - if (!body.empty()) { - printer << ' '; - printer.printRegion(body, /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/true); - } -} - -LogicalResult spirv::GraphARMOp::verifyType() { - if (getFunctionType().getNumResults() < 1) - return emitOpError("there should be at least one result"); - return success(); -} - -LogicalResult spirv::GraphARMOp::verifyBody() { - for (auto [index, graphArgType] : llvm::enumerate(getArgumentTypes())) { - if (!isa(graphArgType)) { - return emitOpError("type of argument #") - << index << " must be a TensorArmType, but got " << graphArgType; - } - } - for (auto [index, graphResType] : llvm::enumerate(getResultTypes())) { - if (!isa(graphResType)) { - return emitOpError("type of result #") - << index << " must be a TensorArmType, but got " << graphResType; - } - } - - if (!isExternal()) { - Block &entryBlock = front(); - - unsigned numArguments = this->getNumArguments(); - if (entryBlock.getNumArguments() != numArguments) - return emitOpError("entry block must have ") - << numArguments << " arguments to match graph signature"; - - for (auto [index, grArgType, blockArgType] : - llvm::enumerate(getArgumentTypes(), entryBlock.getArgumentTypes())) { - if (blockArgType != grArgType) { - return emitOpError("type of entry block argument #") - << index << '(' << blockArgType - << ") must match the type of the corresponding argument in " - << "graph signature(" << grArgType << ')'; - } - } - } - - GraphType grType = getFunctionType(); - auto walkResult = walk([grType](Operation *op) -> WalkResult { - if (auto graphOutputsARMOp = dyn_cast(op)) { - if (grType.getNumResults() != graphOutputsARMOp.getNumOperands()) - return graphOutputsARMOp.emitOpError("is returning ") - << graphOutputsARMOp.getNumOperands() - << " value(s) but enclosing spirv.ARM.Graph requires " - << grType.getNumResults() << " result(s)"; - - ValueTypeRange graphOutputOperandTypes = - graphOutputsARMOp.getValue().getType(); - for (unsigned i = 0, size = graphOutputOperandTypes.size(); i < size; - ++i) { - Type graphOutputOperandType = graphOutputOperandTypes[i]; - Type grResultType = grType.getResult(i); - if (graphOutputOperandType != grResultType) - return graphOutputsARMOp.emitError("type of return operand ") - << i << " (" << graphOutputOperandType - << ") doesn't match graph result type (" << grResultType - << ")"; - } - } - return WalkResult::advance(); - }); - - return failure(walkResult.wasInterrupted()); -} - -void spirv::GraphARMOp::build(OpBuilder &builder, OperationState &state, - StringRef name, GraphType type, - ArrayRef attrs, bool entryPoint) { - state.addAttribute(SymbolTable::getSymbolAttrName(), - builder.getStringAttr(name)); - state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); - state.attributes.append(attrs.begin(), attrs.end()); - state.addAttribute(getEntryPointAttrName(state.name), - builder.getBoolAttr(entryPoint)); - state.addRegion(); -} - -// Returns the argument types of this function. -ArrayRef spirv::GraphARMOp::getArgumentTypes() { - return getFunctionType().getInputs(); -} - -// Returns the result types of this function. -ArrayRef spirv::GraphARMOp::getResultTypes() { - return getFunctionType().getResults(); -} - -// CallableOpInterface -Region *spirv::GraphARMOp::getCallableRegion() { - return isExternal() ? nullptr : &getBody(); -} - -//===----------------------------------------------------------------------===// -// spirv.GraphOutputsARM -//===----------------------------------------------------------------------===// - -LogicalResult spirv::GraphOutputsARMOp::verify() { - auto graph = cast((*this)->getParentOp()); - - // The operand number and types must match the graph signature. - const ArrayRef &results = graph.getFunctionType().getResults(); - if (getNumOperands() != results.size()) - return emitOpError("has ") - << getNumOperands() << " operands, but enclosing spirv.ARM.Graph (@" - << graph.getName() << ") returns " << results.size(); - - for (unsigned i = 0, size = results.size(); i < size; ++i) - if (getOperand(i).getType() != results[i]) - return emitError() << "type of return operand " << i << " (" - << getOperand(i).getType() - << ") doesn't match spirv.ARM.Graph result type (" - << results[i] << ")" - << " in graph @" << graph.getName(); - - return success(); -} - //===----------------------------------------------------------------------===// // spirv.GLFClampOp //===----------------------------------------------------------------------===//