diff --git a/tools/render-test/options.cpp b/tools/render-test/options.cpp index 120e03456f..0ecda69443 100644 --- a/tools/render-test/options.cpp +++ b/tools/render-test/options.cpp @@ -278,6 +278,10 @@ static rhi::DeviceType _toRenderType(Slang::RenderApiType apiType) { outOptions.showAdapterInfo = true; } + else if (argValue == "-cache-rhi-device") + { + outOptions.cacheRhiDevice = true; + } else { // Lookup diff --git a/tools/render-test/options.h b/tools/render-test/options.h index 49e05c440f..bbc2364cd4 100644 --- a/tools/render-test/options.h +++ b/tools/render-test/options.h @@ -96,6 +96,9 @@ struct Options bool skipSPIRVValidation = false; + // Whether to enable RHI device caching (default: false in render-test) + bool cacheRhiDevice = false; + Slang::List capabilities; Options() { downstreamArgs.addName("slang"); } diff --git a/tools/render-test/render-test-main.cpp b/tools/render-test/render-test-main.cpp index 991d606839..5b3974fc94 100644 --- a/tools/render-test/render-test-main.cpp +++ b/tools/render-test/render-test-main.cpp @@ -12,6 +12,7 @@ #include "shader-input-layout.h" #include "shader-renderer-util.h" #include "slang-support.h" +#include "slang-test-device-cache.h" #include "window.h" #if defined(_WIN32) @@ -1440,7 +1441,7 @@ static SlangResult _innerMain( } } - renderer_test::CoreToRHIDebugBridge debugCallback; + static renderer_test::CoreToRHIDebugBridge debugCallback; debugCallback.setCoreCallback(stdWriters->getDebugCallback()); // Use the profile name set on options if set @@ -1495,7 +1496,7 @@ static SlangResult _innerMain( return SLANG_E_NOT_AVAILABLE; } - Slang::ComPtr device; + CachedDeviceWrapper deviceWrapper; { DeviceDesc desc = {}; desc.deviceType = options.deviceType; @@ -1558,8 +1559,27 @@ static SlangResult _innerMain( { getRHI()->enableDebugLayers(); } - SlangResult res = getRHI()->createDevice(desc, device.writeRef()); - if (SLANG_FAILED(res)) + Slang::ComPtr rhiDevice; + SlangResult res; + if (options.cacheRhiDevice) + { + res = DeviceCache::acquireDevice(desc, rhiDevice.writeRef()); + if (SLANG_FAILED(res)) + { + rhiDevice = nullptr; + } + } + else + { + res = rhi::getRHI()->createDevice(desc, rhiDevice.writeRef()); + if (SLANG_FAILED(res)) + { + rhiDevice = nullptr; + } + } + + // Check result for both cached and non-cached paths + if (SLANG_FAILED(res) || !rhiDevice) { // We need to be careful here about SLANG_E_NOT_AVAILABLE. This return value means // that the renderer couldn't be created because it required *features* that were @@ -1575,21 +1595,20 @@ static SlangResult _innerMain( { return res; } - if (!options.onlyStartup) { fprintf(stderr, "Unable to create renderer %s\n", rendererName.getBuffer()); } - return res; } - SLANG_ASSERT(device); + SLANG_ASSERT(rhiDevice); + deviceWrapper = CachedDeviceWrapper(rhiDevice); } for (const auto& feature : requiredFeatureList) { // If doesn't have required feature... we have to give up - if (!device->hasFeature(feature)) + if (!deviceWrapper->hasFeature(feature)) { return SLANG_E_NOT_AVAILABLE; } @@ -1599,7 +1618,7 @@ static SlangResult _innerMain( // Print adapter info after device creation but before any other operations if (options.showAdapterInfo) { - auto info = device->getInfo(); + auto info = deviceWrapper->getInfo(); auto out = stdWriters->getOut(); out.print("Using graphics adapter: %s\n", info.adapterName); } @@ -1613,14 +1632,20 @@ static SlangResult _innerMain( { RenderTestApp app; renderDocBeginFrame(); - SLANG_RETURN_ON_FAIL(app.initialize(session, device, options, input)); + SLANG_RETURN_ON_FAIL(app.initialize(session, deviceWrapper.get(), options, input)); app.update(); renderDocEndFrame(); app.finalize(); } + return SLANG_OK; } +SLANG_TEST_TOOL_API void cleanDeviceCache() +{ + DeviceCache::cleanCache(); +} + SLANG_TEST_TOOL_API SlangResult innerMain( Slang::StdWriters* stdWriters, SlangSession* sharedSession, diff --git a/tools/render-test/slang-test-device-cache.cpp b/tools/render-test/slang-test-device-cache.cpp new file mode 100644 index 0000000000..a486ee3f3c --- /dev/null +++ b/tools/render-test/slang-test-device-cache.cpp @@ -0,0 +1,160 @@ +#include "slang-test-device-cache.h" + +#include + +// Static member accessor functions (Meyer's singleton pattern) +// This ensures proper destruction order - function-local statics are destroyed +// in reverse order of first access, avoiding the static destruction order fiasco +std::mutex& DeviceCache::getMutex() +{ + static std::mutex instance; + return instance; +} + +std::unordered_map< + DeviceCache::DeviceCacheKey, + DeviceCache::CachedDevice, + DeviceCache::DeviceCacheKeyHash>& +DeviceCache::getDeviceCache() +{ + static std::unordered_map instance; + return instance; +} + +uint64_t& DeviceCache::getNextCreationOrder() +{ + static uint64_t instance = 0; + return instance; +} + +bool DeviceCache::DeviceCacheKey::operator==(const DeviceCacheKey& other) const +{ + return deviceType == other.deviceType && enableValidation == other.enableValidation && + enableRayTracingValidation == other.enableRayTracingValidation && + profileName == other.profileName && requiredFeatures == other.requiredFeatures; +} + +std::size_t DeviceCache::DeviceCacheKeyHash::operator()(const DeviceCacheKey& key) const +{ + std::size_t h1 = std::hash{}(static_cast(key.deviceType)); + std::size_t h2 = std::hash{}(key.enableValidation); + std::size_t h3 = std::hash{}(key.enableRayTracingValidation); + std::size_t h4 = std::hash{}(key.profileName); + + std::size_t h5 = 0; + for (const auto& feature : key.requiredFeatures) + { + h5 ^= std::hash{}(feature) + 0x9e3779b9 + (h5 << 6) + (h5 >> 2); + } + + return h1 ^ (h2 << 1) ^ (h3 << 2) ^ (h4 << 3) ^ (h5 << 4); +} + +DeviceCache::CachedDevice::CachedDevice() + : creationOrder(0) +{ +} + +void DeviceCache::evictOldestDeviceIfNeeded() +{ + auto& deviceCache = getDeviceCache(); + if (deviceCache.size() < MAX_CACHED_DEVICES) + return; + + // Find the oldest device to evict + auto oldestIt = deviceCache.end(); + uint64_t oldestCreationOrder = UINT64_MAX; + + for (auto it = deviceCache.begin(); it != deviceCache.end(); ++it) + { + if (it->second.creationOrder < oldestCreationOrder) + { + oldestCreationOrder = it->second.creationOrder; + oldestIt = it; + } + } + + // Remove the oldest device - ComPtr will handle the actual device release + if (oldestIt != deviceCache.end()) + { + deviceCache.erase(oldestIt); + } +} + +SlangResult DeviceCache::acquireDevice(const rhi::DeviceDesc& desc, rhi::IDevice** outDevice) +{ + if (!outDevice) + return SLANG_E_INVALID_ARG; + + *outDevice = nullptr; + + // Skip caching for CUDA devices due to crashes + if (desc.deviceType == rhi::DeviceType::CUDA) + { + return rhi::getRHI()->createDevice(desc, outDevice); + } + + std::lock_guard lock(getMutex()); + auto& deviceCache = getDeviceCache(); + auto& nextCreationOrder = getNextCreationOrder(); + + // Create cache key + DeviceCacheKey key; + key.deviceType = desc.deviceType; + key.enableValidation = desc.enableValidation; + key.enableRayTracingValidation = desc.enableRayTracingValidation; + key.profileName = desc.slang.targetProfile ? desc.slang.targetProfile : "Unknown"; + + // Add required features to key + for (int i = 0; i < desc.requiredFeatureCount; ++i) + { + key.requiredFeatures.push_back(desc.requiredFeatures[i]); + } + std::sort(key.requiredFeatures.begin(), key.requiredFeatures.end()); + + // Evict oldest device if we've reached the limit + evictOldestDeviceIfNeeded(); + + // Check if we have a cached device + auto it = deviceCache.find(key); + if (it != deviceCache.end()) + { + // Return the cached device - COM reference counting handles the references + *outDevice = it->second.device.get(); + if (*outDevice) + { + (*outDevice)->addRef(); + return SLANG_OK; + } + } + + // Create new device + Slang::ComPtr device; + auto result = rhi::getRHI()->createDevice(desc, device.writeRef()); + if (SLANG_FAILED(result)) + { + return result; + } + + // Cache the device + CachedDevice& cached = deviceCache[key]; + cached.device = device; + cached.creationOrder = nextCreationOrder++; + + // Return the device with proper reference counting + *outDevice = device.get(); + if (*outDevice) + { + (*outDevice)->addRef(); + } + + return SLANG_OK; +} + + +void DeviceCache::cleanCache() +{ + std::lock_guard lock(getMutex()); + auto& deviceCache = getDeviceCache(); + deviceCache.clear(); +} diff --git a/tools/render-test/slang-test-device-cache.h b/tools/render-test/slang-test-device-cache.h new file mode 100644 index 0000000000..752d03ff16 --- /dev/null +++ b/tools/render-test/slang-test-device-cache.h @@ -0,0 +1,97 @@ +#pragma once + +#include +#include +#include +#include +#include + +// Device Cache for preventing NVIDIA Tegra driver state corruption +// This cache reuses Vulkan instances and devices to avoid the VK_ERROR_INCOMPATIBLE_DRIVER +// issue that occurs after ~19 device creation/destruction cycles on Tegra platforms. +// Uses ComPtr for automatic device lifecycle management - devices are released when removed from +// cache. +class DeviceCache +{ +public: + struct DeviceCacheKey + { + rhi::DeviceType deviceType; + bool enableValidation; + bool enableRayTracingValidation; + std::string profileName; + std::vector requiredFeatures; + + bool operator==(const DeviceCacheKey& other) const; + }; + + struct DeviceCacheKeyHash + { + std::size_t operator()(const DeviceCacheKey& key) const; + }; + + struct CachedDevice + { + Slang::ComPtr device; + uint64_t creationOrder; + + CachedDevice(); + }; + +private: + static constexpr int MAX_CACHED_DEVICES = 10; + + // Use function-local statics to control destruction order (Meyer's singleton pattern) + static std::mutex& getMutex(); + static std::unordered_map& getDeviceCache(); + static uint64_t& getNextCreationOrder(); + + static void evictOldestDeviceIfNeeded(); + +public: + static SlangResult acquireDevice(const rhi::DeviceDesc& desc, rhi::IDevice** outDevice); + static void cleanCache(); +}; + +// RAII wrapper for cached devices to ensure proper cleanup +class CachedDeviceWrapper +{ +private: + Slang::ComPtr m_device; + +public: + CachedDeviceWrapper() = default; + + CachedDeviceWrapper(Slang::ComPtr device) + : m_device(device) + { + } + + ~CachedDeviceWrapper() {} + + // Move constructor + CachedDeviceWrapper(CachedDeviceWrapper&& other) noexcept + : m_device(std::move(other.m_device)) + { + } + + // Move assignment + CachedDeviceWrapper& operator=(CachedDeviceWrapper&& other) noexcept + { + if (this != &other) + { + m_device = std::move(other.m_device); + } + return *this; + } + + // Delete copy constructor and assignment + CachedDeviceWrapper(const CachedDeviceWrapper&) = delete; + CachedDeviceWrapper& operator=(const CachedDeviceWrapper&) = delete; + + rhi::IDevice* get() const { return m_device.get(); } + rhi::IDevice* operator->() const { return m_device.get(); } + operator bool() const { return m_device != nullptr; } + + Slang::ComPtr& getComPtr() { return m_device; } +}; diff --git a/tools/slang-test/options.cpp b/tools/slang-test/options.cpp index ec30262fce..7a98caec67 100644 --- a/tools/slang-test/options.cpp +++ b/tools/slang-test/options.cpp @@ -76,6 +76,7 @@ static bool _isSubCommand(const char* arg) " -verbose-paths Use verbose paths in output\n" " -category Only run tests in specified category\n" " -exclude Exclude tests in specified category\n" + " -exclude-prefix Exclude tests with specified path prefix\n" " -api Enable specific APIs (e.g., 'vk+dx12' or '+dx11')\n" " -synthesizedTestApi Set APIs for synthesized tests\n" " -skip-api-detection Skip API availability detection\n" @@ -91,6 +92,7 @@ static bool _isSubCommand(const char* arg) " -capability Compile with the given capability\n" " -enable-debug-layers [true|false] Enable or disable Validation Layer for Vulkan\n" " and Debug Device for DX\n" + " -cache-rhi-device [true|false] Enable or disable RHI device caching (default: true)\n" #if _DEBUG " -disable-debug-layers Disable the debug layers (default enabled in debug " "build)\n" @@ -357,6 +359,18 @@ static bool _isSubCommand(const char* arg) optionsOut->excludeCategories.add(category, category); } } + else if (strcmp(arg, "-exclude-prefix") == 0) + { + if (argCursor == argEnd) + { + stdError.print("error: expected operand for '%s'\n", arg); + showHelp(stdError); + return SLANG_FAIL; + } + Slang::StringBuilder sb; + Slang::Path::simplify(*argCursor++, Slang::Path::SimplifyStyle::NoRoot, sb); + optionsOut->excludePrefixes.add(sb); + } else if (strcmp(arg, "-api") == 0) { if (argCursor == argEnd) @@ -488,6 +502,26 @@ static bool _isSubCommand(const char* arg) optionsOut->enableDebugLayers = false; } } + else if (strcmp(arg, "-cache-rhi-device") == 0) + { + optionsOut->cacheRhiDevice = true; + + if (argCursor == argEnd) + { + stdError.print("error: expected operand for '%s'\n", arg); + showHelp(stdError); + return SLANG_FAIL; + } + + // Check for false variants + const char* value = *argCursor++; + if (value[0] == 'f' || value[0] == 'F' || value[0] == 'n' || value[0] == 'N' || + value[0] == '0' || + ((value[0] == 'o' || value[0] == 'O') && (value[1] == 'f' || value[1] == 'F'))) + { + optionsOut->cacheRhiDevice = false; + } + } #if _DEBUG else if (strcmp(arg, "-disable-debug-layers") == 0) { diff --git a/tools/slang-test/options.h b/tools/slang-test/options.h index df2de310a1..9223b2eafb 100644 --- a/tools/slang-test/options.h +++ b/tools/slang-test/options.h @@ -64,6 +64,9 @@ struct Options // only run test cases with names have one of these prefixes. Slang::List testPrefixes; + // skip test cases with names that have one of these prefixes. + Slang::List excludePrefixes; + // verbosity level for output VerbosityLevel verbosity = VerbosityLevel::Info; @@ -133,6 +136,9 @@ struct Options bool emitSPIRVDirectly = true; + // Whether to enable RHI device caching in render-test (default: true in slang-test) + bool cacheRhiDevice = true; + Slang::HashSet capabilities; Slang::HashSet expectedFailureList; diff --git a/tools/slang-test/slang-test-main.cpp b/tools/slang-test/slang-test-main.cpp index 0297227d63..08019f9951 100644 --- a/tools/slang-test/slang-test-main.cpp +++ b/tools/slang-test/slang-test-main.cpp @@ -3643,6 +3643,11 @@ static void _addRenderTestOptions(const Options& options, CommandLine& ioCmdLine { ioCmdLine.addArg("-enable-debug-layers"); } + + if (options.cacheRhiDevice) + { + ioCmdLine.addArg("-cache-rhi-device"); + } } static SlangResult _extractProfileTime(const UnownedStringSlice& text, double& timeOut) @@ -4766,6 +4771,23 @@ static bool shouldRunTest(TestContext* context, String filePath) if (!endsWithAllowedExtension(context, filePath)) return false; + // Check exclude prefixes first - if any match, skip the test + for (auto& excludePrefix : context->options.excludePrefixes) + { + if (filePath.startsWith(excludePrefix)) + { + if (context->options.verbosity == VerbosityLevel::Verbose) + { + context->getTestReporter()->messageFormat( + TestMessageType::Info, + "%s file is excluded from the test because it is found from the exclusion " + "list\n", + filePath.getBuffer()); + } + return false; + } + } + if (!context->options.testPrefixes.getCount()) { return true; @@ -5129,6 +5151,15 @@ static SlangResult runUnitTestModule( return SLANG_OK; } +static void cleanupRenderTestDeviceCache(TestContext& context) +{ + auto cleanFunc = context.getCleanDeviceCacheFunc("render-test"); + if (cleanFunc) + { + cleanFunc(); + } +} + SlangResult innerMain(int argc, char** argv) { auto stdWriters = StdWriters::initDefaultSingleton(); @@ -5420,12 +5451,15 @@ SlangResult innerMain(int argc, char** argv) } reporter.outputSummary(); + + cleanupRenderTestDeviceCache(context); return reporter.didAllSucceed() ? SLANG_OK : SLANG_FAIL; } } int main(int argc, char** argv) { + // Fallback: run without cleanup if context initialization fails SlangResult res = innerMain(argc, argv); slang::shutdown(); Slang::RttiInfo::deallocateAll(); diff --git a/tools/slang-test/test-context.cpp b/tools/slang-test/test-context.cpp index ede12d0d7f..e3655e61d9 100644 --- a/tools/slang-test/test-context.cpp +++ b/tools/slang-test/test-context.cpp @@ -130,6 +130,8 @@ TestContext::InnerMainFunc TestContext::getInnerMainFunc(const String& dirPath, loader->loadPlatformSharedLibrary(path.begin(), tool.m_sharedLibrary.writeRef()))) { tool.m_func = (InnerMainFunc)tool.m_sharedLibrary->findFuncByName("innerMain"); + tool.m_cleanDeviceCacheFunc = + (CleanDeviceCacheFunc)tool.m_sharedLibrary->findFuncByName("cleanDeviceCache"); } m_sharedLibTools.add(name, tool); @@ -152,6 +154,17 @@ void TestContext::setInnerMainFunc(const String& name, InnerMainFunc func) } } +TestContext::CleanDeviceCacheFunc TestContext::getCleanDeviceCacheFunc(const String& name) +{ + SharedLibraryTool* tool = m_sharedLibTools.tryGetValue(name); + if (tool) + { + return tool->m_cleanDeviceCacheFunc; + } + + return nullptr; +} + DownstreamCompilerSet* TestContext::getCompilerSet() { std::lock_guard lock(mutex); diff --git a/tools/slang-test/test-context.h b/tools/slang-test/test-context.h index e760e8dda5..0161637dd3 100644 --- a/tools/slang-test/test-context.h +++ b/tools/slang-test/test-context.h @@ -90,6 +90,7 @@ class TestContext { public: typedef Slang::TestToolUtil::InnerMainFunc InnerMainFunc; + typedef void (*CleanDeviceCacheFunc)(); /// Get the slang session SlangSession* getSession() const { return m_session; } @@ -101,6 +102,9 @@ class TestContext /// Set the function for the shared library void setInnerMainFunc(const Slang::String& name, InnerMainFunc func); + /// Get the device cache cleanup function (from shared library) + CleanDeviceCacheFunc getCleanDeviceCacheFunc(const Slang::String& name); + void setTestRequirements(TestRequirements* req); TestRequirements* getTestRequirements() const; @@ -196,6 +200,7 @@ class TestContext { Slang::ComPtr m_sharedLibrary; InnerMainFunc m_func; + CleanDeviceCacheFunc m_cleanDeviceCacheFunc; }; Slang::List> m_jsonRpcConnections;