Skip to content

Commit f04c942

Browse files
nieubankjeffkilpatrick
authored andcommitted
[VitisAI] Upstream changes from win-ort (microsoft#25448)
1 parent a02bb9c commit f04c942

File tree

3 files changed

+154
-4
lines changed

3 files changed

+154
-4
lines changed

onnxruntime/core/providers/vitisai/imp/global_api.cc

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,35 @@ using namespace onnxruntime;
4242
#define LIBRARY_EXTENSION ".so"
4343
#endif
4444

45+
/// @brief Gets the path of directory containing the dynamic library that contains the address.
46+
/// @param address An address of a function or variable in the dynamic library.
47+
/// @return The path of the directory containing the dynamic library, or an empty string if the path cannot be determined.
48+
static onnxruntime::PathString GetDynamicLibraryLocationByAddress(const void* address) {
49+
#ifdef _WIN32
50+
HMODULE moduleHandle;
51+
if (!::GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
52+
reinterpret_cast<LPCWSTR>(address), &moduleHandle)) {
53+
return {};
54+
}
55+
std::wstring buffer;
56+
for (std::uint32_t size{70}; size < 4096; size *= 2) {
57+
buffer.resize(size, L'\0');
58+
const std::uint32_t requiredSize = ::GetModuleFileNameW(moduleHandle, buffer.data(), size);
59+
if (requiredSize == 0) {
60+
break;
61+
}
62+
if (requiredSize == size) {
63+
continue;
64+
}
65+
buffer.resize(requiredSize);
66+
return {std::move(buffer)};
67+
}
68+
#else
69+
std::ignore = address;
70+
#endif
71+
return {};
72+
}
73+
4574
vaip_core::OrtApiForVaip* create_org_api_hook();
4675
struct OrtVitisAIEpAPI {
4776
void (*initialize_onnxruntime_vitisai_ep)(vaip_core::OrtApiForVaip* api, std::vector<OrtCustomOpDomain*>& ret_domain);
@@ -74,8 +103,20 @@ struct OrtVitisAIEpAPI {
74103
// this dll is already linked to the executable, normally a test program
75104
handle_ = reinterpret_cast<void*>(GetModuleHandle(TEXT("onnxruntime_vitisai_ep.dll")));
76105
if (!handle_) {
106+
// First try loading with full path
107+
auto library_filename = PathString(LIBRARY_PREFIX ORT_TSTR("onnxruntime_vitisai_ep") LIBRARY_EXTENSION);
77108
auto full_path = env.GetRuntimePath() + PathString(LIBRARY_PREFIX ORT_TSTR("onnxruntime_vitisai_ep") LIBRARY_EXTENSION);
78-
ORT_THROW_IF_ERROR(env.LoadDynamicLibrary(full_path, true, &handle_));
109+
if (std::filesystem::exists(full_path)) {
110+
ORT_THROW_IF_ERROR(env.LoadDynamicLibrary(full_path, true, &handle_));
111+
} else {
112+
// Identify the path of the current dynamic library, and expect that onnxruntime_vitisai_ep is in the same directory.
113+
PathString current_path = GetDynamicLibraryLocationByAddress(reinterpret_cast<const void*>(create_org_api_hook));
114+
if (!current_path.empty()) {
115+
const std::filesystem::path parent_path = std::filesystem::path{std::move(current_path)}.parent_path();
116+
PathString module_relative_full_path = PathString(parent_path / library_filename);
117+
ORT_THROW_IF_ERROR(env.LoadDynamicLibrary(module_relative_full_path, true, &handle_));
118+
}
119+
}
79120
}
80121
#else
81122
auto full_path = env.GetRuntimePath() + PathString(LIBRARY_PREFIX ORT_TSTR("onnxruntime_vitisai_ep") LIBRARY_EXTENSION);
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
EXPORTS
22
GetProvider
3+
CreateEpFactories
4+
ReleaseEpFactory

onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc

Lines changed: 110 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,6 @@ std::unique_ptr<IExecutionProvider> VitisAIProviderFactory::CreateProvider(const
5757
}
5858
}
5959

60-
// Store pointer to session options as done in SessionOptionsAppendExecutionProvider_VitisAI
61-
provider_options["session_options"] = std::to_string((uintptr_t)(void*)&session_options);
62-
6360
auto ep_instance = std::make_unique<VitisAIExecutionProvider>(provider_options);
6461
ep_instance->SetLogger(reinterpret_cast<const logging::Logger*>(&session_logger));
6562
return ep_instance;
@@ -89,13 +86,123 @@ struct VitisAI_Provider : Provider {
8986
void Initialize() override { initialize_vitisai_ep(); }
9087
// Called right before unloading the shared library
9188
void Shutdown() override { deinitialize_vitisai_ep(); }
89+
90+
Status CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/,
91+
const OrtKeyValuePairs* const* /*ep_metadata*/,
92+
size_t /*num_devices*/,
93+
ProviderOptions& provider_options,
94+
const OrtSessionOptions& session_options,
95+
const OrtLogger& logger,
96+
std::unique_ptr<IExecutionProvider>& ep) override {
97+
auto ep_factory = CreateExecutionProviderFactory(&provider_options);
98+
ep = ep_factory->CreateProvider(session_options, logger);
99+
return Status::OK();
100+
}
92101
} g_provider;
93102

