diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index 18178614e39cd..3545acb20212d 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -385,6 +385,14 @@ LogicalResult GPUModuleConversion::matchAndRewrite( if (auto attr = moduleOp->getAttrOfType( spirv::getTargetEnvAttrName())) spvModule->setAttr(spirv::getTargetEnvAttrName(), attr); + if (ArrayAttr targets = moduleOp.getTargetsAttr()) { + for (Attribute targetAttr : targets) + if (auto spirvTargetEnvAttr = + dyn_cast(targetAttr)) { + spvModule->setAttr(spirv::getTargetEnvAttrName(), spirvTargetEnvAttr); + break; + } + } rewriter.eraseOp(moduleOp); return success(); diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp index a344f88326089..5eab05742d401 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp @@ -48,9 +48,36 @@ struct GPUToSPIRVPass final : impl::ConvertGPUToSPIRVBase { void runOnOperation() override; private: + /// Queries the target environment from 'targets' attribute of the given + /// `moduleOp`. + spirv::TargetEnvAttr lookupTargetEnvInTargets(gpu::GPUModuleOp moduleOp); + + /// Queries the target environment from 'targets' attribute of the given + /// `moduleOp` or returns target environment as returned by + /// `spirv::lookupTargetEnvOrDefault` if not provided by 'targets'. + spirv::TargetEnvAttr lookupTargetEnvOrDefault(gpu::GPUModuleOp moduleOp); bool mapMemorySpace; }; +spirv::TargetEnvAttr +GPUToSPIRVPass::lookupTargetEnvInTargets(gpu::GPUModuleOp moduleOp) { + if (ArrayAttr targets = moduleOp.getTargetsAttr()) { + for (Attribute targetAttr : targets) + if (auto spirvTargetEnvAttr = dyn_cast(targetAttr)) + return spirvTargetEnvAttr; + } + + return {}; +} + +spirv::TargetEnvAttr +GPUToSPIRVPass::lookupTargetEnvOrDefault(gpu::GPUModuleOp moduleOp) { + if (spirv::TargetEnvAttr targetEnvAttr = lookupTargetEnvInTargets(moduleOp)) + return targetEnvAttr; + + return spirv::lookupTargetEnvOrDefault(moduleOp); +} + void GPUToSPIRVPass::runOnOperation() { MLIRContext *context = &getContext(); ModuleOp module = getOperation(); @@ -58,9 +85,8 @@ void GPUToSPIRVPass::runOnOperation() { SmallVector gpuModules; OpBuilder builder(context); - auto targetEnvSupportsKernelCapability = [](gpu::GPUModuleOp moduleOp) { - Operation *gpuModule = moduleOp.getOperation(); - auto targetAttr = spirv::lookupTargetEnvOrDefault(gpuModule); + auto targetEnvSupportsKernelCapability = [this](gpu::GPUModuleOp moduleOp) { + auto targetAttr = lookupTargetEnvOrDefault(moduleOp); spirv::TargetEnv targetEnv(targetAttr); return targetEnv.allows(spirv::Capability::Kernel); }; @@ -86,7 +112,7 @@ void GPUToSPIRVPass::runOnOperation() { // TargetEnv attributes. for (Operation *gpuModule : gpuModules) { spirv::TargetEnvAttr targetAttr = - spirv::lookupTargetEnvOrDefault(gpuModule); + lookupTargetEnvOrDefault(cast(gpuModule)); // Map MemRef memory space to SPIR-V storage class first if requested. if (mapMemorySpace) { diff --git a/mlir/test/Conversion/GPUToSPIRV/lookup-target-env.mlir b/mlir/test/Conversion/GPUToSPIRV/lookup-target-env.mlir new file mode 100644 index 0000000000000..983747be57995 --- /dev/null +++ b/mlir/test/Conversion/GPUToSPIRV/lookup-target-env.mlir @@ -0,0 +1,40 @@ +// RUN: mlir-opt --split-input-file --convert-gpu-to-spirv %s | FileCheck %s + +module attributes {gpu.container_module} { + // CHECK-LABEL: spirv.module @{{.*}} GLSL450 + gpu.module @kernels [#spirv.target_env<#spirv.vce, #spirv.resource_limits<>>] { + // CHECK: spirv.func @load_kernel + // CHECK-SAME: %[[ARG:.*]]: !spirv.ptr [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}) + gpu.func @load_kernel(%arg0: memref<12x4xf32>) kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { + %c0 = arith.constant 0 : index + // CHECK: %[[PTR:.*]] = spirv.AccessChain %[[ARG]]{{\[}}{{%.*}}, {{%.*}}{{\]}} + // CHECK-NEXT: {{%.*}} = spirv.Load "StorageBuffer" %[[PTR]] : f32 + %0 = memref.load %arg0[%c0, %c0] : memref<12x4xf32> + // CHECK: spirv.Return + gpu.return + } + } +} + +// ----- +// Checks that the `-convert-gpu-to-spirv` pass selects the first +// `spirv.target_env` from the `targets` array attribute attached to `gpu.module`. +module attributes {gpu.container_module} { + // CHECK-LABEL: spirv.module @{{.*}} GLSL450 + // CHECK-SAME: #spirv.target_env<#spirv.vce + gpu.module @kernels [ + #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>, + #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>, + #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>] { + // CHECK: spirv.func @load_kernel + // CHECK-SAME: %[[ARG:.*]]: !spirv.ptr [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}) + gpu.func @load_kernel(%arg0: memref<12x4xf32>) kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { + %c0 = arith.constant 0 : index + // CHECK: %[[PTR:.*]] = spirv.AccessChain %[[ARG]]{{\[}}{{%.*}}, {{%.*}}{{\]}} + // CHECK-NEXT: {{%.*}} = spirv.Load "StorageBuffer" %[[PTR]] : f32 + %0 = memref.load %arg0[%c0, %c0] : memref<12x4xf32> + // CHECK: spirv.Return + gpu.return + } + } +}