Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/common/gemm_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ struct gemm_desc_t : public op_desc_t {
// if ndims < 3, it should return 1
int64_t batch = 1;
for (int i = 0; i < c_desc.ndims - 2; ++i) {
if (c_desc.dims[i] == DNNL_RUNTIME_DIM_VAL)
return DNNL_RUNTIME_DIM_VAL;
if (is_runtime_value(c_desc.dims[i]))
return runtime_value_for<dnnl_dim_t>();
batch *= c_desc.dims[i];
}
return batch;
Expand Down
6 changes: 3 additions & 3 deletions src/common/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -495,11 +495,11 @@ status_t matmul_desc_init(matmul_desc_t *matmul_desc,
const dim_t b_dim = with_bias ? op_d.bias_desc.dims[d] : 0;
const dim_t r_dim = with_reduce ? op_d.reduce_desc.dims[d] : 0;

if (one_of(DNNL_RUNTIME_DIM_VAL, s_dim, w_dim, d_dim, b_dim)) {
if (any_runtime_value(s_dim, w_dim, d_dim, b_dim)) {

VCHECK_MATMUL(everyone_is(DNNL_RUNTIME_DIM_VAL, s_dim, w_dim, d_dim)
VCHECK_MATMUL(all_runtime_values(s_dim, w_dim, d_dim)
&& IMPLICATION((bia_mask & (1 << d)) && with_bias,
b_dim == DNNL_RUNTIME_DIM_VAL),
is_runtime_value(b_dim)),
VERBOSE_RUNTIMEDIM_INCONSISTENT, d);
} else {
// This follows numpy semantics of broadcasting when 0 is involved.
Expand Down
6 changes: 3 additions & 3 deletions src/common/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ namespace {
// Returns the size required for memory descriptor mapping.
// Caveats:
// 1. If memory descriptor with run-time parameters, the mapping cannot be done;
// hence return DNNL_RUNTIME_SIZE_VAL
// hence return runtime_value_for<size_t>()
// 2. Otherwise, the size returned includes `offset0` and holes (for the case
// of non-trivial strides). Strictly speaking, the mapping should happen only
// for elements accessible with `md.off_l(0 .. md.nelems())`. However, for
Expand All @@ -59,7 +59,7 @@ namespace {
size_t memory_desc_map_size(const memory_desc_t *md, int index = 0) {
auto mdw = memory_desc_wrapper(md);

if (mdw.has_runtime_dims_or_strides()) return DNNL_RUNTIME_SIZE_VAL;
if (mdw.has_runtime_dims_or_strides()) return runtime_value_for<size_t>();

return mdw.size(index, true, true);
}
Expand Down Expand Up @@ -350,7 +350,7 @@ status_t dnnl_memory_map_data_v2(
if (map_size == 0) {
*mapped_ptr = nullptr;
return success;
} else if (map_size == DNNL_RUNTIME_SIZE_VAL) {
} else if (is_runtime_value(map_size)) {
return invalid_arguments;
}

Expand Down
20 changes: 9 additions & 11 deletions src/common/memory_desc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,9 @@ status_t memory_desc_init_by_strides(memory_desc_t &memory_desc, int ndims,
bool has_runtime_strides = false;
default_strides[md.ndims - 1] = 1;
for (int d = md.ndims - 2; d >= 0; --d) {
if (md.padded_dims[d] == DNNL_RUNTIME_DIM_VAL)
has_runtime_strides = true;
if (is_runtime_value(md.padded_dims[d])) has_runtime_strides = true;
default_strides[d] = has_runtime_strides
? DNNL_RUNTIME_DIM_VAL
? runtime_value_for(default_strides[d])
: default_strides[d + 1] * md.padded_dims[d + 1];
}
strides = default_strides;
Expand Down Expand Up @@ -220,9 +219,8 @@ status_t memory_desc_init_submemory(memory_desc_t &memory_desc,
VERBOSE_UNSUPPORTED_MEM_STRIDE);

for (int d = 0; d < src_d.ndims(); ++d) {
VCHECK_MEMORY(
!(utils::one_of(DNNL_RUNTIME_DIM_VAL, dims[d], offsets[d])),
unimplemented, VERBOSE_RUNTIMEDIM_UNSUPPORTED);
VCHECK_MEMORY(!any_runtime_value(dims[d], offsets[d]), unimplemented,
VERBOSE_RUNTIMEDIM_UNSUPPORTED);

const bool dim_offsets_oob = (dims[d] < 0 || offsets[d] < 0
|| (offsets[d] + dims[d] > src_d.dims()[d]));
Expand Down Expand Up @@ -269,7 +267,7 @@ status_t memory_desc_reshape(memory_desc_t &out_memory_desc,
auto volume = [](const dim_t *dims, int ndims) -> dim_t {
dim_t prod = 1;
for (int i = 0; i < ndims; ++i) {
if (dims[i] == DNNL_RUNTIME_DIM_VAL) return DNNL_RUNTIME_DIM_VAL;
if (is_runtime_value(dims[i])) return runtime_value_for(prod);
prod *= dims[i] > 0 ? dims[i] : 1;
}
return prod;
Expand Down Expand Up @@ -576,12 +574,12 @@ status_t memory_desc_init_by_string_tag(memory_desc_t &md, int ndims,
blk.strides[dim_idx] = stride;

dim_t fib = full_inner_blks[dim_idx];
dim_t padded_dim = md.dims[dim_idx] == DNNL_RUNTIME_DIM_VAL
? DNNL_RUNTIME_DIM_VAL
const auto padded_dim = is_runtime_value(md.dims[dim_idx])
? runtime_value_for(md.padded_dims[dim_idx])
: (md.dims[dim_idx] + fib - 1) / fib * fib;
md.padded_dims[dim_idx] = padded_dim;
if (one_of(DNNL_RUNTIME_DIM_VAL, padded_dim, stride))
stride = DNNL_RUNTIME_DIM_VAL;
if (any_runtime_value(padded_dim, stride))
stride = runtime_value_for(stride);
else
stride *= (padded_dim / fib);
} else {
Expand Down
29 changes: 25 additions & 4 deletions src/common/memory_desc_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,14 @@ status_t fill_blocked(memory_desc_t &md, std::initializer_list<int> perm,

utils::array_set(md.padded_offsets, 0, md.ndims);
for (int d = 0; d < md.ndims; ++d)
md.padded_dims[d] = md.dims[d] == DNNL_RUNTIME_DIM_VAL
? DNNL_RUNTIME_DIM_VAL
md.padded_dims[d] = is_runtime_value(md.dims[d])
? runtime_value_for(md.padded_dims[d])
: utils::rnd_up(md.dims[d], blocks[d]);

// tracks max stride for integral overflow checks
dim_t max_stride = 1;
int max_stride_d = 0;

// setting the strides
{
dim_t stride = block_size;
Expand All @@ -72,14 +76,31 @@ status_t fill_blocked(memory_desc_t &md, std::initializer_list<int> perm,
blk.strides[d] = stride;

const dim_t pdim = md.padded_dims[d];
if (utils::one_of(DNNL_RUNTIME_DIM_VAL, stride, pdim))
stride = DNNL_RUNTIME_DIM_VAL;
if (any_runtime_value(stride, pdim))
stride = runtime_value_for(stride);
else if (pdim != 0)
stride *= pdim / blocks[d];

if (max_stride <= stride) {
max_stride = stride;
max_stride_d = d;
}

} while (iter_d != perm.begin());
}

const size_t dt_size = types::data_type_size(md.data_type);

// guard against integral overflow due to strides exceeding numeric limits
if (!is_runtime_value(md.padded_dims[max_stride_d])) {
size_t dim_val = static_cast<size_t>(
md.padded_dims[max_stride_d] / blocks[max_stride_d]);
dim_val = dim_val == (size_t)max_stride ? 1 : dim_val;
if (dim_val > SIZE_MAX / max_stride) return status::invalid_arguments;
if (dt_size && ((dim_val * max_stride) > SIZE_MAX / dt_size))
return status::invalid_arguments;
}

return status::success;
}

Expand Down
8 changes: 4 additions & 4 deletions src/common/memory_desc_wrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ struct memory_desc_wrapper : public c_compatible {
* is true, and the number of data elements otherwise */
dim_t nelems(bool with_padding = false) const {
if (is_zero()) return 0;
if (has_runtime_dims()) return DNNL_RUNTIME_DIM_VAL;
if (has_runtime_dims()) return runtime_value_for<dim_t>();
return utils::array_product(
with_padding ? padded_dims() : dims(), ndims());
}
Expand Down Expand Up @@ -301,7 +301,7 @@ struct memory_desc_wrapper : public c_compatible {
return 0;
}

if (has_runtime_dims_or_strides()) return DNNL_RUNTIME_SIZE_VAL;
if (has_runtime_dims_or_strides()) return runtime_value_for<size_t>();

if (is_wino_desc()) {
return wino_desc().size;
Expand Down Expand Up @@ -433,15 +433,15 @@ struct memory_desc_wrapper : public c_compatible {
/** returns true if at least one dim is not known */
bool has_runtime_dims() const {
for (int d = 0; d < ndims(); ++d)
if (dims()[d] == DNNL_RUNTIME_DIM_VAL) return true;
if (is_runtime_value(dims()[d])) return true;
return false;
}

/** returns true if at least one dim is not known */
bool has_runtime_strides() const {
if (!is_blocking_desc()) return false;
for (int d = 0; d < ndims(); ++d)
if (blocking_desc().strides[d] == DNNL_RUNTIME_DIM_VAL) return true;
if (is_runtime_value(blocking_desc().strides[d])) return true;
return false;
}

Expand Down
2 changes: 1 addition & 1 deletion src/common/memory_zero_pad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ status_t typed_zero_pad(const memory_t *memory, const exec_ctx_t &ctx) {
if (mdw.nelems(false) == mdw.nelems(true)) return success;

const size_t map_size = mdw.size();
assert(map_size != DNNL_RUNTIME_SIZE_VAL);
assert(!is_runtime_value(map_size));

void *mapped_ptr
= ctx.map_memory_storage(memory_storage, ctx.stream(), map_size);
Expand Down
4 changes: 2 additions & 2 deletions src/common/primitive_attr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ status_t post_ops_t::validate_binary(alg_kind_t alg,

// Additional check to restrict run-time dimension usage until supported.
for (int d = 0; d < user_src1_desc->ndims; ++d) {
VCHECK_ATTR(user_src1_desc->dims[d] != DNNL_RUNTIME_DIM_VAL,
VCHECK_ATTR(!is_runtime_value(user_src1_desc->dims[d]),
VERBOSE_RUNTIMEDIM_UNSUPPORTED);
}

Expand All @@ -251,7 +251,7 @@ status_t post_ops_t::validate_binary(alg_kind_t alg,
VCHECK_ATTR(memory_desc_sanity_check(*user_src2_desc),
VERBOSE_MEM_DESC_CHECK_FAIL);
for (int d = 0; d < user_src2_desc->ndims; ++d) {
VCHECK_ATTR(user_src2_desc->dims[d] != DNNL_RUNTIME_DIM_VAL,
VCHECK_ATTR(!is_runtime_value(user_src2_desc->dims[d]),
VERBOSE_RUNTIMEDIM_UNSUPPORTED);
}
}
Expand Down
62 changes: 41 additions & 21 deletions src/common/type_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1089,8 +1089,8 @@ inline bool memory_desc_strides_check(
if (md.padded_dims[d] == 0) return true;

// no strides verification for runtime dims
const bool has_runtime_dim = utils::one_of(
DNNL_RUNTIME_DIM_VAL, strides[d], md.padded_dims[d]);
const bool has_runtime_dim
= any_runtime_value(strides[d], md.padded_dims[d]);
if (has_runtime_dim) return true;

perm[d] = d;
Expand All @@ -1115,6 +1115,10 @@ inline bool memory_desc_strides_check(
};
std::sort(perm, perm + md.ndims, idx_sorter);

// tracks max stride for integral overflow checks
dim_t max_stride = 1;
int max_stride_d = 0;

dim_t min_stride = block_size;
for (int idx = 0; idx < md.ndims; ++idx) {
const int d = perm[idx];
Expand All @@ -1134,6 +1138,22 @@ inline bool memory_desc_strides_check(
// update min_stride for next iteration
const auto padded_dim = md.padded_dims[d];
min_stride = block_size * strides[d] * (padded_dim / blocks[d]);
if (max_stride <= strides[d]) {
max_stride = strides[d];
max_stride_d = d;
}
}

const size_t dt_size = types::data_type_size(md.data_type);

// guard against integral overflow due to strides exceeding numeric limits
if (!is_runtime_value(md.padded_dims[max_stride_d])) {
size_t dim_val = static_cast<size_t>(
md.padded_dims[max_stride_d] / blocks[max_stride_d]);
dim_val = dim_val == (size_t)max_stride ? 1 : dim_val;
if (dim_val > SIZE_MAX / max_stride) return false;
if (dt_size && ((dim_val * max_stride) > SIZE_MAX / dt_size))
return false;
}
return true;
}
Expand Down Expand Up @@ -1218,8 +1238,10 @@ inline status_t memory_desc_init_by_blocking_desc(

utils::simultaneous_sort(
mblk.strides, ou_blocks, perm, ndims, [](stride_t a, stride_t b) {
if (utils::one_of(DNNL_RUNTIME_DIM_VAL, a, b))
return DNNL_RUNTIME_DIM_VAL;
static_assert(runtime_value_for<stride_t>() < 0,
"negative value is expected");
if (any_runtime_value(a, b))
return runtime_value_for<stride_t>(); // negative: preserves order
return b - a;
});

Expand Down Expand Up @@ -1298,21 +1320,6 @@ format_tag_t memory_desc_matches_one_of_tag(
return format_tag::undef;
}

/** returns true if fp32 value denotes DNNL_RUNTIME_F32_VAL */
inline bool is_runtime_value(float val) {
return utils::bit_cast<unsigned>(val) == DNNL_RUNTIME_F32_VAL_REP.u;
}

/** returns true if s32 value denotes DNNL_RUNTIME_S32_VAL */
inline bool is_runtime_value(int val) {
return val == DNNL_RUNTIME_S32_VAL;
}

/** returns true if dim_t value denotes DNNL_RUNTIME_DIM_VAL */
inline bool is_runtime_value(dim_t val) {
return val == DNNL_RUNTIME_DIM_VAL;
}

inline bool memory_desc_sanity_check(int ndims, const dims_t dims,
data_type_t data_type, format_kind_t format_kind) {
using namespace data_type;
Expand All @@ -1324,10 +1331,23 @@ inline bool memory_desc_sanity_check(int ndims, const dims_t dims,
f8_e4m3, f16, bf16, f32, f64, s64, s32, s8, u8, s4, u4);
if (!ok) return false;

// A bounds check on the dimensions ensures that the tensor size
// computation does not trigger a overflow during memory creation.
dim_t prod = 1;
for (int d = 0; d < ndims; ++d) {
if (dims[d] != DNNL_RUNTIME_DIM_VAL) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (dims[d] != DNNL_RUNTIME_DIM_VAL) {
if (!is_runtime_value(dims[d])) {

if (dims[d] < 0) return false;
if (dims[d] > 0) {
if (prod > std::numeric_limits<dim_t>::max() / dims[d])
return false;
prod *= dims[d];
}
}
}

bool has_runtime_dims = false;
for (int d = 0; d < ndims; ++d) {
if (dims[d] != DNNL_RUNTIME_DIM_VAL && dims[d] < 0) return false;
if (dims[d] == DNNL_RUNTIME_DIM_VAL) has_runtime_dims = true;
if (is_runtime_value(dims[d])) has_runtime_dims = true;
}

if (has_runtime_dims) {
Expand Down
Loading
Loading