Skip to content
4 changes: 2 additions & 2 deletions tests/benchdnn/bnorm/bnorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -513,8 +513,8 @@ void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind,
// Since bwd testing is done using results from forward which are random
// fp32 values, diff_scale starts fluctuating, so we check norm for both
// data, SC, and SH.
const bool compare_with_norm = (prb->dir & FLAG_BWD);
cmp.set_norm_validation_mode(compare_with_norm);
const bool allow_norm_check = (prb->dir & FLAG_BWD);
cmp.set_allow_norm_check(allow_norm_check);

// Digits must be non-negative for safe left-shifting when `digits_dt`
// exceeds `digits_f32`.
Expand Down
1 change: 0 additions & 1 deletion tests/benchdnn/brgemm/brgemm_aux.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ dnnl_data_type_t prb_t::get_dt(data_kind_t data_kind) const {
case WEI: return wei_dt();
case BIA: return bia_dt;
case DST: return dst_dt();
case ACC: return acc_dt();
default: assert(!"unexpected"); return dnnl_data_type_undef;
}
}
Expand Down
4 changes: 2 additions & 2 deletions tests/benchdnn/conv/conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,8 +441,8 @@ void skip_invalid_prb(const prb_t *prb, res_t *res) {}

void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind,
const args_t &ref_args) {
const bool compare_with_norm = (prb->alg & WINO);
cmp.set_norm_validation_mode(compare_with_norm);
const bool allow_norm_check = (prb->alg & WINO);
cmp.set_allow_norm_check(allow_norm_check);

float trh = 0.f;
if (prb->alg & WINO) {
Expand Down
4 changes: 2 additions & 2 deletions tests/benchdnn/deconv/deconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,8 +373,8 @@ void skip_invalid_prb(const prb_t *prb, res_t *res) {}

void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind,
const args_t &ref_args) {
const bool compare_with_norm = (prb->alg & WINO);
cmp.set_norm_validation_mode(compare_with_norm);
const bool allow_norm_check = (prb->alg & WINO);
cmp.set_allow_norm_check(allow_norm_check);

float trh = 0.f;
if (prb->alg & WINO) {
Expand Down
4 changes: 4 additions & 0 deletions tests/benchdnn/dnnl_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,10 @@ void finalize() {

inline int measure_perf_individual(timer::timer_t &t, dnnl_stream_t stream,
perf_function_t &perf_func, std::vector<dnnl_exec_arg_t> &dnnl_args) {
// Warm-up run.
DNN_SAFE(perf_func(stream, dnnl_args), WARN);
DNN_SAFE(dnnl_stream_wait(stream), CRIT);

cold_cache_t cold_cache(dnnl_args, stream);

t.reset();
Expand Down
4 changes: 2 additions & 2 deletions tests/benchdnn/gnorm/gnorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -534,8 +534,8 @@ void skip_invalid_prb(const prb_t *prb, res_t *res) {

void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind,
const args_t &ref_args) {
const bool compare_with_norm = (prb->dir & FLAG_BWD);
cmp.set_norm_validation_mode(compare_with_norm);
const bool allow_norm_check = (prb->dir & FLAG_BWD);
cmp.set_allow_norm_check(allow_norm_check);

const auto dt = prb->dir & FLAG_FWD ? prb->dt[1] : prb->dt[0];
// Digits must be non-negative for safe left-shifting when `digits_dt`
Expand Down
45 changes: 12 additions & 33 deletions tests/benchdnn/graph/ref_primitive.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ void ref_primitive_t::check_correctness(
{DNNL_ARG_BIAS, BIA},
{DNNL_ARG_DIFF_BIAS, BIA},
{DNNL_ARG_DST, DST},
{DNNL_ARG_DST_1, DST_1},
{DNNL_ARG_DST_1, SDPA_STATS},
{DNNL_ARG_DIFF_SRC_0, DST},
{DNNL_ARG_SRC_1, SRC_1},
{DNNL_ARG_MEAN, MEAN},
Expand Down Expand Up @@ -355,8 +355,8 @@ void ref_primitive_t::check_correctness(
const auto &mem_dt = args.find(arg);
const auto &mem_fp = args_.find(arg);

if (dnnl_arg_2_data_kind_map.find(arg)
== dnnl_arg_2_data_kind_map.end()) {
auto it = dnnl_arg_2_data_kind_map.find(arg);
if (it == dnnl_arg_2_data_kind_map.end()) {
BENCHDNN_PRINT(1, "Output arg %d is unsupported!\n", arg);
res->state = UNIMPLEMENTED;
return;
Expand All @@ -365,36 +365,12 @@ void ref_primitive_t::check_correctness(
attr_t attr;
SWITCH_DRIVER(CASE_CHECK_CORRECTNESS, CASE_CUSTOM_CHECK_CORRECTNESS);

cmp.set_data_kind(it->second);
cmp.set_has_eltwise_post_op(has_eltwise);
cmp.set_op_output_has_nans(has_nans);
dnn_mem_t mem_fp_abx(mem_fp, dnnl_f32, tag::abx, ::get_cpu_engine());
// Reset `res` counters when more than a single arg is checked.
res->errors = 0;
res->total = 0;
auto st = cmp.compare(mem_fp_abx, mem_dt, attr, res);
if (st == OK) continue;

// If comparison failed, try a norm comparison. However, at this point,
// to limit the risk of hiding issues, the norm comparison is enabled
// if number of affected points is really small compared to the total
// number of points - 1 point per every 1024.
// This can be revisited later.
const size_t allowed_error_points = res->total / 1024;
const bool norm_check_allowed = allowed_error_points >= res->errors;

BENCHDNN_PRINT(0,
"[COMPARE_STATS] Norm check is %s; error_to_total_ratio: "
"%zu/%zu; allowed_ratio: %zu/%zu;\n",
norm_check_allowed ? "allowed" : "prohibited", res->errors,
res->total, allowed_error_points, res->total);

if (!norm_check_allowed) continue;

// Reset the `res` statistics state.
res->state = EXECUTED;
res->errors = 0;
res->total = 0;

// `cmp` object has internal knowledge on when this check must be
// enabled.
cmp.set_allow_norm_check(true);
// TODO: there's an open question with how to determine the threshold
// and what the criteria to use. Unless a partition says it is some
// complex fusion (such as SDP) with a specific data type, setting such
Expand All @@ -412,8 +388,11 @@ void ref_primitive_t::check_correctness(
//
// Note: the following threshold is obtained from actual runs on
// different hardware.
cmp.set_threshold(2.5e-3f);
cmp.set_norm_validation_mode(true);
cmp.set_threshold_norm(2.5e-3f);
dnn_mem_t mem_fp_abx(mem_fp, dnnl_f32, tag::abx, ::get_cpu_engine());
// Clear previous output stats.
auto cur_res_state = res->state;
res->reset_stats(cur_res_state);
cmp.compare(mem_fp_abx, mem_dt, attr, res);
}
}
Expand Down
4 changes: 3 additions & 1 deletion tests/benchdnn/inputs/rnn/harness_gru_regression
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
# int8 SIC != SLC
--reset --trivial-strides=true --prop=FWD_I --alg=VANILLA_GRU --activation=UNDEF --direction=left2right --cfg=u8u8u8f32 l1t32mb100sic128slc256dhc128dic128
--reset
--trivial-strides=true --prop=FWD_I --alg=VANILLA_GRU --activation=UNDEF
--direction=left2right --cfg=u8u8u8f32 l1t47mb100sic128slc256dhc128dic128
2 changes: 1 addition & 1 deletion tests/benchdnn/inputs/rnn/shapes_small
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# small shapes

l8t3mb12_sic16_n"uniform"
l4t2mb20_sic36_n"uniform:unroll_tail"
l4t3mb20_sic36_n"uniform:unroll_tail"
l1t2mb6_sic16_slc32_n"non-uniform:slc_neq_sic"
l1t1mb7_sic17_dhc34_n"non-uniform:slc_neq_dhc_tail"
l1t1mb3_sic16_slc32_dhc64_n"non-uniform:slc_neq_sic_neq_dhc"
Expand Down
4 changes: 2 additions & 2 deletions tests/benchdnn/lnorm/lnorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,8 +497,8 @@ void skip_invalid_prb(const prb_t *prb, res_t *res) {

void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind,
const args_t &ref_args) {
const bool compare_with_norm = (prb->dir & FLAG_BWD);
cmp.set_norm_validation_mode(compare_with_norm);
const bool allow_norm_check = (prb->dir & FLAG_BWD);
cmp.set_allow_norm_check(allow_norm_check);

const auto dt = prb->dir & FLAG_FWD ? prb->dt[1] : prb->dt[0];
// Digits must be non-negative for safe left-shifting when `digits_dt`
Expand Down
33 changes: 19 additions & 14 deletions tests/benchdnn/rnn/rnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -993,16 +993,6 @@ void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind,
if (prb->prop == dnnl_backward) acc_dim *= MAX2(bwdd_acc_dim, bwdw_acc_dim);
// Here the factor 4 just gives some wiggle room for fp32 testing

// Note: the following process of picking a `trh` is likely fine for
// floating-point problems but doesn't suit well for int8. It may happen
// that underlying target implementation will compute DST[i] and DST_ITER[i]
// with small difference around X.5f point ending up rounded differently
// leading to a difference in the output. Turned out, one incorrect point
// leads to norm comparison failure which doesn't make norm validation
// meaningful.
// TODO: consider moving int8 config (DST_ITER only?) on per point check
// with additional verification that underlying sources can have diff_1
// (though slightly changing shapes can work around failures).
float trh = 4
* (1 + (prb->prop == dnnl_backward)) // double wiggle room for bwd
* ((prb->direction == dnnl_bidirectional_sum)
Expand Down Expand Up @@ -1039,10 +1029,10 @@ void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind,
// as long as we get precise u8 intermediate results (and so far we do),
// the f32 result should be pretty accurate -- the dequantization is just
// two simple ops: f32 = scale * u8 + shift.
bool check_p2p = (prb->skip_nonlinear
|| ((prb->n_layer == 1) && (prb->n_iter == 1)));
if (prb->is_int8() && rnn_kind == DST_ITER_C) check_p2p = false;
cmp.set_norm_validation_mode(!check_p2p);
const bool disallow_norm_check = prb->skip_nonlinear
|| (prb->n_layer == 1 && prb->n_iter == 1)
|| (prb->is_int8() && rnn_kind == DST_ITER_C);
cmp.set_allow_norm_check(!disallow_norm_check);

const auto rnn_add_check =
[&, prb](const compare::compare_t::driver_check_func_args_t &args) {
Expand All @@ -1054,6 +1044,21 @@ void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind,
&& prb->prop == dnnl_backward) {
return args.diff < args.trh;
}

// When a problem uses int computations, DST_ITER(_C) is computed using
// DST_LAYER. However, the library part can compute LAYER and ITER in
// parallel, which can lead to off-by-1 issue for ITER part.
// Reconstruct original DST_LAYER values on got and exp sides and if
// they are off-by-1, let them through.
if (prb->cfg.is_int8()
&& (args.dk == rnn_data_kind2data_kind(DST_ITER)
|| args.dk == rnn_data_kind2data_kind(DST_ITER_C))) {
const int exp_q = static_cast<int>(
args.exp * prb->data_scale + prb->data_shift);
const int got_q = static_cast<int>(
args.got * prb->data_scale + prb->data_shift);
return abs(got_q - exp_q) <= 1;
}
return false;
};
cmp.set_driver_check_function(rnn_add_check);
Expand Down
22 changes: 22 additions & 0 deletions tests/benchdnn/rnn/rnn_aux.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,28 @@ rnn_data_kind_t data_kind2rnn_data_kind(data_kind_t data_kind) {
return KIND_TOTAL;
}

data_kind_t rnn_data_kind2data_kind(rnn_data_kind_t rnn_data_kind) {
switch (rnn_data_kind) {
case rnn_data_kind_t::DST_LAYER: return data_kind_t::DST;
case rnn_data_kind_t::DST_ITER: return data_kind_t::DST_ITER;
case rnn_data_kind_t::DST_ITER_C: return data_kind_t::DST_ITER_C;
case rnn_data_kind_t::DIFF_SRC_LAYER: return data_kind_t::SRC;
case rnn_data_kind_t::DIFF_AUGRU_ATTENTION:
return data_kind_t::AUGRU_ATTENTION;
case rnn_data_kind_t::DIFF_SRC_ITER: return data_kind_t::SRC_ITER;
case rnn_data_kind_t::DIFF_SRC_ITER_C: return data_kind_t::SRC_ITER_C;
case rnn_data_kind_t::DIFF_WEIGHTS_LAYER: return data_kind_t::WEI;
case rnn_data_kind_t::DIFF_WEIGHTS_ITER: return data_kind_t::WEI_ITER;
case rnn_data_kind_t::DIFF_WEIGHTS_PEEPHOLE:
return data_kind_t::WEI_PEEPHOLE;
case rnn_data_kind_t::DIFF_WEIGHTS_PROJECTION:
return data_kind_t::WEI_PROJECTION;
case rnn_data_kind_t::DIFF_BIAS: return data_kind_t::BIA;
default: assert(!"unknown data kind");
}
return DAT_TOTAL;
}

void prb_t::set_qparams(float fp_min, float fp_max) {
if (!cfg.is_int8()) {
data_shift = 0.;
Expand Down
1 change: 1 addition & 0 deletions tests/benchdnn/rnn/rnn_aux.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ void gates_reduction(
const prb_t &prb, const float *b_gates_, float *diff_bias_);

rnn_data_kind_t data_kind2rnn_data_kind(data_kind_t data_kind);
data_kind_t rnn_data_kind2data_kind(rnn_data_kind_t rnn_data_kind);

}; // namespace rnn

Expand Down
2 changes: 1 addition & 1 deletion tests/benchdnn/self/norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ static int check_compare_norm() {
res_t res_bad {};
res_bad.state = EXECUTED;
compare::compare_t cmp;
cmp.set_norm_validation_mode(true);
cmp.set_allow_norm_check(true);
cmp.set_threshold(
sqrt(N) / sqrt(exp_sq_sum0) - 10.f * epsilon_dt(dnnl_f32));
cmp.compare(m0, m1, attr_t(), &res_bad);
Expand Down
2 changes: 1 addition & 1 deletion tests/benchdnn/softmax/softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind,
const float trh = is_flt_or_dbl || is_relaxed_xf16 ? trh_f32 : 0.f;
#endif
cmp.set_threshold(trh);
if (driver_name == "graph" && kind == DST_1) {
if (driver_name == "graph" && kind == SDPA_STATS) {
// softmax stats is computed with eltwise-log, which has a different
// and larger threshold than softmax. So we need to adjust the threshold
// for this case.
Expand Down
Loading
Loading