Skip to content

Commit 9ebe7f5

Browse files
committed
Apply suggestions from code review
1 parent 101e82e commit 9ebe7f5

File tree

2 files changed

+56
-53
lines changed

2 files changed

+56
-53
lines changed

src/plugins/intel_cpu/src/cpu_memory.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "memory_desc/cpu_memory_desc_utils.h"
1010
#include "nodes/common/cpu_memcpy.h"
1111
#include "nodes/reorder.h"
12+
#include "utils/bfloat16.hpp"
1213
#include "utils/debug_capabilities.h"
1314
#if defined(__linux__)
1415
# include <sys/syscall.h> /* Definition of SYS_* constants */
@@ -38,9 +39,11 @@ inline void setSubnormalsToZeroAndbf16Saturation(float* data, size_t size, bool
3839
if (ftz && ((u32data[i] & (0xFF << 23)) == 0)) {
3940
u32data[i] = 0;
4041
} else if (bf16saturation && !std::isnan(floatdata[i]) && !std::isinf(floatdata[i])) {
41-
floatdata[i] = (floatdata[i] < -3.3895313899137927e38f) ? -3.3895313899137927e38f
42-
: (floatdata[i] > 3.3895313899137927e38f) ? 3.3895313899137927e38f
43-
: floatdata[i];
42+
floatdata[i] = (floatdata[i] < static_cast<float>(std::numeric_limits<ov::bfloat16>::lowest()))
43+
? static_cast<float>(std::numeric_limits<ov::bfloat16>::lowest())
44+
: (floatdata[i] > static_cast<float>(std::numeric_limits<ov::bfloat16>::max()))
45+
? static_cast<float>(std::numeric_limits<ov::bfloat16>::max())
46+
: floatdata[i];
4447
}
4548
}
4649
}

src/plugins/intel_cpu/src/nodes/input.cpp