103+
struct VitisAIEpFactory : OrtEpFactory {
104+
VitisAIEpFactory(const OrtApi& ort_api_in)
105+
: ort_api{ort_api_in} {
106+
ort_version_supported = ORT_API_VERSION;
107+
GetName = GetNameImpl;
108+
GetVendor = GetVendorImpl;
109+
GetVendorId = GetVendorIdImpl;
110+
GetVersion = GetVersionImpl;
111+
GetSupportedDevices = GetSupportedDevicesImpl;
112+
CreateEp = CreateEpImpl;
113+
ReleaseEp = ReleaseEpImpl;
114+
}
115+
116+
static const char* GetNameImpl(const OrtEpFactory* /*this_ptr*/) noexcept {
117+
return ep_name;
118+
}
119+
120+
static const char* GetVendorImpl(const OrtEpFactory* /*this_ptr*/) noexcept {
121+
return vendor;
122+
}
123+
124+
static uint32_t GetVendorIdImpl(const OrtEpFactory* /*this_ptr*/) noexcept {
125+
return hardware_vendor_id;
126+
}
127+
128+
static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* /*this_ptr*/) noexcept {
129+
return ORT_VERSION;
130+
}
131+
132+
static OrtStatus* GetSupportedDevicesImpl(OrtEpFactory* ep_factory,
133+
const OrtHardwareDevice* const* devices,
134+
size_t num_devices,
135+
OrtEpDevice** ep_devices,
136+
size_t max_ep_devices,
137+
size_t* p_num_ep_devices) noexcept {
138+
size_t& num_ep_devices = *p_num_ep_devices;
139+
VitisAIEpFactory* factory = static_cast<VitisAIEpFactory*>(ep_factory);
140+
141+
for (size_t i = 0; i < num_devices; ++i) {
142+
const OrtHardwareDevice* hardware_device = devices[i];
143+
const std::uint32_t vendor_id = factory->ort_api.HardwareDevice_VendorId(hardware_device);
144+
const OrtHardwareDeviceType device_type = factory->ort_api.HardwareDevice_Type(hardware_device);
145+
146+
if ((vendor_id != VitisAIEpFactory::hardware_vendor_id) ||
147+
(device_type != OrtHardwareDeviceType_NPU)) {
148+
continue;
149+
}
150+
151+
if (num_ep_devices == max_ep_devices) {
152+
return factory->ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "Not enough space to return EP devices.");
153+
}
154+
155+
auto status = factory->ort_api.GetEpApi()->CreateEpDevice(factory, hardware_device, nullptr, nullptr,
156+
&ep_devices[num_ep_devices++]);
157+
if (status != nullptr) {
158+
return status;
159+
}
160+
}
161+
return nullptr;
162+
}
163+
164+
static OrtStatus* CreateEpImpl(OrtEpFactory* /*this_ptr*/,
165+
_In_reads_(num_devices) const OrtHardwareDevice* const* /*devices*/,
166+
_In_reads_(num_devices) const OrtKeyValuePairs* const* /*ep_metadata*/,
167+
_In_ size_t /*num_devices*/,
168+
_In_ const OrtSessionOptions* /*session_options*/,
169+
_In_ const OrtLogger* /*logger*/,
170+
_Out_ OrtEp** /*ep*/) noexcept {
171+
return CreateStatus(ORT_INVALID_ARGUMENT, "VitisAI EP factory does not support this method.");
172+
}
173+
174+
static void ReleaseEpImpl(OrtEpFactory*, OrtEp*) noexcept {
175+
// no-op as we never create an EP here.
176+
}
177+
178+
const OrtApi& ort_api;
179+
static constexpr const char* const ep_name{kVitisAIExecutionProvider};
180+
static constexpr std::uint32_t hardware_vendor_id{0x1022};
181+
static constexpr const char* const vendor{"AMD"};
182+
};
183+
94184
} // namespace onnxruntime
95185

96186
extern "C" {
97187

98188
ORT_API(onnxruntime::Provider*, GetProvider) {
99189
return &onnxruntime::g_provider;
100190
}
191+
192+
OrtStatus* CreateEpFactories(const char* /*registration_name*/, const OrtApiBase* ort_api_base,
193+
OrtEpFactory** factories, size_t max_factories, size_t* num_factories) {
194+
const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION);
195+
if (max_factories < 1) {
196+
return ort_api->CreateStatus(ORT_INVALID_ARGUMENT,
197+
"Not enough space to return EP factory. Need at least one.");
198+
}
199+
factories[0] = std::make_unique<onnxruntime::VitisAIEpFactory>(*ort_api).release();
200+
*num_factories = 1;
201+
return nullptr;
202+
}
203+
204+
OrtStatus* ReleaseEpFactory(OrtEpFactory* factory) {
205+
delete static_cast<onnxruntime::VitisAIEpFactory*>(factory);
206+
return nullptr;
207+
}
101208
}

0 commit comments

Comments
 (0)