Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,11 @@ LogicalResult GPUModuleConversion::matchAndRewrite(
if (auto attr = moduleOp->getAttrOfType<spirv::TargetEnvAttr>(
spirv::getTargetEnvAttrName()))
spvModule->setAttr(spirv::getTargetEnvAttrName(), attr);
if (const ArrayAttr &targets = moduleOp.getTargetsAttr()) {
for (const Attribute &targetAttr : targets)
if (auto spirvTargetEnvAttr = dyn_cast<spirv::TargetEnvAttr>(targetAttr))
spvModule->setAttr(spirv::getTargetEnvAttrName(), spirvTargetEnvAttr);
}

rewriter.eraseOp(moduleOp);
return success();
Expand Down
34 changes: 30 additions & 4 deletions mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,45 @@ struct GPUToSPIRVPass final : impl::ConvertGPUToSPIRVBase<GPUToSPIRVPass> {
void runOnOperation() override;

private:
/// Queries the target environment from 'targets' attribute of the given
/// `moduleOp`.
spirv::TargetEnvAttr lookupTargetEnvInTargets(gpu::GPUModuleOp moduleOp);

/// Queries the target environment from 'targets' attribute of the given
/// `moduleOp` or returns target environment as returned by
/// `spirv::lookupTargetEnvOrDefault` if not provided by 'targets'.
spirv::TargetEnvAttr lookupTargetEnvOrDefault(gpu::GPUModuleOp moduleOp);
bool mapMemorySpace;
};

spirv::TargetEnvAttr
GPUToSPIRVPass::lookupTargetEnvInTargets(gpu::GPUModuleOp moduleOp) {
if (const ArrayAttr &targets = moduleOp.getTargetsAttr()) {
for (const Attribute &targetAttr : targets)
if (auto spirvTargetEnvAttr = dyn_cast<spirv::TargetEnvAttr>(targetAttr))
return spirvTargetEnvAttr;
}

return {};
}

spirv::TargetEnvAttr
GPUToSPIRVPass::lookupTargetEnvOrDefault(gpu::GPUModuleOp moduleOp) {
if (spirv::TargetEnvAttr targetEnvAttr = lookupTargetEnvInTargets(moduleOp))
return targetEnvAttr;

return spirv::lookupTargetEnvOrDefault(moduleOp);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the tests, it looks like they both test the if statement above. Is there a test checking this part of the function? If not, do we need a test to check that the function behaves as intended when no attributes are attached?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understood correctly, we should test the behavior when the gpu.module does not have a targets attribute with spirv target env so that the line 78 code is executed.

I believe this case might already be covered by existing tests. For instance, gpu-to-spirv.mlir tests the behavior when no target env is present at all, and load-store.mlir covers the case where the target env is not attached to thegpu.module's targets attribute but to the module.

However, if a more specific test is needed, such as one where the gpu.module lacks the targets attribute but has a target env in its attr-dict, or if other additional tests are necessary, I would be happy to add them based on your feedback.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understood correctly, we should test the behavior when the gpu.module does not have a targets attribute with spirv target env so that the line 78 code is executed.

Correct.

I believe this case might already be covered by existing tests. For instance, gpu-to-spirv.mlir tests the behavior when no target env is present at all, and load-store.mlir covers the case where the target env is not attached to thegpu.module's targets attribute but to the module.

Sounds good, I think we are good.

}

void GPUToSPIRVPass::runOnOperation() {
MLIRContext *context = &getContext();
ModuleOp module = getOperation();

SmallVector<Operation *, 1> gpuModules;
OpBuilder builder(context);

auto targetEnvSupportsKernelCapability = [](gpu::GPUModuleOp moduleOp) {
Operation *gpuModule = moduleOp.getOperation();
auto targetAttr = spirv::lookupTargetEnvOrDefault(gpuModule);
auto targetEnvSupportsKernelCapability = [this](gpu::GPUModuleOp moduleOp) {
auto targetAttr = lookupTargetEnvOrDefault(moduleOp);
spirv::TargetEnv targetEnv(targetAttr);
return targetEnv.allows(spirv::Capability::Kernel);
};
Expand All @@ -86,7 +112,7 @@ void GPUToSPIRVPass::runOnOperation() {
// TargetEnv attributes.
for (Operation *gpuModule : gpuModules) {
spirv::TargetEnvAttr targetAttr =
spirv::lookupTargetEnvOrDefault(gpuModule);
lookupTargetEnvOrDefault(cast<gpu::GPUModuleOp>(gpuModule));

// Map MemRef memory space to SPIR-V storage class first if requested.
if (mapMemorySpace) {
Expand Down
17 changes: 17 additions & 0 deletions mlir/test/Conversion/GPUToSPIRV/lookup-target-env.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: mlir-opt --convert-gpu-to-spirv %s | FileCheck %s

module attributes {gpu.container_module} {
// CHECK-LABEL: spirv.module @{{.*}} GLSL450
gpu.module @kernels [#spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>] {
// CHECK: spirv.func @load_kernel
// CHECK-SAME: %[[ARG:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<48 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>})
gpu.func @load_kernel(%arg0: memref<12x4xf32>) kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
%c0 = arith.constant 0 : index
// CHECK: %[[PTR:.*]] = spirv.AccessChain %[[ARG]]{{\[}}{{%.*}}, {{%.*}}{{\]}}
// CHECK-NEXT: {{%.*}} = spirv.Load "StorageBuffer" %[[PTR]] : f32
%0 = memref.load %arg0[%c0, %c0] : memref<12x4xf32>
// CHECK: spirv.Return
gpu.return
}
}
}