Skip to content

Commit fa3b44e

Browse files
committed
fixup! [mlir][SPIRV] Fix lookup logic spirv.target_env for gpu.module
Add lookup target env in "targets" attr logic to GPUToSPIRV pass
1 parent ae6ba9c commit fa3b44e

File tree

2 files changed

+23
-11
lines changed

2 files changed

+23
-11
lines changed

mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,19 +48,38 @@ struct GPUToSPIRVPass final : impl::ConvertGPUToSPIRVBase<GPUToSPIRVPass> {
4848
void runOnOperation() override;
4949

5050
private:
51+
spirv::TargetEnvAttr lookupTargetEnvInTargets(gpu::GPUModuleOp moduleOp);
52+
spirv::TargetEnvAttr lookupTargetEnvOrDefault(gpu::GPUModuleOp moduleOp);
5153
bool mapMemorySpace;
5254
};
5355

56+
spirv::TargetEnvAttr
57+
GPUToSPIRVPass::lookupTargetEnvInTargets(gpu::GPUModuleOp moduleOp) {
58+
for (auto &targetAttr : moduleOp.getTargetsAttr())
59+
if (auto spirvTargetEnvAttr =
60+
llvm::dyn_cast<spirv::TargetEnvAttr>(targetAttr))
61+
return spirvTargetEnvAttr;
62+
63+
return {};
64+
}
65+
66+
spirv::TargetEnvAttr
67+
GPUToSPIRVPass::lookupTargetEnvOrDefault(gpu::GPUModuleOp moduleOp) {
68+
if (auto targetEnvAttr = lookupTargetEnvInTargets(moduleOp))
69+
return targetEnvAttr;
70+
71+
return spirv::lookupTargetEnvOrDefault(moduleOp);
72+
}
73+
5474
void GPUToSPIRVPass::runOnOperation() {
5575
MLIRContext *context = &getContext();
5676
ModuleOp module = getOperation();
5777

5878
SmallVector<Operation *, 1> gpuModules;
5979
OpBuilder builder(context);
6080

61-
auto targetEnvSupportsKernelCapability = [](gpu::GPUModuleOp moduleOp) {
62-
Operation *gpuModule = moduleOp.getOperation();
63-
auto targetAttr = spirv::lookupTargetEnvOrDefault(gpuModule);
81+
auto targetEnvSupportsKernelCapability = [this](gpu::GPUModuleOp moduleOp) {
82+
auto targetAttr = lookupTargetEnvOrDefault(moduleOp);
6483
spirv::TargetEnv targetEnv(targetAttr);
6584
return targetEnv.allows(spirv::Capability::Kernel);
6685
};
@@ -86,7 +105,7 @@ void GPUToSPIRVPass::runOnOperation() {
86105
// TargetEnv attributes.
87106
for (Operation *gpuModule : gpuModules) {
88107
spirv::TargetEnvAttr targetAttr =
89-
spirv::lookupTargetEnvOrDefault(gpuModule);
108+
lookupTargetEnvOrDefault(llvm::cast<gpu::GPUModuleOp>(gpuModule));
90109

91110
// Map MemRef memory space to SPIR-V storage class first if requested.
92111
if (mapMemorySpace) {

mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -184,13 +184,6 @@ spirv::TargetEnvAttr spirv::lookupTargetEnv(Operation *op) {
184184
if (!op)
185185
break;
186186

187-
if (auto arrAttr = op->getAttrOfType<ArrayAttr>("targets")) {
188-
for (auto attr : arrAttr)
189-
if (auto spirvTargetEnvAttr =
190-
llvm::dyn_cast<spirv::TargetEnvAttr>(attr))
191-
return spirvTargetEnvAttr;
192-
}
193-
194187
if (auto attr = op->getAttrOfType<spirv::TargetEnvAttr>(
195188
spirv::getTargetEnvAttrName()))
196189
return attr;

0 commit comments

Comments
 (0)