Skip to content

Commit 25d1908

Browse files
bachelor-douSanket Kale
authored andcommitted
[CANN]Fix issue with negative dynamic tensor shape (microsoft#25431)
### Description <!-- Describe your changes. --> The error is: ``` ..2025-07-17 11:21:36.861835596 [E:onnxruntime:, sequential_executor.cc:572 ExecuteKernel] Non-zero status code returned while running main_graph_11957213504832792607_0 node. Name:'CANNExecutionProvider_main_graph_11957213504832792607_0_0' Status Message: ~/code/onnxruntime/onnxruntime/core/framework/op_kernel.cc:83 virtual OrtValue* onnxruntime::OpKernelContext::OutputMLValue(int, const onnxruntime::TensorShape&) status.IsOK() was false. tensor.cc:57 CalculateTensorStorageSize Tensor shape.Size() must be >= 0 [ONNXRuntimeError] : 1 : FAIL : Non-zero status code returned while running main_graph_11957213504832792607_0 node. Name:'CANNExecutionProvider_main_graph_11957213504832792607_0_0' Status Message: ~/code/onnxruntime/onnxruntime/core/framework/op_kernel.cc:83 virtual OrtValue* onnxruntime::OpKernelContext::OutputMLValue(int, const onnxruntime::TensorShape&) status.IsOK() was false. tensor.cc:57 CalculateTensorStorageSize Tensor shape.Size() must be >= 0 ```
1 parent e575ebe commit 25d1908

File tree

3 files changed

+58
-12
lines changed

3 files changed

+58
-12
lines changed

onnxruntime/core/providers/cann/cann_execution_provider.cc

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1385,8 +1385,7 @@ Status CANNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fuse
13851385
HashValue hash;
13861386
cann::GenerateHashValue(input_shape, hash);
13871387
std::string filename = cann_state->node_name + "_" + std::to_string(hash);
1388-
std::string filename_with_suffix = filename + ".om";
1389-
1388+
bool dynamic_shape = false;
13901389
// TODO(FFFrog): Resource Management
13911390
// It is very necessary to provide a new mechanism for memory reclamation to avoid inference failure caused by
13921391
// device memory exhaustion
@@ -1395,8 +1394,8 @@ Status CANNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fuse
13951394
modelID = modelIDs_[filename];
13961395
} else {
13971396
std::lock_guard<std::mutex> lock(g_mutex);
1398-
1399-
if (cann::FileExist(filename_with_suffix)) {
1397+
auto filename_with_suffix = cann::RegexMatchFile(filename);
1398+
if (!filename_with_suffix.empty()) {
14001399
CANN_RETURN_IF_ERROR(aclmdlLoadFromFile(filename_with_suffix.c_str(), &modelID));
14011400
} else {
14021401
ge::Graph graph{cann_state->node_name.c_str()};
@@ -1424,20 +1423,46 @@ Status CANNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fuse
14241423
for (size_t i = 0; i < aclmdlGetNumOutputs(prepare.modelDesc_); i++) {
14251424
aclmdlIODims dims;
14261425
CANN_CALL_THROW(aclmdlGetOutputDims(prepare.modelDesc_, i, &dims));
1427-
std::vector<int64_t> vec{dims.dims, dims.dims + dims.dimCount};
1428-
auto output = ctx.GetOutput(i, vec);
1429-
CANN_MODEL_PREPARE_OUTPUTBUFFER(prepare,
1430-
const_cast<void*>(output.GetTensorRawData()),
1431-
aclmdlGetOutputSizeByIndex(prepare.modelDesc_, i));
1426+
1427+
if (cann::is_dynamic_shape(dims)) {
1428+
CANN_MODEL_PREPARE_OUTPUTBUFFER(prepare, nullptr, 0);
1429+
dynamic_shape = true;
1430+
} else {
1431+
std::vector<int64_t> vec{dims.dims, dims.dims + dims.dimCount};
1432+
auto output = ctx.GetOutput(i, vec);
1433+
CANN_MODEL_PREPARE_OUTPUTBUFFER(prepare,
1434+
const_cast<void*>(output.GetTensorRawData()),
1435+
aclmdlGetOutputSizeByIndex(prepare.modelDesc_, i));
1436+
}
14321437
}
14331438
}
14341439
ORT_CATCH(const std::exception& e) {
14351440
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, e.what());
14361441
}
14371442

14381443
aclrtStream stream = static_cast<aclrtStream>(ctx.GetGPUComputeStream());
1439-
CANN_RETURN_IF_ERROR(aclmdlExecuteAsync(modelID, prepare.inputSet_, prepare.outputSet_, stream));
1440-
1444+
if (dynamic_shape) {
1445+
aclrtSynchronizeStream(stream);
1446+
CANN_RETURN_IF_ERROR(aclmdlExecute(modelID, prepare.inputSet_, prepare.outputSet_));
1447+
for (size_t i = 0; i < aclmdlGetNumOutputs(prepare.modelDesc_); i++) {
1448+
std::vector<int64_t> shape;
1449+
aclTensorDesc* desc = aclmdlGetDatasetTensorDesc(prepare.outputSet_, i);
1450+
size_t num_dims = aclGetTensorDescNumDims(desc);
1451+
shape.reserve(num_dims);
1452+
for (size_t j = 0; j < num_dims; j++) {
1453+
int64_t dim;
1454+
CANN_RETURN_IF_ERROR(aclGetTensorDescDimV2(desc, j, &dim));
1455+
shape.push_back(dim);
1456+
}
1457+
aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(prepare.outputSet_, i);
1458+
void* src_data = aclGetDataBufferAddr(dataBuffer);
1459+
void* dst_data = const_cast<void*>(ctx.GetOutput(i, shape).GetTensorRawData());
1460+
size_t count = aclGetTensorDescSize(desc);
1461+
CANN_CALL_THROW(aclrtMemcpyAsync(dst_data, count, src_data, count, ACL_MEMCPY_DEVICE_TO_DEVICE, stream));
1462+
}
1463+
} else {
1464+
CANN_RETURN_IF_ERROR(aclmdlExecuteAsync(modelID, prepare.inputSet_, prepare.outputSet_, stream));
1465+
}
14411466
return Status::OK();
14421467
};
14431468

