Skip to content

Commit 44dc2c8

Browse files
[GPU] Gracefully handle zero batch size (#34515)
### Details: - Avoid division by zero error when creating memory descriptor for GPU plugin ### Tickets: - closes #24243 ### AI Assistance: - *AI assistance used: yes* - To generate tests
1 parent 93835b8 commit 44dc2c8

File tree

2 files changed

+38
-2
lines changed

2 files changed

+38
-2
lines changed

src/plugins/intel_gpu/src/graph/impls/onednn/utils.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,12 +351,22 @@ class MemoryDescriptorBuilder {
351351
dnnl::memory::dims dims;
352352
auto fmt_tag = _target_fmt;
353353

354+
auto count_except = [this](size_t skip_idx) -> int64_t {
355+
const auto& raw = _layout.get_tensor().raw;
356+
int64_t result = 1;
357+
for (size_t i = 0; i < raw.size(); ++i) {
358+
if (i != skip_idx)
359+
result *= static_cast<int64_t>(raw[i]);
360+
}
361+
return result;
362+
};
363+
354364
if (fmt_tag == dnnl::memory::format_tag::ab && _flatten) {
355365
dims = flatten_tensor(_layout.get_tensor());
356366
dims.insert(dims.begin(), 1);
357367
} else if (fmt_tag == dnnl::memory::format_tag::ab) {
358368
dims.push_back(_layout.batch());
359-
dims.push_back(_layout.get_tensor().count() / _layout.batch());
369+
dims.push_back(count_except(0));
360370
} else if (fmt_tag == dnnl::memory::format_tag::abc) {
361371
dims.push_back(_layout.batch());
362372
dims.push_back(_layout.feature());
@@ -385,7 +395,7 @@ class MemoryDescriptorBuilder {
385395
dims.push_back(_layout.spatial(1));
386396
} else if (fmt_tag == dnnl::memory::format_tag::ba) {
387397
dims.push_back(_layout.feature());
388-
dims.push_back(_layout.get_tensor().count() / _layout.feature());
398+
dims.push_back(count_except(1));
389399
} else if (_flatten) {
390400
dims = flatten_tensor(_layout.get_tensor());
391401
} else {

src/plugins/intel_gpu/tests/unit/onednn/utils_test.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,3 +533,29 @@ TEST_F(test_layout_to_memory_desc, regression_3d_shape_format_selection) {
533533
EXPECT_EQ(get_format_tag_from_desc(desc_byxf), dnnl::memory::format_tag::acb);
534534
// Format tags should be different (abc vs acb)
535535
}
536+
537+
TEST_F(test_layout_to_memory_desc, zero_batch_ab_format) {
538+
// batch=0 should not cause division-by-zero; inner dims should be preserved
539+
layout l = layout{ov::PartialShape{0, 256}, data_types::f32, format::bfyx};
540+
auto desc = layout_to_memory_desc(l, dnnl::memory::format_tag::ab);
541+
EXPECT_EQ(desc.get_ndims(), 2);
542+
EXPECT_EQ(desc.get_dims()[0], 0);
543+
EXPECT_EQ(desc.get_dims()[1], 256);
544+
}
545+
546+
TEST_F(test_layout_to_memory_desc, zero_feature_ba_format) {
547+
// feature=0 should not cause division-by-zero; inner dims should be preserved
548+
layout l = layout{ov::PartialShape{256, 0}, data_types::f32, format::bfyx};
549+
auto desc = layout_to_memory_desc(l, dnnl::memory::format_tag::ba);
550+
EXPECT_EQ(desc.get_ndims(), 2);
551+
EXPECT_EQ(desc.get_dims()[0], 0);
552+
EXPECT_EQ(desc.get_dims()[1], 256);
553+
}
554+
555+
TEST_F(test_layout_to_memory_desc, zero_batch_default_format) {
556+
// All-zero batch with default format should not crash
557+
layout l = layout{ov::PartialShape{0, 64, 32, 32}, data_types::f16, format::bfyx};
558+
EXPECT_NO_THROW({
559+
auto desc = layout_to_memory_desc(l, dnnl::memory::format_tag::undef);
560+
});
561+
}

0 commit comments

Comments
 (0)