Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2606,12 +2606,34 @@ common::Status InferenceSession::CheckShapes(const std::string& input_output_nam
" Please fix either the inputs/outputs or the model.");
}

// Helper function to check whether QNN EP is used & all nodes are assigned to QNN EP,
// and relax the constraint to support batch multiplier on the first dimension.
// We will check whether only the Htp backend is used inside QnnModel::ExecuteGraph.
auto is_qnn_batch_multiplier_valid = [this](int64_t input_dim, int64_t expected_dim, const Graph& graph) -> bool {
// check if QNN EP is used
if (execution_providers_.Get(kQnnExecutionProvider) == nullptr) {
return false;
}
// check if all nodes are assigned to QNN EP
for (const auto& node : graph.Nodes()) {
const auto& node_provider = node.GetExecutionProviderType();
if (node_provider.empty() || node_provider != kQnnExecutionProvider) {
return false;
}
}

if (expected_dim <= 0) return false;

return (input_dim % expected_dim == 0);
};

InlinedVector<size_t> invalid_dim_indices;
for (size_t i = 0; i < shape_size; ++i) {
if (expected_shape[i] < 0) {
continue; // this represents a symbolic shape dimension
}
if (input_output_shape[i] != expected_shape[i]) {
} else if (i == 0 && is_qnn_batch_multiplier_valid(input_output_shape[i], expected_shape[i], model_->MainGraph())) {
continue; // Qnn API supports batch multiplier, but the running batch size must be divisible by the original batch size.
Comment on lines +2634 to +2635
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want to add QNN EP-specific relaxing of shape validation here. if the graph was not running on the QNN EP, would it be considered invalid? can you elaborate on what you are trying to do?

} else if (input_output_shape[i] != expected_shape[i]) {
invalid_dim_indices.push_back(i);
}
}
Expand Down