onnxruntime/core/providers/cann/cann_utils.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,5 +224,23 @@ void GenerateHashValue(const std::string string, HashValue& hash_value) {
224224
hash_value = hash[0] | (uint64_t(hash[1]) << 32);
225225
}
226226

227+
bool is_dynamic_shape(const aclmdlIODims& dims) {
228+
return std::find(dims.dims, dims.dims + dims.dimCount, -1) != dims.dims + dims.dimCount;
229+
}
230+
231+
namespace fs = std::filesystem;
232+
std::string RegexMatchFile(const std::string& file_name) {
233+
fs::path current_dir = fs::current_path();
234+
std::regex pattern(file_name);
235+
for (const auto& entry : fs::directory_iterator(current_dir)) {
236+
if (entry.is_regular_file()) {
237+
std::string name = entry.path().filename().string();
238+
if (std::regex_search(name, pattern)) {
239+
return name;
240+
}
241+
}
242+
}
243+
return "";
244+
}
227245
} // namespace cann
228246
} // namespace onnxruntime

onnxruntime/core/providers/cann/cann_utils.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#include <iomanip>
1010
#include <string>
1111
#include <memory>
12+
#include <filesystem>
13+
#include <regex>
1214

1315
#include "core/framework/murmurhash3.h"
1416
#include "core/providers/cann/cann_common.h"
@@ -124,7 +126,8 @@ Status aclrtblasGemmEx(aclTransType transA,
124126

125127
bool FileExist(const std::string& file_name);
126128
void GenerateHashValue(const std::string string, HashValue& hash_value);
127-
129+
bool is_dynamic_shape(const aclmdlIODims& dims);
130+
std::string RegexMatchFile(const std::string& file_name);
128131
std::unique_ptr<Model> CreateModel(const GraphViewer& graph_viewer, const logging::Logger& logger);
129132

130133
} // namespace cann

0 commit comments

Comments
 (0)