Lines changed: 50 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ namespace node {
2323

2424
#if defined(OPENVINO_ARCH_X86_64)
2525
namespace {
26-
struct jit_subnormals_bf16saturation_check_base : public jit_generator {
27-
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_subnormals_bf16saturation_check_base)
26+
struct jit_has_special_value_base : public jit_generator {
27+
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_has_special_value_base)
2828

2929
typedef struct {
3030
const float* src;
@@ -34,7 +34,7 @@ struct jit_subnormals_bf16saturation_check_base : public jit_generator {
3434

3535
typedef void (*fn_t)(const args_t*);
3636

37-
jit_subnormals_bf16saturation_check_base() : jit_generator(jit_name()) {
37+
jit_has_special_value_base() : jit_generator(jit_name()) {
3838
jit_ker_ = nullptr;
3939
}
4040

@@ -152,31 +152,31 @@ struct jit_subnormals_bf16saturation_check_base : public jit_generator {
152152
static const float bf16_min_mask_data[8];
153153
};
154154

155-
const uint32_t jit_subnormals_bf16saturation_check_base::exponent_mask_data[8] =
155+
const uint32_t jit_has_special_value_base::exponent_mask_data[8] =
156156
{0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000};
157157

158-
const uint32_t jit_subnormals_bf16saturation_check_base::mantissa_mask_data[8] =
158+
const uint32_t jit_has_special_value_base::mantissa_mask_data[8] =
159159
{0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff};
160160

161-
const float jit_subnormals_bf16saturation_check_base::bf16_max_mask_data[8] = {3.38953139e+38f,
162-
3.38953139e+38f,
163-
3.38953139e+38f,
164-
3.38953139e+38f,
165-
3.38953139e+38f,
166-
3.38953139e+38f,
167-
3.38953139e+38f,
168-
3.38953139e+38f};
169-
170-
const float jit_subnormals_bf16saturation_check_base::bf16_min_mask_data[8] = {-3.38953139e+38f,
171-
-3.38953139e+38f,
172-
-3.38953139e+38f,
173-
-3.38953139e+38f,
174-
-3.38953139e+38f,
175-
-3.38953139e+38f,
176-
-3.38953139e+38f,
177-
-3.38953139e+38f};
161+
const float jit_has_special_value_base::bf16_max_mask_data[8] = {std::numeric_limits<ov::bfloat16>::max(),
162+
std::numeric_limits<ov::bfloat16>::max(),
163+
std::numeric_limits<ov::bfloat16>::max(),
164+
std::numeric_limits<ov::bfloat16>::max(),
165+
std::numeric_limits<ov::bfloat16>::max(),
166+
std::numeric_limits<ov::bfloat16>::max(),
167+
std::numeric_limits<ov::bfloat16>::max(),
168+
std::numeric_limits<ov::bfloat16>::max()};
169+
170+
const float jit_has_special_value_base::bf16_min_mask_data[8] = {std::numeric_limits<ov::bfloat16>::lowest(),
171+
std::numeric_limits<ov::bfloat16>::lowest(),
172+
std::numeric_limits<ov::bfloat16>::lowest(),
173+
std::numeric_limits<ov::bfloat16>::lowest(),
174+
std::numeric_limits<ov::bfloat16>::lowest(),
175+
std::numeric_limits<ov::bfloat16>::lowest(),
176+
std::numeric_limits<ov::bfloat16>::lowest(),
177+
std::numeric_limits<ov::bfloat16>::lowest()};
178178
template <cpu_isa_t isa>
179-
struct jit_has_subnormals : public jit_subnormals_bf16saturation_check_base {
179+
struct jit_has_subnormals : public jit_has_special_value_base {
180180
using Vmm = typename dnnl::impl::utils::conditional<isa == sse41, Xbyak::Xmm, Xbyak::Ymm>::type;
181181

182182
const Vmm rmm4 = Vmm(4);
@@ -250,7 +250,7 @@ struct jit_has_subnormals : public jit_subnormals_bf16saturation_check_base {
250250
}
251251
};
252252
template <cpu_isa_t isa>
253-
struct jit_has_bf16_overflows : public jit_subnormals_bf16saturation_check_base {
253+
struct jit_has_bf16_overflows : public jit_has_special_value_base {
254254
using Vmm = typename dnnl::impl::utils::conditional<isa == sse41, Xbyak::Xmm, Xbyak::Ymm>::type;
255255

256256
const Vmm rmm4 = Vmm(4);
@@ -323,7 +323,7 @@ struct jit_has_bf16_overflows : public jit_subnormals_bf16saturation_check_base
323323
postamble();
324324
}
325325
};
326-
jit_subnormals_bf16saturation_check_base::fn_t jit_has_subnormals_function() {
326+
jit_has_special_value_base::fn_t jit_has_subnormals_function() {
327327
if (mayiuse(cpu_isa_t::avx2)) {
328328
static jit_has_subnormals<cpu_isa_t::avx2> generator;
329329
static auto fn = generator.get();
@@ -335,7 +335,7 @@ jit_subnormals_bf16saturation_check_base::fn_t jit_has_subnormals_function() {
335335
}
336336
return nullptr;
337337
}
338-
jit_subnormals_bf16saturation_check_base::fn_t jit_has_bf16_overflows_function() {
338+
jit_has_special_value_base::fn_t jit_has_bf16_overflows_function() {
339339
if (mayiuse(cpu_isa_t::avx2)) {
340340
static jit_has_bf16_overflows<cpu_isa_t::avx2> generator;
341341
static auto fn = generator.get();
@@ -414,24 +414,25 @@ void Input::cloneBlobIfRequired() {
414414

415415
volatile bool has_subnormals_local = false;
416416
volatile bool has_bf16_overflows_local = false;
417+
if (needFlushDenormalsToZero) {
418+
parallel_for(iterations_num, [&](int n) {
419+
auto ptr = u32data + n * batch_size;
420+
const jit_has_special_value_base::args_t args1 = {
421+
reinterpret_cast<float const*>(ptr),
422+
std::min(batch_size, (size_t)(u32data + size - ptr)),
423+
false};
417424

418-
parallel_for(iterations_num, [&](int n) {
419-
auto ptr = u32data + n * batch_size;
420-
const jit_subnormals_bf16saturation_check_base::args_t args1 = {
421-
reinterpret_cast<float const*>(ptr),
422-
std::min(batch_size, (size_t)(u32data + size - ptr)),
423-
false};
424-
425-
fn(&args1);
425+
fn(&args1);
426426

427-
if (args1.hasTargetValues)
428-
has_subnormals_local = true;
429-
});
427+
if (args1.hasTargetValues)
428+
has_subnormals_local = true;
429+
});
430+
}
430431

431432
if (do_bf16_saturation_check) {
432433
parallel_for(iterations_num, [&](int n) {
433434
auto ptr2 = f32data + n * batch_size;
434-
const jit_subnormals_bf16saturation_check_base::args_t args2 = {
435+
const jit_has_special_value_base::args_t args2 = {
435436
reinterpret_cast<float const*>(ptr2),
436437
std::min(batch_size, (size_t)(f32data + size - ptr2)),
437438
false};
@@ -452,19 +453,18 @@ void Input::cloneBlobIfRequired() {
452453

453454
uint32_t mantissaMask = 0x007fffff;
454455
uint32_t exponentMask = 0x7f800000;
455-
const float bf16_max = 3.3895313899137927e38f;
456+
const float bf16_max = std::numeric_limits<ov::bfloat16>::max();
456457
for (size_t i = 0; i < size; ++i) {
457-
if ((u32data[i] & exponentMask) == 0 && (u32data[i] & mantissaMask) != 0) {
458+
if (needFlushDenormalsToZero && (u32data[i] & exponentMask) == 0 && (u32data[i] & mantissaMask) != 0) {
458459
has_subnormals = true;
459460
}
460-
if (do_bf16_saturation_check) {
461-
if (f32data[i] < -bf16_max || f32data[i] > bf16_max) {
462-
has_bf16_overflows = true;
463-
}
464-
if (has_subnormals && has_bf16_overflows) {
465-
return;
466-
}
467-
} else if (has_subnormals) {
461+
462+
if (do_bf16_saturation_check && (f32data[i] < -bf16_max || f32data[i] > bf16_max)) {
463+
has_bf16_overflows = true;
464+
}
465+
466+
if ((!needFlushDenormalsToZero || has_subnormals) &&
467+
(!do_bf16_saturation_check || has_bf16_overflows)) {
468468
return;
469469
}
470470
}
@@ -508,7 +508,7 @@ void Input::cloneBlobIfRequired() {
508508
} else {
509509
ptr = std::make_shared<StaticMemory>(getEngine(), memDesc);
510510
}
511-
ptr->load(*memory.get(), needFlushDenormalsToZero, has_bf16_overflows);
511+
ptr->load(*memory.get(), has_subnormals, has_bf16_overflows);
512512

513513
return ptr;
514514
};
@@ -536,7 +536,7 @@ void Input::cloneBlobIfRequired() {
536536
prec != element::string &&
537537
// IRs already have all subnormals flushed to zero, but in
538538
// read_model scenario with directly loaded original model still can have subnormals
539-
isBlobAligned(m_constOp) && (!needFlushDenormalsToZero || !has_subnormals) && !has_bf16_overflows &&
539+
isBlobAligned(m_constOp) && !has_subnormals && !has_bf16_overflows &&
540540
// Blob should be cloned in cache only if original weights are stored on other numa node.
541541
// This is possible only in multistream case on multisocket machine.
542542
// TODO: don't clone blob for multisocket + multistream case if current stream is run on the numa node where

0 commit comments

Comments
 (0)