Skip to content

benchdnn: mem check: touch up #3153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tests/benchdnn/binary/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ int checkit(std::vector<benchdnn_dnnl_wrapper_t<dnnl_primitive_t>> &v_prim,

int doit(const std::vector<benchdnn_dnnl_wrapper_t<dnnl_primitive_t>> &v_prim,
const prb_t *prb, res_t *res) {
set_zmalloc_max_expected_size(res->mem_size_args.zmalloc_expected_size);

const auto &prim = v_prim[0];

dnn_mem_map_t mem_map, ref_mem_map;
Expand All @@ -306,7 +308,7 @@ int doit(const std::vector<benchdnn_dnnl_wrapper_t<dnnl_primitive_t>> &v_prim,

SAFE(execute_and_wait(prim, args, res), WARN);

check_correctness(prb, {DST}, args, ref_args, setup_cmp, res);
check_correctness(prb, {DST}, args, ref_args, setup_cmp, res, prb->dir);
SAFE(check_bitwise(prim, {DST}, args, prb->attr, prb->inplace, res), WARN);

return measure_perf(prb->ctx_exe, res, prim, args);
Expand Down
2 changes: 1 addition & 1 deletion tests/benchdnn/binary/binary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ int init_ref_memory_args(dnn_mem_map_t &ref_mem_map, dnn_mem_map_t &mem_map,

void skip_unimplemented_prb(const prb_t *prb, res_t *res);
void skip_invalid_prb(const prb_t *prb, res_t *res);
void compute_ref(const prb_t *prb, const args_t &args,
void compute_ref(const prb_t *prb, dir_t dir, const args_t &args,
dnnl_primitive_t prim_ref = nullptr);

int createit(std::vector<benchdnn_dnnl_wrapper_t<dnnl_primitive_t>> &v_prim,
Expand Down
4 changes: 2 additions & 2 deletions tests/benchdnn/binary/ref_binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

namespace binary {

void compute_ref(
const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
void compute_ref(const prb_t *prb, dir_t dir, const args_t &args,
dnnl_primitive_t prim_ref) {

const dnn_mem_t &src0 = args.find(DNNL_ARG_SRC_0);
const dnn_mem_t &src1 = args.find(DNNL_ARG_SRC_1);
Expand Down
25 changes: 8 additions & 17 deletions tests/benchdnn/bnorm/bnorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,8 @@ int init_ref_memory_args(dnn_mem_map_t &ref_mem_map, dnn_mem_map_t &mem_map,
dnnl_primitive_t prim_ref) {
if (has_bench_mode_modifier(mode_modifier_t::no_ref_memory)) return OK;

if (!ref_mem_map.empty()) { erase_unused_args(ref_mem_map, mem_map); }

// TODO: this function still allocates the full memory print needed to fill
// the data and each argument can't be destroyed right away since filling
// requires all of them at a time.
Expand All @@ -638,16 +640,7 @@ int init_ref_memory_args(dnn_mem_map_t &ref_mem_map, dnn_mem_map_t &mem_map,

switch (exec_arg) {
case DNNL_ARG_DST:
if (prb->dir & FLAG_BWD) {
// Stash for backward which is used in reference code:
// src_hat[i] = (src[i] - mean) / sqrt(var + prb->eps)
ref_mem_map.emplace(DNNL_ARG_DST_1,
dnn_mem_t(mem.md_, dnnl_f32, tag::abx, ref_engine,
/* prefill = */ false));
}
break;
case DNNL_ARG_DIFF_SRC: break; // Skip on backward.
case DNNL_ARG_DST_1: break; // Skip on backward.
case DNNL_ARG_MEAN:
case DNNL_ARG_VARIANCE:
if (prb->dir & FLAG_INF) {
Expand All @@ -656,11 +649,6 @@ int init_ref_memory_args(dnn_mem_map_t &ref_mem_map, dnn_mem_map_t &mem_map,
tag::abx, ref_engine, /* prefill = */ false);
}
break;
case DNNL_ARG_WORKSPACE: {
ref_mem_map[exec_arg] = dnn_mem_t(mem.md_, dnnl_u8, tag::abx,
ref_engine, /* prefill = */ false);
break;
}
default: break;
}
}
Expand All @@ -678,7 +666,8 @@ int init_ref_memory_args(dnn_mem_map_t &ref_mem_map, dnn_mem_map_t &mem_map,
}

// Reference code uses different kind of workspace. Adjust to ref needs.
if (ref_mem_map.count(DNNL_ARG_WORKSPACE)) {
// Insert a reference workspace if it comes non-empty from the library side.
if (query_md_ndims(mem_map.at(DNNL_ARG_WORKSPACE).md_) > 0 && is_fwd_prim) {
const auto &src_md = ref_mem_map[DNNL_ARG_SRC].md_;
ref_mem_map[DNNL_ARG_WORKSPACE] = dnn_mem_t(
src_md, dnnl_u8, tag::abx, ref_engine, /* prefill = */ false);
Expand Down Expand Up @@ -735,6 +724,8 @@ int checkit(std::vector<benchdnn_dnnl_wrapper_t<dnnl_primitive_t>> &v_prim,

int doit(const std::vector<benchdnn_dnnl_wrapper_t<dnnl_primitive_t>> &v_prim,
const prb_t *prb, res_t *res) {
set_zmalloc_max_expected_size(res->mem_size_args.zmalloc_expected_size);

const auto &prim = prb->dir & FLAG_FWD ? v_prim[0] : v_prim[1];

dnn_mem_map_t mem_map, ref_mem_map;
Expand All @@ -749,7 +740,7 @@ int doit(const std::vector<benchdnn_dnnl_wrapper_t<dnnl_primitive_t>> &v_prim,
SAFE(execute_and_wait(v_prim[0], args, res), WARN);

check_correctness(prb, get_kinds_to_check(prb, FLAG_FWD), args, ref_args,
setup_cmp, res);
setup_cmp, res, FLAG_FWD);
SAFE(check_bitwise(prim, get_kinds_to_check(prb, FLAG_FWD), args, prb->attr,
prb->inplace, res),
WARN);
Expand All @@ -770,7 +761,7 @@ int doit(const std::vector<benchdnn_dnnl_wrapper_t<dnnl_primitive_t>> &v_prim,
SAFE(execute_and_wait(v_prim[1], args, res), WARN);

check_correctness(prb, get_kinds_to_check(prb, FLAG_BWD), args,
ref_args, setup_cmp, res);
ref_args, setup_cmp, res, FLAG_BWD);
SAFE(check_bitwise(prim, get_kinds_to_check(prb, FLAG_BWD), args,
prb->attr, prb->inplace, res),
WARN);
Expand Down
2 changes: 1 addition & 1 deletion tests/benchdnn/bnorm/bnorm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ int init_ref_memory_args(dnn_mem_map_t &ref_mem_map, dnn_mem_map_t &mem_map,

void skip_unimplemented_prb(const prb_t *prb, res_t *res);
void skip_invalid_prb(const prb_t *prb, res_t *res);
void compute_ref(const prb_t *prb, const args_t &args,
void compute_ref(const prb_t *prb, dir_t dir, const args_t &args,
dnnl_primitive_t prim_ref = nullptr);

int createit(std::vector<benchdnn_dnnl_wrapper_t<dnnl_primitive_t>> &v_prim,
Expand Down
26 changes: 14 additions & 12 deletions tests/benchdnn/bnorm/ref_bnorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@ void compute_ref_fwd(const prb_t *prb, const args_t &args) {
const dnn_mem_t &sh = args.find(DNNL_ARG_SHIFT);
const dnn_mem_t &ws = args.find(DNNL_ARG_WORKSPACE);
const dnn_mem_t &dst = args.find(DNNL_ARG_DST);
const dnn_mem_t &src_hat = args.find(DNNL_ARG_DST_1);

uint8_t *ws_ptr = (uint8_t *)ws;
float *dst_ptr = (float *)dst;

const int64_t MB = prb->mb;
const int64_t C = prb->ic;
Expand Down Expand Up @@ -63,16 +59,20 @@ void compute_ref_fwd(const prb_t *prb, const args_t &args) {
float res = gamma * x_hat + beta;
if (fuse_add_relu) res += src_add.get_f32_elem(off);
if (fuse_relu && res < 0) res = 0;
if (need_ws) ws_ptr[off] = !!res;
if (need_ws) { ws.set_elem(off, !!res); }
maybe_post_ops(attr, res);
dst_ptr[off] = res;
if (prb->dir & FLAG_BWD) src_hat.set_f32_elem(off, x_hat);
// Write to dst only for forward, backward will stash necessary
// values in src memory.
if (prb->dir & FLAG_FWD) dst.set_f32_elem(off, res);
// Write the update value back in `SRC` to save on computations on
// backward. `src_hat[i] = (src[i] - mean) / sqrt(var + prb->eps)`
if (prb->dir & FLAG_BWD) src.set_f32_elem(off, x_hat);
}
});
}

void compute_ref_bwd(const prb_t *prb, const args_t &args) {
const dnn_mem_t &src_hat = args.find(DNNL_ARG_DST_1);
const dnn_mem_t &src_hat = args.find(DNNL_ARG_SRC);
const dnn_mem_t &var = args.find(DNNL_ARG_VARIANCE);
const dnn_mem_t &d_dst = args.find(DNNL_ARG_DIFF_DST);
const dnn_mem_t &sc = args.find(DNNL_ARG_SCALE);
Expand Down Expand Up @@ -134,12 +134,14 @@ void compute_ref_bwd(const prb_t *prb, const args_t &args) {
});
}

void compute_ref(
const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
void compute_ref(const prb_t *prb, dir_t dir, const args_t &args,
dnnl_primitive_t prim_ref) {
// Running fwd ref on bwd to collect src_hat (used instead of src + mean)
// and ws, if fuse_relu flag is requested.
compute_ref_fwd(prb, args);
if (prb->dir & FLAG_BWD) compute_ref_bwd(prb, args);
if (dir & FLAG_FWD)
compute_ref_fwd(prb, args);
else
compute_ref_bwd(prb, args);
}

} // namespace bnorm
2 changes: 1 addition & 1 deletion tests/benchdnn/brgemm/brgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1332,7 +1332,7 @@ int doit(const prb_t *prb, res_t *res) {
res->state = EXECUTED;

if (has_bench_mode_bit(mode_bit_t::corr)) {
check_correctness(prb, {DST}, args, ref_args, setup_cmp, res);
check_correctness(prb, {DST}, args, ref_args, setup_cmp, res, prb->dir);
}

// Create a bind to match internals to run performance measurements.
Expand Down
2 changes: 1 addition & 1 deletion tests/benchdnn/brgemm/brgemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ int brgemm_finalize();

void skip_unimplemented_prb(const prb_t *prb, res_t *res);
void skip_invalid_prb(const prb_t *prb, res_t *res);
void compute_ref(const prb_t *prb, const args_t &args,
void compute_ref(const prb_t *prb, dir_t dir, const args_t &args,
dnnl_primitive_t prim_ref = nullptr);

int doit(const prb_t *prb, res_t *res);
Expand Down
4 changes: 2 additions & 2 deletions tests/benchdnn/brgemm/ref_brgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ void compute_ref_brgemm(const prb_t *prb, const args_t &args) {
});
}

void compute_ref(
const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
void compute_ref(const prb_t *prb, dir_t dir, const args_t &args,
dnnl_primitive_t prim_ref) {
if (prim_ref) {
SAFE_V(execute_and_wait(prim_ref, args));
return;
Expand Down
20 changes: 14 additions & 6 deletions tests/benchdnn/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,20 +270,28 @@ struct memory_registry_t {
}

void set_expected_max(size_t size) {
expected_max_ = size;
constexpr float expected_trh = 1.1f; // Smooth out small allocations.
expected_max_ = static_cast<size_t>(expected_trh * size);
has_warned_ = false;
warn_size_check();
}

private:
size_t size() const { return total_size_; }
void warn_size_check() {
if (expected_max_ != unset_ && !has_warned_
&& total_size_ > expected_max_) {
// Switch to WARNING once existing failures are resolved
const bool is_max_set = expected_max_ != unset_;
// Verify the total amount of allocated memory when it starts exceeding
// 1 GB threshold. Small amount of memory is highly unlikely cause OOM.
// There's an idea to add a portion of RAM into account as well, keep
// only 1 GB so far to check if it proves working well.
const bool is_total_size_big = total_size_ >= 1024 * 1024 * 1024;
const bool is_total_size_unexpected = total_size_ > expected_max_;
if (!has_warned_ && is_max_set && is_total_size_big
&& is_total_size_unexpected) {
BENCHDNN_PRINT(0,
"[CHECK_MEM][INFO]: memory use underestimated, "
"zmalloc allocations exceed %s\n",
"[CHECK_MEM][ERROR]: Memory use is underestimated. Current "
"allocation size: %s; expected size: %s.\n",
smart_bytes(total_size_).c_str(),
smart_bytes(expected_max_).c_str());
// Prevent spamming logs with subsequent overflowing allocations;
has_warned_ = true;
Expand Down
4 changes: 3 additions & 1 deletion tests/benchdnn/concat/concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ int checkit(std::vector<benchdnn_dnnl_wrapper_t<dnnl_primitive_t>> &v_prim,

int doit(const std::vector<benchdnn_dnnl_wrapper_t<dnnl_primitive_t>> &v_prim,
const prb_t *prb, res_t *res) {
set_zmalloc_max_expected_size(res->mem_size_args.zmalloc_expected_size);

const auto &prim = v_prim[0];

dnn_mem_map_t mem_map, ref_mem_map;
Expand All @@ -236,7 +238,7 @@ int doit(const std::vector<benchdnn_dnnl_wrapper_t<dnnl_primitive_t>> &v_prim,

SAFE(execute_and_wait(prim, args, res), WARN);

check_correctness(prb, {DST}, args, ref_args, setup_cmp, res);
check_correctness(prb, {DST}, args, ref_args, setup_cmp, res, prb->dir);
SAFE(check_bitwise(prim, {DST}, args, prb->attr, prb->inplace, res), WARN);

return measure_perf(prb->ctx_exe, res, prim, args);
Expand Down
2 changes: 1 addition & 1 deletion tests/benchdnn/concat/concat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ int init_ref_memory_args(dnn_mem_map_t &ref_mem_map, dnn_mem_map_t &mem_map,

void skip_unimplemented_prb(const prb_t *prb, res_t *res);
void skip_invalid_prb(const prb_t *prb, res_t *res);
void compute_ref(const prb_t *prb, const args_t &args,
void compute_ref(const prb_t *prb, dir_t dir, const args_t &args,
dnnl_primitive_t prim_ref = nullptr);

int createit(std::vector<benchdnn_dnnl_wrapper_t<dnnl_primitive_t>> &v_prim,
Expand Down
4 changes: 2 additions & 2 deletions tests/benchdnn/concat/ref_concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ void get_sizes(const prb_t *prb, int64_t &outer_size, int64_t &inner_size,
axis_size = prb->axis_size();
}

void compute_ref(
const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
void compute_ref(const prb_t *prb, dir_t dir, const args_t &args,
dnnl_primitive_t prim_ref) {
const dnn_mem_t &dst = args.find(DNNL_ARG_DST);

float *dst_ptr = (float *)dst;
Expand Down
31 changes: 30 additions & 1 deletion tests/benchdnn/conv/conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,15 @@ int check_reorder_presence(
= prb->get_dt(WEI) == dnnl_s8 && prb->get_dt(SRC) == dt_check;
const bool is_def_zp = prb->attr.zero_points.is_def(DNNL_ARG_SRC);
if (wei_x8x8 || !is_def_zp) {
// A work around zmalloc registry checker: temporarily increase the
// capacity just for this check since there's no simple way to account
// for memory allocated here to verify an extra reorder.
size_t extra_reorder_mem_size
= (dnnl_memory_desc_get_size(mem_fp.md_) / 4)
+ dnnl_memory_desc_get_size(mem_dt.md_);
res->mem_size_args.zmalloc_expected_size += extra_reorder_mem_size;
set_zmalloc_max_expected_size(res->mem_size_args.zmalloc_expected_size);

// Check that s8 -> s8_comp exists in the library since users may have
// already quantized data.
dnn_mem_t mem_fp_s8(mem_fp.md_, dnnl_s8, tag::abx, get_cpu_engine(),
Expand All @@ -120,6 +129,14 @@ int check_reorder_presence(
SAFE(mem_dt.size() == mem_dt_s8.size() ? OK : FAIL, WARN);
int rc = std::memcmp((void *)mem_dt, (void *)mem_dt_s8, mem_dt.size());
SAFE(rc == 0 ? OK : FAIL, WARN);

// Subtract to restore the original size.
res->mem_size_args.zmalloc_expected_size -= extra_reorder_mem_size;
}
// Must be done in a separate scope to have extra memory objects destroyed
// before updating the limit to an original value.
if (wei_x8x8 || !is_def_zp) {
set_zmalloc_max_expected_size(res->mem_size_args.zmalloc_expected_size);
}

return OK;
Expand Down Expand Up @@ -574,6 +591,9 @@ int checkit(std::vector<benchdnn_dnnl_wrapper_t<dnnl_primitive_t>> &v_prim,
if (res_copy.state == SKIPPED) {
v_prim[1].reset(nullptr);
SAFE(check_total_size(res), WARN);
} else {
// Copy estimations back to original `res`.
*res = res_copy;
}
} else {
SAFE(check_total_size(res), WARN);
Expand All @@ -588,6 +608,15 @@ int checkit(std::vector<benchdnn_dnnl_wrapper_t<dnnl_primitive_t>> &v_prim,

int doit(const std::vector<benchdnn_dnnl_wrapper_t<dnnl_primitive_t>> &v_prim,
const prb_t *prb, res_t *res) {
set_zmalloc_max_expected_size(res->mem_size_args.zmalloc_expected_size);
// TODO: move Winograd's reference implementation scratchpad to a dedicated
// class for ability to query sizes.
// So far, just increase the size twice and let it roll.
if (prb->alg == WINO) {
set_zmalloc_max_expected_size(
2 * res->mem_size_args.zmalloc_expected_size);
}

const auto &prim = v_prim[0];
const auto &prim_ref = v_prim[1];

Expand All @@ -602,7 +631,7 @@ int doit(const std::vector<benchdnn_dnnl_wrapper_t<dnnl_primitive_t>> &v_prim,
SAFE(execute_and_wait(prim, args, res), WARN);

check_correctness(prb, get_kinds_to_check(prb), args, ref_args, setup_cmp,
res, prim_ref);
res, prb->dir, prim_ref);
SAFE(check_bitwise(prim, get_kinds_to_check(prb), args, prb->attr,
prb->inplace, res),
WARN);
Expand Down
2 changes: 1 addition & 1 deletion tests/benchdnn/conv/conv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ int init_ref_memory_args(dnn_mem_map_t &ref_mem_map, dnn_mem_map_t &mem_map,

void skip_unimplemented_prb(const prb_t *prb, res_t *res);
void skip_invalid_prb(const prb_t *prb, res_t *res);
void compute_ref(const prb_t *prb, const args_t &args,
void compute_ref(const prb_t *prb, dir_t dir, const args_t &args,
dnnl_primitive_t prim_ref = nullptr);

int createit(std::vector<benchdnn_dnnl_wrapper_t<dnnl_primitive_t>> &v_prim,
Expand Down
10 changes: 5 additions & 5 deletions tests/benchdnn/conv/ref_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -475,13 +475,13 @@ void compute_ref_bwd_w(
}
}

void compute_ref(
const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
if (prb->dir & FLAG_FWD)
void compute_ref(const prb_t *prb, dir_t dir, const args_t &args,
dnnl_primitive_t prim_ref) {
if (dir & FLAG_FWD)
compute_ref_fwd(prb, args, prim_ref);
else if (prb->dir == BWD_D)
else if (dir == BWD_D)
compute_ref_bwd_d(prb, args, prim_ref);
else if (prb->dir & FLAG_BWD && prb->dir & FLAG_WEI)
else if ((dir & FLAG_BWD) && (dir & FLAG_WEI))
compute_ref_bwd_w(prb, args, prim_ref);
}

Expand Down
Loading
Loading