diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index c5b6507ac847b..711d81186bad1 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -1304,7 +1304,7 @@ std::vector NvExecutionProvider::CreatePreferredAllocators() { AllocatorCreationInfo pinned_allocator_info( [](OrtDevice::DeviceId device_id) { - return std::make_unique(CUDA_PINNED, device_id); + return std::make_unique(device_id, CUDA_PINNED); }, narrow(device_id_)); diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc index 0fc3e5443bc28..83edc6ccdd313 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc @@ -115,6 +115,23 @@ struct Nv_Provider : Provider { return std::make_shared(info); } + Status CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata*/, + size_t /*num_devices*/, + ProviderOptions& provider_options, + const OrtSessionOptions& session_options, + const OrtLogger& logger, + std::unique_ptr& ep) override { + const ConfigOptions* config_options = &session_options.GetConfigOptions(); + + std::array configs_array = {&provider_options, config_options}; + const void* arg = reinterpret_cast(&configs_array); + auto ep_factory = CreateExecutionProviderFactory(arg); + ep = ep_factory->CreateProvider(session_options, logger); + + return Status::OK(); + } + void Initialize() override { InitializeRegistry(); } @@ -133,3 +150,118 @@ ORT_API(onnxruntime::Provider*, GetProvider) { return &onnxruntime::g_provider; } } + +#include "core/framework/error_code_helper.h" + +// OrtEpApi infrastructure to be able to use the NvTensorRTRTX EP as an OrtEpFactory for auto EP selection. +struct NvTensorRtRtxEpFactory : OrtEpFactory { + NvTensorRtRtxEpFactory(const OrtApi& ort_api_in, + const char* ep_name, + OrtHardwareDeviceType hw_type) + : ort_api{ort_api_in}, ep_name{ep_name}, ort_hw_device_type{hw_type} { + GetName = GetNameImpl; + GetVendor = GetVendorImpl; + GetVersion = GetVersionImpl; + GetSupportedDevices = GetSupportedDevicesImpl; + CreateEp = CreateEpImpl; + ReleaseEp = ReleaseEpImpl; + } + + // Returns the name for the EP. Each unique factory configuration must have a unique name. + // Ex: a factory that supports NPU should have a different than a factory that supports GPU. + static const char* GetNameImpl(const OrtEpFactory* this_ptr) { + const auto* factory = static_cast(this_ptr); + return factory->ep_name.c_str(); + } + + static const char* GetVendorImpl(const OrtEpFactory* this_ptr) { + const auto* factory = static_cast(this_ptr); + return factory->vendor.c_str(); + } + + static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return ORT_VERSION; + } + + // Creates and returns OrtEpDevice instances for all OrtHardwareDevices that this factory supports. + // An EP created with this factory is expected to be able to execute a model with *all* supported + // hardware devices at once. A single instance of NvTensorRtRtx EP is not currently setup to partition a model among + // multiple different NvTensorRtRtx backends at once (e.g, npu, cpu, gpu), so this factory instance is set to only + // support one backend: gpu. To support a different backend, like npu, create a different factory instance + // that only supports NPU. + static OrtStatus* GetSupportedDevicesImpl(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) { + size_t& num_ep_devices = *p_num_ep_devices; + auto* factory = static_cast(this_ptr); + + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + const OrtHardwareDevice& device = *devices[i]; + if (factory->ort_api.HardwareDevice_Type(&device) == factory->ort_hw_device_type && + factory->ort_api.HardwareDevice_VendorId(&device) == factory->vendor_id) { + OrtKeyValuePairs* ep_options = nullptr; + factory->ort_api.CreateKeyValuePairs(&ep_options); + ORT_API_RETURN_IF_ERROR( + factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, nullptr, ep_options, + &ep_devices[num_ep_devices++])); + } + } + + return nullptr; + } + + static OrtStatus* CreateEpImpl(OrtEpFactory* /*this_ptr*/, + _In_reads_(num_devices) const OrtHardwareDevice* const* /*devices*/, + _In_reads_(num_devices) const OrtKeyValuePairs* const* /*ep_metadata*/, + _In_ size_t /*num_devices*/, + _In_ const OrtSessionOptions* /*session_options*/, + _In_ const OrtLogger* /*logger*/, + _Out_ OrtEp** /*ep*/) { + return onnxruntime::CreateStatus(ORT_INVALID_ARGUMENT, "[NvTensorRTRTX EP] EP factory does not support this method."); + } + + static void ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* /*ep*/) { + // no-op as we never create an EP here. + } + + const OrtApi& ort_api; + const std::string ep_name; + const std::string vendor{"NVIDIA"}; + + // NVIDIA vendor ID. Refer to the ACPI ID registry (search NVIDIA): https://uefi.org/ACPI_ID_List + const uint32_t vendor_id{0x10de}; + const OrtHardwareDeviceType ort_hw_device_type; // Supported OrtHardwareDevice +}; + +extern "C" { +// +// Public symbols +// +OrtStatus* CreateEpFactories(const char* /*registration_name*/, const OrtApiBase* ort_api_base, + OrtEpFactory** factories, size_t max_factories, size_t* num_factories) { + const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION); + + // Factory could use registration_name or define its own EP name. + auto factory_gpu = std::make_unique(*ort_api, + onnxruntime::kNvTensorRTRTXExecutionProvider, + OrtHardwareDeviceType_GPU); + + if (max_factories < 1) { + return ort_api->CreateStatus(ORT_INVALID_ARGUMENT, + "Not enough space to return EP factory. Need at least one."); + } + + factories[0] = factory_gpu.release(); + *num_factories = 1; + + return nullptr; +} + +OrtStatus* ReleaseEpFactory(OrtEpFactory* factory) { + delete static_cast(factory); + return nullptr; +} +} diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/symbols.def b/onnxruntime/core/providers/nv_tensorrt_rtx/symbols.def index 4ec2f7914c208..3afed01da1966 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/symbols.def +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/symbols.def @@ -1,2 +1,4 @@ EXPORTS GetProvider + CreateEpFactories + ReleaseEpFactory diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc index 4718a38ce4e1c..144901d8dac33 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc @@ -389,5 +389,47 @@ TYPED_TEST(NvExecutionProviderTest, IOTypeTests) { } } +static bool SessionHasEp(Ort::Session& session, const char* ep_name) { + // Access the underlying InferenceSession. + const OrtSession* ort_session = session; + const InferenceSession* s = reinterpret_cast(ort_session); + bool has_ep = false; + + for (const auto& provider : s->GetRegisteredProviderTypes()) { + if (provider == ep_name) { + has_ep = true; + break; + } + } + return has_ep; +} + +#if defined(WIN32) +// Tests autoEP feature to automatically select an EP that supports the GPU. +// Currently only works on Windows. +TEST(NvExecutionProviderTest, AutoEp_PreferGpu) { + PathString model_name = ORT_TSTR("nv_execution_provider_data_dyn_test.onnx"); + std::string graph_name = "test"; + std::vector dims = {1, -1, -1}; + + CreateBaseModel(model_name, graph_name, dims, true); + + auto env = Ort::Env(); + auto logging_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING; + env.UpdateEnvWithCustomLogLevel(logging_level); + + { + env.RegisterExecutionProviderLibrary(kNvTensorRTRTXExecutionProvider, ORT_TSTR("onnxruntime_providers_nv_tensorrt_rtx.dll")); + + Ort::SessionOptions so; + so.SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy_PREFER_GPU); + Ort::Session session_object(env, model_name.c_str(), so); + EXPECT_TRUE(SessionHasEp(session_object, kNvTensorRTRTXExecutionProvider)); + } + + env.UnregisterExecutionProviderLibrary(kNvTensorRTRTXExecutionProvider); +} +#endif // defined(WIN32) + } // namespace test } // namespace onnxruntime