From e1dceaea78f40d1e8300597bf39dc5ba986277f9 Mon Sep 17 00:00:00 2001 From: "Simonov, Alexander" Date: Sun, 8 Mar 2026 11:01:04 -0700 Subject: [PATCH 1/2] common: deduce runtime value marker from types --- src/common/gemm_types.hpp | 4 +- src/common/matmul.cpp | 6 +- src/common/memory.cpp | 6 +- src/common/memory_desc.cpp | 20 +++---- src/common/memory_desc_wrapper.cpp | 8 +-- src/common/memory_desc_wrapper.hpp | 8 +-- src/common/memory_zero_pad.cpp | 2 +- src/common/primitive_attr.cpp | 4 +- src/common/type_helpers.hpp | 29 +++------ src/common/utils.hpp | 70 ++++++++++++++++++++++ src/common/verbose.cpp | 2 +- src/cpu/aarch64/matmul/brgemm_matmul.cpp | 2 +- src/cpu/aarch64/matmul/jit_int8_matmul.cpp | 3 +- src/cpu/gemm_inner_product_utils.hpp | 4 +- src/cpu/matmul/gemm_based_common.hpp | 2 +- src/cpu/matmul/gemm_bf16_matmul.cpp | 4 +- src/cpu/matmul/gemm_bf16_matmul.hpp | 2 +- src/cpu/matmul/gemm_f32_matmul.cpp | 8 +-- src/cpu/matmul/gemm_f32_matmul.hpp | 2 +- src/cpu/matmul/gemm_x8s8s32x_matmul.cpp | 4 +- src/cpu/matmul/gemm_x8s8s32x_matmul.hpp | 2 +- src/cpu/matmul/matmul_utils.hpp | 3 +- src/cpu/x64/matmul/brgemm_matmul.cpp | 2 +- 23 files changed, 127 insertions(+), 70 deletions(-) diff --git a/src/common/gemm_types.hpp b/src/common/gemm_types.hpp index 23209d3eb58..f7be676f791 100644 --- a/src/common/gemm_types.hpp +++ b/src/common/gemm_types.hpp @@ -91,8 +91,8 @@ struct gemm_desc_t : public op_desc_t { // if ndims < 3, it should return 1 int64_t batch = 1; for (int i = 0; i < c_desc.ndims - 2; ++i) { - if (c_desc.dims[i] == DNNL_RUNTIME_DIM_VAL) - return DNNL_RUNTIME_DIM_VAL; + if (is_runtime_value(c_desc.dims[i])) + return runtime_value_for(); batch *= c_desc.dims[i]; } return batch; diff --git a/src/common/matmul.cpp b/src/common/matmul.cpp index 45c57693ce4..07eb8b89cc4 100644 --- a/src/common/matmul.cpp +++ b/src/common/matmul.cpp @@ -495,11 +495,11 @@ status_t matmul_desc_init(matmul_desc_t *matmul_desc, const dim_t b_dim = with_bias ? op_d.bias_desc.dims[d] : 0; const dim_t r_dim = with_reduce ? op_d.reduce_desc.dims[d] : 0; - if (one_of(DNNL_RUNTIME_DIM_VAL, s_dim, w_dim, d_dim, b_dim)) { + if (any_runtime_value(s_dim, w_dim, d_dim, b_dim)) { - VCHECK_MATMUL(everyone_is(DNNL_RUNTIME_DIM_VAL, s_dim, w_dim, d_dim) + VCHECK_MATMUL(all_runtime_values(s_dim, w_dim, d_dim) && IMPLICATION((bia_mask & (1 << d)) && with_bias, - b_dim == DNNL_RUNTIME_DIM_VAL), + is_runtime_value(b_dim)), VERBOSE_RUNTIMEDIM_INCONSISTENT, d); } else { // This follows numpy semantics of broadcasting when 0 is involved. diff --git a/src/common/memory.cpp b/src/common/memory.cpp index 1ea4661f660..a58bb32915d 100644 --- a/src/common/memory.cpp +++ b/src/common/memory.cpp @@ -48,7 +48,7 @@ namespace { // Returns the size required for memory descriptor mapping. // Caveats: // 1. If memory descriptor with run-time parameters, the mapping cannot be done; -// hence return DNNL_RUNTIME_SIZE_VAL +// hence return runtime_value_for() // 2. Otherwise, the size returned includes `offset0` and holes (for the case // of non-trivial strides). Strictly speaking, the mapping should happen only // for elements accessible with `md.off_l(0 .. md.nelems())`. However, for @@ -59,7 +59,7 @@ namespace { size_t memory_desc_map_size(const memory_desc_t *md, int index = 0) { auto mdw = memory_desc_wrapper(md); - if (mdw.has_runtime_dims_or_strides()) return DNNL_RUNTIME_SIZE_VAL; + if (mdw.has_runtime_dims_or_strides()) return runtime_value_for(); return mdw.size(index, true, true); } @@ -350,7 +350,7 @@ status_t dnnl_memory_map_data_v2( if (map_size == 0) { *mapped_ptr = nullptr; return success; - } else if (map_size == DNNL_RUNTIME_SIZE_VAL) { + } else if (is_runtime_value(map_size)) { return invalid_arguments; } diff --git a/src/common/memory_desc.cpp b/src/common/memory_desc.cpp index 5723c43736d..5f20684e7d5 100644 --- a/src/common/memory_desc.cpp +++ b/src/common/memory_desc.cpp @@ -105,10 +105,9 @@ status_t memory_desc_init_by_strides(memory_desc_t &memory_desc, int ndims, bool has_runtime_strides = false; default_strides[md.ndims - 1] = 1; for (int d = md.ndims - 2; d >= 0; --d) { - if (md.padded_dims[d] == DNNL_RUNTIME_DIM_VAL) - has_runtime_strides = true; + if (is_runtime_value(md.padded_dims[d])) has_runtime_strides = true; default_strides[d] = has_runtime_strides - ? DNNL_RUNTIME_DIM_VAL + ? runtime_value_for(default_strides[d]) : default_strides[d + 1] * md.padded_dims[d + 1]; } strides = default_strides; @@ -220,9 +219,8 @@ status_t memory_desc_init_submemory(memory_desc_t &memory_desc, VERBOSE_UNSUPPORTED_MEM_STRIDE); for (int d = 0; d < src_d.ndims(); ++d) { - VCHECK_MEMORY( - !(utils::one_of(DNNL_RUNTIME_DIM_VAL, dims[d], offsets[d])), - unimplemented, VERBOSE_RUNTIMEDIM_UNSUPPORTED); + VCHECK_MEMORY(!any_runtime_value(dims[d], offsets[d]), unimplemented, + VERBOSE_RUNTIMEDIM_UNSUPPORTED); const bool dim_offsets_oob = (dims[d] < 0 || offsets[d] < 0 || (offsets[d] + dims[d] > src_d.dims()[d])); @@ -269,7 +267,7 @@ status_t memory_desc_reshape(memory_desc_t &out_memory_desc, auto volume = [](const dim_t *dims, int ndims) -> dim_t { dim_t prod = 1; for (int i = 0; i < ndims; ++i) { - if (dims[i] == DNNL_RUNTIME_DIM_VAL) return DNNL_RUNTIME_DIM_VAL; + if (is_runtime_value(dims[i])) return runtime_value_for(prod); prod *= dims[i] > 0 ? dims[i] : 1; } return prod; @@ -576,12 +574,12 @@ status_t memory_desc_init_by_string_tag(memory_desc_t &md, int ndims, blk.strides[dim_idx] = stride; dim_t fib = full_inner_blks[dim_idx]; - dim_t padded_dim = md.dims[dim_idx] == DNNL_RUNTIME_DIM_VAL - ? DNNL_RUNTIME_DIM_VAL + const auto padded_dim = is_runtime_value(md.dims[dim_idx]) + ? runtime_value_for(md.padded_dims[dim_idx]) : (md.dims[dim_idx] + fib - 1) / fib * fib; md.padded_dims[dim_idx] = padded_dim; - if (one_of(DNNL_RUNTIME_DIM_VAL, padded_dim, stride)) - stride = DNNL_RUNTIME_DIM_VAL; + if (any_runtime_value(padded_dim, stride)) + stride = runtime_value_for(stride); else stride *= (padded_dim / fib); } else { diff --git a/src/common/memory_desc_wrapper.cpp b/src/common/memory_desc_wrapper.cpp index 3370b5d237e..8f75e7e9aa4 100644 --- a/src/common/memory_desc_wrapper.cpp +++ b/src/common/memory_desc_wrapper.cpp @@ -59,8 +59,8 @@ status_t fill_blocked(memory_desc_t &md, std::initializer_list perm, utils::array_set(md.padded_offsets, 0, md.ndims); for (int d = 0; d < md.ndims; ++d) - md.padded_dims[d] = md.dims[d] == DNNL_RUNTIME_DIM_VAL - ? DNNL_RUNTIME_DIM_VAL + md.padded_dims[d] = is_runtime_value(md.dims[d]) + ? runtime_value_for(md.padded_dims[d]) : utils::rnd_up(md.dims[d], blocks[d]); // setting the strides @@ -72,8 +72,8 @@ status_t fill_blocked(memory_desc_t &md, std::initializer_list perm, blk.strides[d] = stride; const dim_t pdim = md.padded_dims[d]; - if (utils::one_of(DNNL_RUNTIME_DIM_VAL, stride, pdim)) - stride = DNNL_RUNTIME_DIM_VAL; + if (any_runtime_value(stride, pdim)) + stride = runtime_value_for(stride); else if (pdim != 0) stride *= pdim / blocks[d]; diff --git a/src/common/memory_desc_wrapper.hpp b/src/common/memory_desc_wrapper.hpp index 2b3d3fe34a8..a9c63a94b78 100644 --- a/src/common/memory_desc_wrapper.hpp +++ b/src/common/memory_desc_wrapper.hpp @@ -183,7 +183,7 @@ struct memory_desc_wrapper : public c_compatible { * is true, and the number of data elements otherwise */ dim_t nelems(bool with_padding = false) const { if (is_zero()) return 0; - if (has_runtime_dims()) return DNNL_RUNTIME_DIM_VAL; + if (has_runtime_dims()) return runtime_value_for(); return utils::array_product( with_padding ? padded_dims() : dims(), ndims()); } @@ -301,7 +301,7 @@ struct memory_desc_wrapper : public c_compatible { return 0; } - if (has_runtime_dims_or_strides()) return DNNL_RUNTIME_SIZE_VAL; + if (has_runtime_dims_or_strides()) return runtime_value_for(); if (is_wino_desc()) { return wino_desc().size; @@ -433,7 +433,7 @@ struct memory_desc_wrapper : public c_compatible { /** returns true if at least one dim is not known */ bool has_runtime_dims() const { for (int d = 0; d < ndims(); ++d) - if (dims()[d] == DNNL_RUNTIME_DIM_VAL) return true; + if (is_runtime_value(dims()[d])) return true; return false; } @@ -441,7 +441,7 @@ struct memory_desc_wrapper : public c_compatible { bool has_runtime_strides() const { if (!is_blocking_desc()) return false; for (int d = 0; d < ndims(); ++d) - if (blocking_desc().strides[d] == DNNL_RUNTIME_DIM_VAL) return true; + if (is_runtime_value(blocking_desc().strides[d])) return true; return false; } diff --git a/src/common/memory_zero_pad.cpp b/src/common/memory_zero_pad.cpp index e48de56c416..c8b19805e33 100644 --- a/src/common/memory_zero_pad.cpp +++ b/src/common/memory_zero_pad.cpp @@ -199,7 +199,7 @@ status_t typed_zero_pad(const memory_t *memory, const exec_ctx_t &ctx) { if (mdw.nelems(false) == mdw.nelems(true)) return success; const size_t map_size = mdw.size(); - assert(map_size != DNNL_RUNTIME_SIZE_VAL); + assert(!is_runtime_value(map_size)); void *mapped_ptr = ctx.map_memory_storage(memory_storage, ctx.stream(), map_size); diff --git a/src/common/primitive_attr.cpp b/src/common/primitive_attr.cpp index 8c40e44a9e7..7d76bcb0143 100644 --- a/src/common/primitive_attr.cpp +++ b/src/common/primitive_attr.cpp @@ -242,7 +242,7 @@ status_t post_ops_t::validate_binary(alg_kind_t alg, // Additional check to restrict run-time dimension usage until supported. for (int d = 0; d < user_src1_desc->ndims; ++d) { - VCHECK_ATTR(user_src1_desc->dims[d] != DNNL_RUNTIME_DIM_VAL, + VCHECK_ATTR(!is_runtime_value(user_src1_desc->dims[d]), VERBOSE_RUNTIMEDIM_UNSUPPORTED); } @@ -251,7 +251,7 @@ status_t post_ops_t::validate_binary(alg_kind_t alg, VCHECK_ATTR(memory_desc_sanity_check(*user_src2_desc), VERBOSE_MEM_DESC_CHECK_FAIL); for (int d = 0; d < user_src2_desc->ndims; ++d) { - VCHECK_ATTR(user_src2_desc->dims[d] != DNNL_RUNTIME_DIM_VAL, + VCHECK_ATTR(!is_runtime_value(user_src2_desc->dims[d]), VERBOSE_RUNTIMEDIM_UNSUPPORTED); } } diff --git a/src/common/type_helpers.hpp b/src/common/type_helpers.hpp index e78ced9dd33..f110b5b61b1 100644 --- a/src/common/type_helpers.hpp +++ b/src/common/type_helpers.hpp @@ -1089,8 +1089,8 @@ inline bool memory_desc_strides_check( if (md.padded_dims[d] == 0) return true; // no strides verification for runtime dims - const bool has_runtime_dim = utils::one_of( - DNNL_RUNTIME_DIM_VAL, strides[d], md.padded_dims[d]); + const bool has_runtime_dim + = any_runtime_value(strides[d], md.padded_dims[d]); if (has_runtime_dim) return true; perm[d] = d; @@ -1218,8 +1218,10 @@ inline status_t memory_desc_init_by_blocking_desc( utils::simultaneous_sort( mblk.strides, ou_blocks, perm, ndims, [](stride_t a, stride_t b) { - if (utils::one_of(DNNL_RUNTIME_DIM_VAL, a, b)) - return DNNL_RUNTIME_DIM_VAL; + static_assert(runtime_value_for() < 0, + "negative value is expected"); + if (any_runtime_value(a, b)) + return runtime_value_for(); // negative: preserves order return b - a; }); @@ -1298,21 +1300,6 @@ format_tag_t memory_desc_matches_one_of_tag( return format_tag::undef; } -/** returns true if fp32 value denotes DNNL_RUNTIME_F32_VAL */ -inline bool is_runtime_value(float val) { - return utils::bit_cast(val) == DNNL_RUNTIME_F32_VAL_REP.u; -} - -/** returns true if s32 value denotes DNNL_RUNTIME_S32_VAL */ -inline bool is_runtime_value(int val) { - return val == DNNL_RUNTIME_S32_VAL; -} - -/** returns true if dim_t value denotes DNNL_RUNTIME_DIM_VAL */ -inline bool is_runtime_value(dim_t val) { - return val == DNNL_RUNTIME_DIM_VAL; -} - inline bool memory_desc_sanity_check(int ndims, const dims_t dims, data_type_t data_type, format_kind_t format_kind) { using namespace data_type; @@ -1326,8 +1313,8 @@ inline bool memory_desc_sanity_check(int ndims, const dims_t dims, bool has_runtime_dims = false; for (int d = 0; d < ndims; ++d) { - if (dims[d] != DNNL_RUNTIME_DIM_VAL && dims[d] < 0) return false; - if (dims[d] == DNNL_RUNTIME_DIM_VAL) has_runtime_dims = true; + if (!is_runtime_value(dims[d]) && dims[d] < 0) return false; + if (is_runtime_value(dims[d])) has_runtime_dims = true; } if (has_runtime_dims) { diff --git a/src/common/utils.hpp b/src/common/utils.hpp index e5015b4d779..29378b699f7 100644 --- a/src/common/utils.hpp +++ b/src/common/utils.hpp @@ -970,6 +970,76 @@ class mask_iterator { // NOLINT(readability-identifier-naming) } }; +/** returns true if fp32 value denotes DNNL_RUNTIME_F32_VAL */ +inline bool is_runtime_value(float val) { + return utils::bit_cast(val) == DNNL_RUNTIME_F32_VAL_REP.u; +} + +/** returns true if s32 value denotes DNNL_RUNTIME_S32_VAL */ +inline bool is_runtime_value(int val) { + return val == DNNL_RUNTIME_S32_VAL; +} + +/** returns true if dim_t value denotes DNNL_RUNTIME_DIM_VAL */ +inline bool is_runtime_value(dim_t val) { + return val == DNNL_RUNTIME_DIM_VAL; +} + +/** returns true if size_t value denotes DNNL_RUNTIME_SIZE_VAL */ +inline bool is_runtime_value(size_t val) { + return val == DNNL_RUNTIME_SIZE_VAL; +} + +template +constexpr bool any_runtime_value(T item) { + return is_runtime_value(item); +} +template +bool any_runtime_value(T item, Args... item_others) { + return is_runtime_value(item) || any_runtime_value(item_others...); +} + +template +constexpr bool all_runtime_values(T item) { + return is_runtime_value(item); +} +template +constexpr bool all_runtime_values(T item, Args... item_others) { + return is_runtime_value(item) && all_runtime_values(item_others...); +} + +template +constexpr T runtime_value_for() { + static_assert(sizeof(T) == 0, "no runtime value defined for this type"); + return T {}; +} + +template <> +inline float runtime_value_for() { + return DNNL_RUNTIME_F32_VAL; +} + +template <> +constexpr int runtime_value_for() { + return DNNL_RUNTIME_S32_VAL; +} + +template <> +constexpr dim_t runtime_value_for() { + return DNNL_RUNTIME_DIM_VAL; +} + +template <> +constexpr size_t runtime_value_for() { + return DNNL_RUNTIME_SIZE_VAL; +} + +/** returns the runtime placeholder constant for the argument type T */ +template +inline T runtime_value_for(T) { + return runtime_value_for::type>(); +} + } // namespace impl } // namespace dnnl diff --git a/src/common/verbose.cpp b/src/common/verbose.cpp index db93ef19741..9cebeda1ad0 100644 --- a/src/common/verbose.cpp +++ b/src/common/verbose.cpp @@ -629,7 +629,7 @@ namespace { int get_runtime_mask(const memory_desc_t *md) { int mask = 0; for (int d = md->ndims - 1; d >= 0; --d) { - mask += md->dims[d] == DNNL_RUNTIME_DIM_VAL ? 1 << d : 0; + mask += is_runtime_value(md->dims[d]) ? 1 << d : 0; } return mask; } diff --git a/src/cpu/aarch64/matmul/brgemm_matmul.cpp b/src/cpu/aarch64/matmul/brgemm_matmul.cpp index 609990fa565..9b2c43408bd 100644 --- a/src/cpu/aarch64/matmul/brgemm_matmul.cpp +++ b/src/cpu/aarch64/matmul/brgemm_matmul.cpp @@ -122,7 +122,7 @@ status_t brgemm_matmul_t::pd_t::init(engine_t *engine) { && !attr()->scales_.has_default_values(DNNL_ARG_WEIGHTS) && attr()->scales_.get_mask(DNNL_ARG_WEIGHTS) > 0) { // This case requires scratchpad - if (N() == DNNL_RUNTIME_DIM_VAL) ok = false; + if (is_runtime_value(N())) ok = false; } if (!attr()->post_ops_.sum_with_default_dt()) return false; diff --git a/src/cpu/aarch64/matmul/jit_int8_matmul.cpp b/src/cpu/aarch64/matmul/jit_int8_matmul.cpp index c909077ac1d..08a687ac58b 100644 --- a/src/cpu/aarch64/matmul/jit_int8_matmul.cpp +++ b/src/cpu/aarch64/matmul/jit_int8_matmul.cpp @@ -1,4 +1,5 @@ /******************************************************************************* +* Copyright 2026 Intel Corporation * Copyright 2025 FUJITSU LIMITED * Copyright 2025 Arm Ltd. and affiliates * @@ -700,7 +701,7 @@ status_t jit_int8_matmul_t::pd_t::init(engine_t *engine) { if (is_src_scl && is_wei_scl && wei_scl_msk > 0) { // This case requires scratchpad. - if (N() == DNNL_RUNTIME_DIM_VAL) ok = false; + if (is_runtime_value(N())) ok = false; } return ok; }; diff --git a/src/cpu/gemm_inner_product_utils.hpp b/src/cpu/gemm_inner_product_utils.hpp index 1eefbe9fc89..933bda4abc7 100644 --- a/src/cpu/gemm_inner_product_utils.hpp +++ b/src/cpu/gemm_inner_product_utils.hpp @@ -91,8 +91,8 @@ struct pp_kernel_t { return (!runtime_oc()) && (OC_ == (size_t)dst_mb_stride_); } bool do_bias() const { return bias_data_type_ != data_type::undef; } - bool runtime_oc() const { return OC_ == (size_t)DNNL_RUNTIME_DIM_VAL; } - bool runtime_mb() const { return MB_ == (size_t)DNNL_RUNTIME_DIM_VAL; } + bool runtime_oc() const { return is_runtime_value(OC_); } + bool runtime_mb() const { return is_runtime_value(MB_); } }; inline const bcast_set_t &gemm_default_strategies() { diff --git a/src/cpu/matmul/gemm_based_common.hpp b/src/cpu/matmul/gemm_based_common.hpp index b5b6c37ccae..acddda2a031 100644 --- a/src/cpu/matmul/gemm_based_common.hpp +++ b/src/cpu/matmul/gemm_based_common.hpp @@ -111,7 +111,7 @@ inline bool check_gemm_binary_per_oc_compatible_formats(const matmul_pd_t &pd) { const int ndims = dst_d.ndims(); for (auto d : dims) - if (d == DNNL_RUNTIME_DIM_VAL) return false; + if (is_runtime_value(d)) return false; // check d, h, w... (b2, m, n... for matmul) dimensions are continuous bool ok = true; diff --git a/src/cpu/matmul/gemm_bf16_matmul.cpp b/src/cpu/matmul/gemm_bf16_matmul.cpp index 7bb56648bc6..5ca0f9fa1a9 100644 --- a/src/cpu/matmul/gemm_bf16_matmul.cpp +++ b/src/cpu/matmul/gemm_bf16_matmul.cpp @@ -111,7 +111,7 @@ status_t gemm_bf16_matmul_t::pd_t::check_and_configure_attributes( && !attr()->scales_.has_default_values(DNNL_ARG_WEIGHTS) && attr()->scales_.get_mask(DNNL_ARG_WEIGHTS) > 0) { // This case requires scratchpad with unknown size - if (N() == DNNL_RUNTIME_DIM_VAL) ok = false; + if (is_runtime_value(N())) ok = false; } return ok; }; @@ -138,7 +138,7 @@ status_t gemm_bf16_matmul_t::pd_t::check_and_configure_attributes( && IMPLICATION(is_binary_po_per_oc, gemm_based::check_gemm_binary_per_oc_compatible_formats( *this)) - && IMPLICATION(N() == DNNL_RUNTIME_DIM_VAL, !has_prelu); + && IMPLICATION(is_runtime_value(N()), !has_prelu); }; // check basic attributes diff --git a/src/cpu/matmul/gemm_bf16_matmul.hpp b/src/cpu/matmul/gemm_bf16_matmul.hpp index 03f29a09f0f..98a38a128ed 100644 --- a/src/cpu/matmul/gemm_bf16_matmul.hpp +++ b/src/cpu/matmul/gemm_bf16_matmul.hpp @@ -63,7 +63,7 @@ struct gemm_bf16_matmul_t : public primitive_t { // mb value is calculated based on work-sharing using // balance211 in execute() - dim_t mb = DNNL_RUNTIME_DIM_VAL; + auto mb = runtime_value_for(); if (!has_runtime_dims && ((batch * M) % nthr == 0)) { const dim_t m_per_thr = nstl::max(1, (batch * M) / nthr); if (m_per_thr >= M && m_per_thr % M == 0) { diff --git a/src/cpu/matmul/gemm_f32_matmul.cpp b/src/cpu/matmul/gemm_f32_matmul.cpp index df4ad3a7989..021361dc3c6 100644 --- a/src/cpu/matmul/gemm_f32_matmul.cpp +++ b/src/cpu/matmul/gemm_f32_matmul.cpp @@ -54,7 +54,7 @@ status_t gemm_f32_matmul_t::pd_t::init(engine_t *engine) { && !attr()->scales_.has_default_values(DNNL_ARG_WEIGHTS) && attr()->scales_.get_mask(DNNL_ARG_WEIGHTS) > 0) { // This case requires scratchpad with unknown size - if (N() == DNNL_RUNTIME_DIM_VAL) ok = false; + if (is_runtime_value(N())) ok = false; } return ok; }; @@ -81,7 +81,7 @@ status_t gemm_f32_matmul_t::pd_t::init(engine_t *engine) { && IMPLICATION(is_binary_po_per_oc, gemm_based::check_gemm_binary_per_oc_compatible_formats( *this)) - && IMPLICATION(N() == DNNL_RUNTIME_DIM_VAL, !has_prelu); + && IMPLICATION(is_runtime_value(N()), !has_prelu); }; const bool problem_dt_correct = src_md()->data_type == src_type @@ -160,8 +160,8 @@ status_t gemm_f32_matmul_t::pd_t::configure_attributes() { data_type::undef); // `C_is_abx` limitation comes from `extended_sgemm`. - const bool C_is_abx = helper.ldc() >= helper.N() - && helper.ldc() != DNNL_RUNTIME_DIM_VAL; + const bool C_is_abx + = !is_runtime_value(helper.ldc()) && helper.ldc() >= helper.N(); params_.dst_is_acc_ = C_is_abx && IMPLICATION(attr()->post_ops_.find(primitive_kind::sum) != -1, sum_po_via_gemm_beta); diff --git a/src/cpu/matmul/gemm_f32_matmul.hpp b/src/cpu/matmul/gemm_f32_matmul.hpp index 5f9197f2ed3..d65a6b7d9c9 100644 --- a/src/cpu/matmul/gemm_f32_matmul.hpp +++ b/src/cpu/matmul/gemm_f32_matmul.hpp @@ -61,7 +61,7 @@ struct gemm_f32_matmul_t : public primitive_t { // mb value is calculated based on work-sharing using // balance211 in execute() - dim_t mb = DNNL_RUNTIME_DIM_VAL; + auto mb = runtime_value_for(); if (!has_runtime_dims && ((batch * M) % nthr == 0)) { const dim_t m_per_thr = nstl::max(1, (batch * M) / nthr); if (m_per_thr >= M && m_per_thr % M == 0) { diff --git a/src/cpu/matmul/gemm_x8s8s32x_matmul.cpp b/src/cpu/matmul/gemm_x8s8s32x_matmul.cpp index d6bdd36abc4..e6434e753d2 100644 --- a/src/cpu/matmul/gemm_x8s8s32x_matmul.cpp +++ b/src/cpu/matmul/gemm_x8s8s32x_matmul.cpp @@ -65,7 +65,7 @@ status_t gemm_x8s8s32x_matmul_t::pd_t::init(engine_t *engine) { && !attr()->scales_.has_default_values(DNNL_ARG_WEIGHTS) && attr()->scales_.get_mask(DNNL_ARG_WEIGHTS) > 0) { // This case requires scratchpad with unknown size - if (N() == DNNL_RUNTIME_DIM_VAL) ok = false; + if (is_runtime_value(N())) ok = false; } return ok; }; @@ -105,7 +105,7 @@ status_t gemm_x8s8s32x_matmul_t::pd_t::init(engine_t *engine) { && IMPLICATION(is_binary_po_per_oc, gemm_based::check_gemm_binary_per_oc_compatible_formats( *this)) - && IMPLICATION(N() == DNNL_RUNTIME_DIM_VAL, !has_prelu); + && IMPLICATION(is_runtime_value(N()), !has_prelu); }; VDISPATCH_MATMUL(DNNL_CPU_THREADING_RUNTIME != DNNL_RUNTIME_THREADPOOL, diff --git a/src/cpu/matmul/gemm_x8s8s32x_matmul.hpp b/src/cpu/matmul/gemm_x8s8s32x_matmul.hpp index f30aad1c093..cc09ba69858 100644 --- a/src/cpu/matmul/gemm_x8s8s32x_matmul.hpp +++ b/src/cpu/matmul/gemm_x8s8s32x_matmul.hpp @@ -62,7 +62,7 @@ struct gemm_x8s8s32x_matmul_t : public primitive_t { // mb value is calculated based on work-sharing using // balance211 in execute() - dim_t mb = DNNL_RUNTIME_DIM_VAL; + auto mb = runtime_value_for(); if (!has_runtime_dims && ((batch * M) % nthr == 0)) { const dim_t m_per_thr = nstl::max(1, (batch * M) / nthr); if (m_per_thr >= M && m_per_thr % M == 0) { diff --git a/src/cpu/matmul/matmul_utils.hpp b/src/cpu/matmul/matmul_utils.hpp index 3764ad1a707..9d0b2683449 100644 --- a/src/cpu/matmul/matmul_utils.hpp +++ b/src/cpu/matmul/matmul_utils.hpp @@ -213,7 +213,8 @@ struct matmul_helper_t { dim_t batch_size = 1; for (int b_idx = 0; b_idx < batch_dims; b_idx++) { dim_t batch_dim = tensor_md.dims()[b_idx]; - if (DNNL_RUNTIME_DIM_VAL == batch_dim) return DNNL_RUNTIME_DIM_VAL; + if (is_runtime_value(batch_dim)) + return runtime_value_for(batch_size); batch_size *= batch_dim; } diff --git a/src/cpu/x64/matmul/brgemm_matmul.cpp b/src/cpu/x64/matmul/brgemm_matmul.cpp index 5684186956f..6a6dea2607d 100644 --- a/src/cpu/x64/matmul/brgemm_matmul.cpp +++ b/src/cpu/x64/matmul/brgemm_matmul.cpp @@ -217,7 +217,7 @@ status_t brgemm_matmul_t::pd_t::init(engine_t *engine) { && !asc.has_default_values(DNNL_ARG_WEIGHTS) && asc.get_mask(DNNL_ARG_WEIGHTS) > 0) { // This case requires scratchpad - if (N() == DNNL_RUNTIME_DIM_VAL) ok = false; + if (is_runtime_value(N())) ok = false; } // Impl suppports f32 scales only for non-weight decompression if (!(is_bf16_with_int_wei || is_f16_with_int_wei From 48a989683b59a092e11c1ce1ff85ada6e3968671 Mon Sep 17 00:00:00 2001 From: Ankit Manerikar Date: Tue, 10 Mar 2026 14:57:35 -0700 Subject: [PATCH 2/2] fix: common: memory_desc: add upper bounds check for md tensor dims --- src/common/memory_desc_wrapper.cpp | 21 ++++++++++++++++++ src/common/type_helpers.hpp | 35 +++++++++++++++++++++++++++++- src/common/utils.hpp | 6 ++++- 3 files changed, 60 insertions(+), 2 deletions(-) diff --git a/src/common/memory_desc_wrapper.cpp b/src/common/memory_desc_wrapper.cpp index 8f75e7e9aa4..fb154610f33 100644 --- a/src/common/memory_desc_wrapper.cpp +++ b/src/common/memory_desc_wrapper.cpp @@ -63,6 +63,10 @@ status_t fill_blocked(memory_desc_t &md, std::initializer_list perm, ? runtime_value_for(md.padded_dims[d]) : utils::rnd_up(md.dims[d], blocks[d]); + // tracks max stride for integral overflow checks + dim_t max_stride = 1; + int max_stride_d = 0; + // setting the strides { dim_t stride = block_size; @@ -77,9 +81,26 @@ status_t fill_blocked(memory_desc_t &md, std::initializer_list perm, else if (pdim != 0) stride *= pdim / blocks[d]; + if (max_stride <= stride) { + max_stride = stride; + max_stride_d = d; + } + } while (iter_d != perm.begin()); } + const size_t dt_size = types::data_type_size(md.data_type); + + // guard against integral overflow due to strides exceeding numeric limits + if (!is_runtime_value(md.padded_dims[max_stride_d])) { + size_t dim_val = static_cast( + md.padded_dims[max_stride_d] / blocks[max_stride_d]); + dim_val = dim_val == (size_t)max_stride ? 1 : dim_val; + if (dim_val > SIZE_MAX / max_stride) return status::invalid_arguments; + if (dt_size && ((dim_val * max_stride) > SIZE_MAX / dt_size)) + return status::invalid_arguments; + } + return status::success; } diff --git a/src/common/type_helpers.hpp b/src/common/type_helpers.hpp index f110b5b61b1..36c3b15c8be 100644 --- a/src/common/type_helpers.hpp +++ b/src/common/type_helpers.hpp @@ -1115,6 +1115,10 @@ inline bool memory_desc_strides_check( }; std::sort(perm, perm + md.ndims, idx_sorter); + // tracks max stride for integral overflow checks + dim_t max_stride = 1; + int max_stride_d = 0; + dim_t min_stride = block_size; for (int idx = 0; idx < md.ndims; ++idx) { const int d = perm[idx]; @@ -1134,6 +1138,22 @@ inline bool memory_desc_strides_check( // update min_stride for next iteration const auto padded_dim = md.padded_dims[d]; min_stride = block_size * strides[d] * (padded_dim / blocks[d]); + if (max_stride <= strides[d]) { + max_stride = strides[d]; + max_stride_d = d; + } + } + + const size_t dt_size = types::data_type_size(md.data_type); + + // guard against integral overflow due to strides exceeding numeric limits + if (!is_runtime_value(md.padded_dims[max_stride_d])) { + size_t dim_val = static_cast( + md.padded_dims[max_stride_d] / blocks[max_stride_d]); + dim_val = dim_val == (size_t)max_stride ? 1 : dim_val; + if (dim_val > SIZE_MAX / max_stride) return false; + if (dt_size && ((dim_val * max_stride) > SIZE_MAX / dt_size)) + return false; } return true; } @@ -1311,9 +1331,22 @@ inline bool memory_desc_sanity_check(int ndims, const dims_t dims, f8_e4m3, f16, bf16, f32, f64, s64, s32, s8, u8, s4, u4); if (!ok) return false; + // A bounds check on the dimensions ensures that the tensor size + // computation does not trigger a overflow during memory creation. + dim_t prod = 1; + for (int d = 0; d < ndims; ++d) { + if (dims[d] != DNNL_RUNTIME_DIM_VAL) { + if (dims[d] < 0) return false; + if (dims[d] > 0) { + if (prod > std::numeric_limits::max() / dims[d]) + return false; + prod *= dims[d]; + } + } + } + bool has_runtime_dims = false; for (int d = 0; d < ndims; ++d) { - if (!is_runtime_value(dims[d]) && dims[d] < 0) return false; if (is_runtime_value(dims[d])) has_runtime_dims = true; } diff --git a/src/common/utils.hpp b/src/common/utils.hpp index 29378b699f7..51b55b909d1 100644 --- a/src/common/utils.hpp +++ b/src/common/utils.hpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -296,8 +297,11 @@ constexpr T array_product(const T *arr) { template inline R array_product(const T *arr, size_t size) { R prod = 1; - for (size_t i = 0; i < size; ++i) + for (size_t i = 0; i < size; ++i) { + assert(IMPLICATION(arr[i] > 0 && prod > 0, + prod <= std::numeric_limits::max() / arr[i])); prod *= arr[i]; + } return prod; }