Skip to content

Commit 9177ae4

Browse files
ishwar-raut1fs-eire
authored andcommitted
[NV RTX EP] Upstream changes from the win-ort (microsoft#25370)
### Description <!-- Describe your changes. --> Changes from win-onnxruntime to onnxruntime Fix the build break. --------- Co-authored-by: Yulong Wang <[email protected]>
1 parent 9a3e54f commit 9177ae4

File tree

4 files changed

+177
-1
lines changed

4 files changed

+177
-1
lines changed

onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1304,7 +1304,7 @@ std::vector<AllocatorPtr> NvExecutionProvider::CreatePreferredAllocators() {
13041304

13051305
AllocatorCreationInfo pinned_allocator_info(
13061306
[](OrtDevice::DeviceId device_id) {
1307-
return std::make_unique<CUDAPinnedAllocator>(CUDA_PINNED, device_id);
1307+
return std::make_unique<CUDAPinnedAllocator>(device_id, CUDA_PINNED);
13081308
},
13091309
narrow<OrtDevice::DeviceId>(device_id_));
13101310

onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,23 @@ struct Nv_Provider : Provider {
115115
return std::make_shared<NvProviderFactory>(info);
116116
}
117117

118+
Status CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/,
119+
const OrtKeyValuePairs* const* /*ep_metadata*/,
120+
size_t /*num_devices*/,
121+
ProviderOptions& provider_options,
122+
const OrtSessionOptions& session_options,
123+
const OrtLogger& logger,
124+
std::unique_ptr<IExecutionProvider>& ep) override {
125+
const ConfigOptions* config_options = &session_options.GetConfigOptions();
126+
127+
std::array<const void*, 2> configs_array = {&provider_options, config_options};
128+
const void* arg = reinterpret_cast<const void*>(&configs_array);
129+
auto ep_factory = CreateExecutionProviderFactory(arg);
130+
ep = ep_factory->CreateProvider(session_options, logger);
131+
132+
return Status::OK();
133+
}
134+
118135
void Initialize() override {
119136
InitializeRegistry();
120137
}
@@ -133,3 +150,118 @@ ORT_API(onnxruntime::Provider*, GetProvider) {
133150
return &onnxruntime::g_provider;
134151
}
135152
}
153+
154+
#include "core/framework/error_code_helper.h"
155+
156+
// OrtEpApi infrastructure to be able to use the NvTensorRTRTX EP as an OrtEpFactory for auto EP selection.
157+
struct NvTensorRtRtxEpFactory : OrtEpFactory {
158+
NvTensorRtRtxEpFactory(const OrtApi& ort_api_in,
159+
const char* ep_name,
160+
OrtHardwareDeviceType hw_type)
161+
: ort_api{ort_api_in}, ep_name{ep_name}, ort_hw_device_type{hw_type} {
162+
GetName = GetNameImpl;
163+
GetVendor = GetVendorImpl;
164+
GetVersion = GetVersionImpl;
165+
GetSupportedDevices = GetSupportedDevicesImpl;
166+
CreateEp = CreateEpImpl;
167+
ReleaseEp = ReleaseEpImpl;
168+
}
169+
170+
// Returns the name for the EP. Each unique factory configuration must have a unique name.
171+
// Ex: a factory that supports NPU should have a different than a factory that supports GPU.
172+
static const char* GetNameImpl(const OrtEpFactory* this_ptr) {
173+
const auto* factory = static_cast<const NvTensorRtRtxEpFactory*>(this_ptr);
174+
return factory->ep_name.c_str();
175+
}
176+
177+
static const char* GetVendorImpl(const OrtEpFactory* this_ptr) {
178+
const auto* factory = static_cast<const NvTensorRtRtxEpFactory*>(this_ptr);
179+
return factory->vendor.c_str();
180+
}
181+
182+
static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* /*this_ptr*/) noexcept {
183+
return ORT_VERSION;
184+
}
185+
186+
// Creates and returns OrtEpDevice instances for all OrtHardwareDevices that this factory supports.
187+
// An EP created with this factory is expected to be able to execute a model with *all* supported
188+
// hardware devices at once. A single instance of NvTensorRtRtx EP is not currently setup to partition a model among
189+
// multiple different NvTensorRtRtx backends at once (e.g, npu, cpu, gpu), so this factory instance is set to only
190+
// support one backend: gpu. To support a different backend, like npu, create a different factory instance
191+
// that only supports NPU.
192+
static OrtStatus* GetSupportedDevicesImpl(OrtEpFactory* this_ptr,
193+
const OrtHardwareDevice* const* devices,
194+
size_t num_devices,
195+
OrtEpDevice** ep_devices,
196+
size_t max_ep_devices,
197+
size_t* p_num_ep_devices) {
198+
size_t& num_ep_devices = *p_num_ep_devices;
199+
auto* factory = static_cast<NvTensorRtRtxEpFactory*>(this_ptr);
200+
201+
for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) {
202+
const OrtHardwareDevice& device = *devices[i];
203+
if (factory->ort_api.HardwareDevice_Type(&device) == factory->ort_hw_device_type &&
204+
factory->ort_api.HardwareDevice_VendorId(&device) == factory->vendor_id) {
205+
OrtKeyValuePairs* ep_options = nullptr;
206+
factory->ort_api.CreateKeyValuePairs(&ep_options);
207+
ORT_API_RETURN_IF_ERROR(
208+
factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, nullptr, ep_options,
209+
&ep_devices[num_ep_devices++]));
210+
}
211+
}
212+
213+
return nullptr;
214+
}
215+
216+
static OrtStatus* CreateEpImpl(OrtEpFactory* /*this_ptr*/,
217+
_In_reads_(num_devices) const OrtHardwareDevice* const* /*devices*/,
218+
_In_reads_(num_devices) const OrtKeyValuePairs* const* /*ep_metadata*/,
219+
_In_ size_t /*num_devices*/,
220+
_In_ const OrtSessionOptions* /*session_options*/,
221+
_In_ const OrtLogger* /*logger*/,
222+
_Out_ OrtEp** /*ep*/) {
223+
return onnxruntime::CreateStatus(ORT_INVALID_ARGUMENT, "[NvTensorRTRTX EP] EP factory does not support this method.");
224+
}
225+
226+
static void ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* /*ep*/) {
227+
// no-op as we never create an EP here.
228+
}
229+
230+
const OrtApi& ort_api;
231+
const std::string ep_name;
232+
const std::string vendor{"NVIDIA"};
233+
234+
// NVIDIA vendor ID. Refer to the ACPI ID registry (search NVIDIA): https://uefi.org/ACPI_ID_List
235+
const uint32_t vendor_id{0x10de};
236+
const OrtHardwareDeviceType ort_hw_device_type; // Supported OrtHardwareDevice
237+
};
238+
239+
extern "C" {
240+
//
241+
// Public symbols
242+
//
243+
OrtStatus* CreateEpFactories(const char* /*registration_name*/, const OrtApiBase* ort_api_base,
244+
OrtEpFactory** factories, size_t max_factories, size_t* num_factories) {
245+
const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION);
246+
247+
// Factory could use registration_name or define its own EP name.
248+
auto factory_gpu = std::make_unique<NvTensorRtRtxEpFactory>(*ort_api,
249+
onnxruntime::kNvTensorRTRTXExecutionProvider,
250+
OrtHardwareDeviceType_GPU);
251+
252+
if (max_factories < 1) {
253+
return ort_api->CreateStatus(ORT_INVALID_ARGUMENT,
254+
"Not enough space to return EP factory. Need at least one.");
255+
}
256+
257+
factories[0] = factory_gpu.release();
258+
*num_factories = 1;
259+
260+
return nullptr;
261+
}
262+
263+
OrtStatus* ReleaseEpFactory(OrtEpFactory* factory) {
264+
delete static_cast<NvTensorRtRtxEpFactory*>(factory);
265+
return nullptr;
266+
}
267+
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
EXPORTS
22
GetProvider
3+
CreateEpFactories
4+
ReleaseEpFactory

onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,5 +389,47 @@ TYPED_TEST(NvExecutionProviderTest, IOTypeTests) {
389389
}
390390
}
391391

392+
static bool SessionHasEp(Ort::Session& session, const char* ep_name) {
393+
// Access the underlying InferenceSession.
394+
const OrtSession* ort_session = session;
395+
const InferenceSession* s = reinterpret_cast<const InferenceSession*>(ort_session);
396+
bool has_ep = false;
397+
398+
for (const auto& provider : s->GetRegisteredProviderTypes()) {
399+
if (provider == ep_name) {
400+
has_ep = true;
401+
break;
402+
}
403+
}
404+
return has_ep;
405+
}
406+
407+
#if defined(WIN32)
408+
// Tests autoEP feature to automatically select an EP that supports the GPU.
409+
// Currently only works on Windows.
410+
TEST(NvExecutionProviderTest, AutoEp_PreferGpu) {
411+
PathString model_name = ORT_TSTR("nv_execution_provider_data_dyn_test.onnx");
412+
std::string graph_name = "test";
413+
std::vector<int> dims = {1, -1, -1};
414+
415+
CreateBaseModel(model_name, graph_name, dims, true);
416+
417+
auto env = Ort::Env();
418+
auto logging_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING;
419+
env.UpdateEnvWithCustomLogLevel(logging_level);
420+
421+
{
422+
env.RegisterExecutionProviderLibrary(kNvTensorRTRTXExecutionProvider, ORT_TSTR("onnxruntime_providers_nv_tensorrt_rtx.dll"));
423+
424+
Ort::SessionOptions so;
425+
so.SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy_PREFER_GPU);
426+
Ort::Session session_object(env, model_name.c_str(), so);
427+
EXPECT_TRUE(SessionHasEp(session_object, kNvTensorRTRTXExecutionProvider));
428+
}
429+
430+
env.UnregisterExecutionProviderLibrary(kNvTensorRTRTXExecutionProvider);
431+
}
432+
#endif // defined(WIN32)
433+
392434
} // namespace test
393435
} // namespace onnxruntime

0 commit comments

Comments
 (0)