diff --git a/offload/plugins-nextgen/common/include/JIT.h b/offload/plugins-nextgen/common/include/JIT.h index 8c530436a754b..d62516d20764a 100644 --- a/offload/plugins-nextgen/common/include/JIT.h +++ b/offload/plugins-nextgen/common/include/JIT.h @@ -55,6 +55,10 @@ struct JITEngine { process(const __tgt_device_image &Image, target::plugin::GenericDeviceTy &Device); + /// Remove \p Image from the jit engine's cache + void erase(const __tgt_device_image &Image, + target::plugin::GenericDeviceTy &Device); + private: /// Compile the bitcode image \p Image and generate the binary image that can /// be loaded to the target device of the triple \p Triple architecture \p @@ -89,11 +93,13 @@ struct JITEngine { /// LLVM Context in which the modules will be constructed. LLVMContext Context; - /// Output images generated from LLVM backend. - SmallVector, 4> JITImages; + /// A map of embedded IR images to the buffer used to store JITed code + DenseMap> + JITImages; /// A map of embedded IR images to JITed images. - DenseMap TgtImageMap; + DenseMap> + TgtImageMap; }; /// Map from (march) "CPUs" (e.g., sm_80, or gfx90a), which we call compute diff --git a/offload/plugins-nextgen/common/src/JIT.cpp b/offload/plugins-nextgen/common/src/JIT.cpp index c82a06e36d8f9..00720fa2d8103 100644 --- a/offload/plugins-nextgen/common/src/JIT.cpp +++ b/offload/plugins-nextgen/common/src/JIT.cpp @@ -285,8 +285,8 @@ JITEngine::compile(const __tgt_device_image &Image, // Check if we JITed this image for the given compute unit kind before. ComputeUnitInfo &CUI = ComputeUnitMap[ComputeUnitKind]; - if (__tgt_device_image *JITedImage = CUI.TgtImageMap.lookup(&Image)) - return JITedImage; + if (CUI.TgtImageMap.contains(&Image)) + return CUI.TgtImageMap[&Image].get(); auto ObjMBOrErr = getOrCreateObjFile(Image, CUI.Context, ComputeUnitKind); if (!ObjMBOrErr) @@ -296,17 +296,15 @@ JITEngine::compile(const __tgt_device_image &Image, if (!ImageMBOrErr) return ImageMBOrErr.takeError(); - CUI.JITImages.push_back(std::move(*ImageMBOrErr)); - __tgt_device_image *&JITedImage = CUI.TgtImageMap[&Image]; - JITedImage = new __tgt_device_image(); + CUI.JITImages.insert({&Image, std::move(*ImageMBOrErr)}); + auto &ImageMB = CUI.JITImages[&Image]; + CUI.TgtImageMap.insert({&Image, std::make_unique<__tgt_device_image>()}); + auto &JITedImage = CUI.TgtImageMap[&Image]; *JITedImage = Image; - - auto &ImageMB = CUI.JITImages.back(); - JITedImage->ImageStart = const_cast(ImageMB->getBufferStart()); JITedImage->ImageEnd = const_cast(ImageMB->getBufferEnd()); - return JITedImage; + return JITedImage.get(); } Expected @@ -324,3 +322,13 @@ JITEngine::process(const __tgt_device_image &Image, return &Image; } + +void JITEngine::erase(const __tgt_device_image &Image, + target::plugin::GenericDeviceTy &Device) { + std::lock_guard Lock(ComputeUnitMapMutex); + const std::string &ComputeUnitKind = Device.getComputeUnitKind(); + ComputeUnitInfo &CUI = ComputeUnitMap[ComputeUnitKind]; + + CUI.TgtImageMap.erase(&Image); + CUI.JITImages.erase(&Image); +} diff --git a/offload/plugins-nextgen/common/src/PluginInterface.cpp b/offload/plugins-nextgen/common/src/PluginInterface.cpp index 81b9d423e13d8..94a050b559efe 100644 --- a/offload/plugins-nextgen/common/src/PluginInterface.cpp +++ b/offload/plugins-nextgen/common/src/PluginInterface.cpp @@ -854,6 +854,9 @@ Error GenericDeviceTy::unloadBinary(DeviceImageTy *Image) { return Err; } + if (Image->getTgtImageBitcode()) + Plugin.getJIT().erase(*Image->getTgtImageBitcode(), Image->getDevice()); + return unloadBinaryImpl(Image); }