Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 36 additions & 11 deletions onnxruntime/core/providers/cann/cann_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1385,8 +1385,7 @@ Status CANNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fuse
HashValue hash;
cann::GenerateHashValue(input_shape, hash);
std::string filename = cann_state->node_name + "_" + std::to_string(hash);
std::string filename_with_suffix = filename + ".om";

bool dynamic_shape = false;
// TODO(FFFrog): Resource Management
// It is very necessary to provide a new mechanism for memory reclamation to avoid inference failure caused by
// device memory exhaustion
Expand All @@ -1395,8 +1394,8 @@ Status CANNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fuse
modelID = modelIDs_[filename];
} else {
std::lock_guard<std::mutex> lock(g_mutex);

if (cann::FileExist(filename_with_suffix)) {
auto filename_with_suffix = cann::RegexMatchFile(filename);
if (!filename_with_suffix.empty()) {
CANN_RETURN_IF_ERROR(aclmdlLoadFromFile(filename_with_suffix.c_str(), &modelID));
} else {
ge::Graph graph{cann_state->node_name.c_str()};
Expand Down Expand Up @@ -1424,20 +1423,46 @@ Status CANNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fuse
for (size_t i = 0; i < aclmdlGetNumOutputs(prepare.modelDesc_); i++) {
aclmdlIODims dims;
CANN_CALL_THROW(aclmdlGetOutputDims(prepare.modelDesc_, i, &dims));
std::vector<int64_t> vec{dims.dims, dims.dims + dims.dimCount};
auto output = ctx.GetOutput(i, vec);
CANN_MODEL_PREPARE_OUTPUTBUFFER(prepare,
const_cast<void*>(output.GetTensorRawData()),
aclmdlGetOutputSizeByIndex(prepare.modelDesc_, i));

if (cann::is_dynamic_shape(dims)) {
CANN_MODEL_PREPARE_OUTPUTBUFFER(prepare, nullptr, 0);
dynamic_shape = true;
} else {
std::vector<int64_t> vec{dims.dims, dims.dims + dims.dimCount};
auto output = ctx.GetOutput(i, vec);
CANN_MODEL_PREPARE_OUTPUTBUFFER(prepare,
const_cast<void*>(output.GetTensorRawData()),
aclmdlGetOutputSizeByIndex(prepare.modelDesc_, i));
}
}
}
ORT_CATCH(const std::exception& e) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, e.what());
}

aclrtStream stream = static_cast<aclrtStream>(ctx.GetGPUComputeStream());
CANN_RETURN_IF_ERROR(aclmdlExecuteAsync(modelID, prepare.inputSet_, prepare.outputSet_, stream));

if (dynamic_shape) {
aclrtSynchronizeStream(stream);
CANN_RETURN_IF_ERROR(aclmdlExecute(modelID, prepare.inputSet_, prepare.outputSet_));
for (size_t i = 0; i < aclmdlGetNumOutputs(prepare.modelDesc_); i++) {
std::vector<int64_t> shape;
aclTensorDesc* desc = aclmdlGetDatasetTensorDesc(prepare.outputSet_, i);
size_t num_dims = aclGetTensorDescNumDims(desc);
shape.reserve(num_dims);
for (size_t j = 0; j < num_dims; j++) {
int64_t dim;
CANN_RETURN_IF_ERROR(aclGetTensorDescDimV2(desc, j, &dim));
shape.push_back(dim);
}
aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(prepare.outputSet_, i);
void* src_data = aclGetDataBufferAddr(dataBuffer);
void* dst_data = const_cast<void*>(ctx.GetOutput(i, shape).GetTensorRawData());
size_t count = aclGetTensorDescSize(desc);
CANN_CALL_THROW(aclrtMemcpyAsync(dst_data, count, src_data, count, ACL_MEMCPY_DEVICE_TO_DEVICE, stream));
}
} else {
CANN_RETURN_IF_ERROR(aclmdlExecuteAsync(modelID, prepare.inputSet_, prepare.outputSet_, stream));
}
return Status::OK();
};

Expand Down
18 changes: 18 additions & 0 deletions onnxruntime/core/providers/cann/cann_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,5 +224,23 @@
hash_value = hash[0] | (uint64_t(hash[1]) << 32);
}

bool is_dynamic_shape(const aclmdlIODims& dims) {
return std::find(dims.dims, dims.dims + dims.dimCount, -1) != dims.dims + dims.dimCount;
}

namespace fs = std::filesystem;
std::string RegexMatchFile(const std::string& file_name) {
fs::path current_dir = fs::current_path();
std::regex pattern(file_name);
for (const auto& entry : fs::directory_iterator(current_dir)) {
if (entry.is_regular_file()) {
std::string name = entry.path().filename().string();

Check warning on line 237 in onnxruntime/core/providers/cann/cann_utils.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/cann/cann_utils.cc:237: Add #include <string> for string [build/include_what_you_use] [4]
if (std::regex_search(name, pattern)) {
return name;
}
}
}
return "";
}
} // namespace cann
} // namespace onnxruntime
5 changes: 4 additions & 1 deletion onnxruntime/core/providers/cann/cann_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include <iomanip>
#include <string>
#include <memory>
#include <filesystem>

Check warning on line 12 in onnxruntime/core/providers/cann/cann_utils.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 <filesystem> is an unapproved C++17 header. [build/c++17] [5] Raw Output: onnxruntime/core/providers/cann/cann_utils.h:12: <filesystem> is an unapproved C++17 header. [build/c++17] [5]
#include <regex>

#include "core/framework/murmurhash3.h"
#include "core/providers/cann/cann_common.h"
Expand Down Expand Up @@ -124,7 +126,8 @@

bool FileExist(const std::string& file_name);
void GenerateHashValue(const std::string string, HashValue& hash_value);

bool is_dynamic_shape(const aclmdlIODims& dims);
std::string RegexMatchFile(const std::string& file_name);
std::unique_ptr<Model> CreateModel(const GraphViewer& graph_viewer, const logging::Logger& logger);

} // namespace cann
Expand Down
Loading