@@ -115,6 +115,23 @@ struct Nv_Provider : Provider {
115
115
return std::make_shared<NvProviderFactory>(info);
116
116
}
117
117
118
+ Status CreateIExecutionProvider (const OrtHardwareDevice* const * /* devices*/ ,
119
+ const OrtKeyValuePairs* const * /* ep_metadata*/ ,
120
+ size_t /* num_devices*/ ,
121
+ ProviderOptions& provider_options,
122
+ const OrtSessionOptions& session_options,
123
+ const OrtLogger& logger,
124
+ std::unique_ptr<IExecutionProvider>& ep) override {
125
+ const ConfigOptions* config_options = &session_options.GetConfigOptions ();
126
+
127
+ std::array<const void *, 2 > configs_array = {&provider_options, config_options};
128
+ const void * arg = reinterpret_cast <const void *>(&configs_array);
129
+ auto ep_factory = CreateExecutionProviderFactory (arg);
130
+ ep = ep_factory->CreateProvider (session_options, logger);
131
+
132
+ return Status::OK ();
133
+ }
134
+
118
135
void Initialize () override {
119
136
InitializeRegistry ();
120
137
}
@@ -133,3 +150,118 @@ ORT_API(onnxruntime::Provider*, GetProvider) {
133
150
return &onnxruntime::g_provider;
134
151
}
135
152
}
153
+
154
+ #include " core/framework/error_code_helper.h"
155
+
156
+ // OrtEpApi infrastructure to be able to use the NvTensorRTRTX EP as an OrtEpFactory for auto EP selection.
157
+ struct NvTensorRtRtxEpFactory : OrtEpFactory {
158
+ NvTensorRtRtxEpFactory (const OrtApi& ort_api_in,
159
+ const char * ep_name,
160
+ OrtHardwareDeviceType hw_type)
161
+ : ort_api{ort_api_in}, ep_name{ep_name}, ort_hw_device_type{hw_type} {
162
+ GetName = GetNameImpl;
163
+ GetVendor = GetVendorImpl;
164
+ GetVersion = GetVersionImpl;
165
+ GetSupportedDevices = GetSupportedDevicesImpl;
166
+ CreateEp = CreateEpImpl;
167
+ ReleaseEp = ReleaseEpImpl;
168
+ }
169
+
170
+ // Returns the name for the EP. Each unique factory configuration must have a unique name.
171
+ // Ex: a factory that supports NPU should have a different than a factory that supports GPU.
172
+ static const char * GetNameImpl (const OrtEpFactory* this_ptr) {
173
+ const auto * factory = static_cast <const NvTensorRtRtxEpFactory*>(this_ptr);
174
+ return factory->ep_name .c_str ();
175
+ }
176
+
177
+ static const char * GetVendorImpl (const OrtEpFactory* this_ptr) {
178
+ const auto * factory = static_cast <const NvTensorRtRtxEpFactory*>(this_ptr);
179
+ return factory->vendor .c_str ();
180
+ }
181
+
182
+ static const char * ORT_API_CALL GetVersionImpl (const OrtEpFactory* /* this_ptr*/ ) noexcept {
183
+ return ORT_VERSION;
184
+ }
185
+
186
+ // Creates and returns OrtEpDevice instances for all OrtHardwareDevices that this factory supports.
187
+ // An EP created with this factory is expected to be able to execute a model with *all* supported
188
+ // hardware devices at once. A single instance of NvTensorRtRtx EP is not currently setup to partition a model among
189
+ // multiple different NvTensorRtRtx backends at once (e.g, npu, cpu, gpu), so this factory instance is set to only
190
+ // support one backend: gpu. To support a different backend, like npu, create a different factory instance
191
+ // that only supports NPU.
192
+ static OrtStatus* GetSupportedDevicesImpl (OrtEpFactory* this_ptr,
193
+ const OrtHardwareDevice* const * devices,
194
+ size_t num_devices,
195
+ OrtEpDevice** ep_devices,
196
+ size_t max_ep_devices,
197
+ size_t * p_num_ep_devices) {
198
+ size_t & num_ep_devices = *p_num_ep_devices;
199
+ auto * factory = static_cast <NvTensorRtRtxEpFactory*>(this_ptr);
200
+
201
+ for (size_t i = 0 ; i < num_devices && num_ep_devices < max_ep_devices; ++i) {
202
+ const OrtHardwareDevice& device = *devices[i];
203
+ if (factory->ort_api .HardwareDevice_Type (&device) == factory->ort_hw_device_type &&
204
+ factory->ort_api .HardwareDevice_VendorId (&device) == factory->vendor_id ) {
205
+ OrtKeyValuePairs* ep_options = nullptr ;
206
+ factory->ort_api .CreateKeyValuePairs (&ep_options);
207
+ ORT_API_RETURN_IF_ERROR (
208
+ factory->ort_api .GetEpApi ()->CreateEpDevice (factory, &device, nullptr , ep_options,
209
+ &ep_devices[num_ep_devices++]));
210
+ }
211
+ }
212
+
213
+ return nullptr ;
214
+ }
215
+
216
+ static OrtStatus* CreateEpImpl (OrtEpFactory* /* this_ptr*/ ,
217
+ _In_reads_ (num_devices) const OrtHardwareDevice* const * /* devices*/ ,
218
+ _In_reads_(num_devices) const OrtKeyValuePairs* const * /* ep_metadata*/ ,
219
+ _In_ size_t /* num_devices*/ ,
220
+ _In_ const OrtSessionOptions* /* session_options*/ ,
221
+ _In_ const OrtLogger* /* logger*/ ,
222
+ _Out_ OrtEp** /* ep*/ ) {
223
+ return onnxruntime::CreateStatus (ORT_INVALID_ARGUMENT, " [NvTensorRTRTX EP] EP factory does not support this method." );
224
+ }
225
+
226
+ static void ReleaseEpImpl (OrtEpFactory* /* this_ptr*/ , OrtEp* /* ep*/ ) {
227
+ // no-op as we never create an EP here.
228
+ }
229
+
230
+ const OrtApi& ort_api;
231
+ const std::string ep_name;
232
+ const std::string vendor{" NVIDIA" };
233
+
234
+ // NVIDIA vendor ID. Refer to the ACPI ID registry (search NVIDIA): https://uefi.org/ACPI_ID_List
235
+ const uint32_t vendor_id{0x10de };
236
+ const OrtHardwareDeviceType ort_hw_device_type; // Supported OrtHardwareDevice
237
+ };
238
+
239
+ extern " C" {
240
+ //
241
+ // Public symbols
242
+ //
243
+ OrtStatus* CreateEpFactories (const char * /* registration_name*/ , const OrtApiBase* ort_api_base,
244
+ OrtEpFactory** factories, size_t max_factories, size_t * num_factories) {
245
+ const OrtApi* ort_api = ort_api_base->GetApi (ORT_API_VERSION);
246
+
247
+ // Factory could use registration_name or define its own EP name.
248
+ auto factory_gpu = std::make_unique<NvTensorRtRtxEpFactory>(*ort_api,
249
+ onnxruntime::kNvTensorRTRTXExecutionProvider ,
250
+ OrtHardwareDeviceType_GPU);
251
+
252
+ if (max_factories < 1 ) {
253
+ return ort_api->CreateStatus (ORT_INVALID_ARGUMENT,
254
+ " Not enough space to return EP factory. Need at least one." );
255
+ }
256
+
257
+ factories[0 ] = factory_gpu.release ();
258
+ *num_factories = 1 ;
259
+
260
+ return nullptr ;
261
+ }
262
+
263
+ OrtStatus* ReleaseEpFactory (OrtEpFactory* factory) {
264
+ delete static_cast <NvTensorRtRtxEpFactory*>(factory);
265
+ return nullptr ;
266
+ }
267
+ }
0 commit comments