@@ -23,8 +23,8 @@ namespace node {
2323
2424#if defined(OPENVINO_ARCH_X86_64)
2525namespace {
26- struct jit_subnormals_bf16saturation_check_base : public jit_generator {
27- DECLARE_CPU_JIT_AUX_FUNCTIONS (jit_subnormals_bf16saturation_check_base )
26+ struct jit_has_special_value_base : public jit_generator {
27+ DECLARE_CPU_JIT_AUX_FUNCTIONS (jit_has_special_value_base )
2828
2929 typedef struct {
3030 const float * src;
@@ -34,7 +34,7 @@ struct jit_subnormals_bf16saturation_check_base : public jit_generator {
3434
3535 typedef void (*fn_t )(const args_t *);
3636
37- jit_subnormals_bf16saturation_check_base () : jit_generator(jit_name()) {
37+ jit_has_special_value_base () : jit_generator(jit_name()) {
3838 jit_ker_ = nullptr ;
3939 }
4040
@@ -152,31 +152,31 @@ struct jit_subnormals_bf16saturation_check_base : public jit_generator {
152152 static const float bf16_min_mask_data[8 ];
153153};
154154
155- const uint32_t jit_subnormals_bf16saturation_check_base ::exponent_mask_data[8 ] =
155+ const uint32_t jit_has_special_value_base ::exponent_mask_data[8 ] =
156156 {0x7f800000 , 0x7f800000 , 0x7f800000 , 0x7f800000 , 0x7f800000 , 0x7f800000 , 0x7f800000 , 0x7f800000 };
157157
158- const uint32_t jit_subnormals_bf16saturation_check_base ::mantissa_mask_data[8 ] =
158+ const uint32_t jit_has_special_value_base ::mantissa_mask_data[8 ] =
159159 {0x007fffff , 0x007fffff , 0x007fffff , 0x007fffff , 0x007fffff , 0x007fffff , 0x007fffff , 0x007fffff };
160160
161- const float jit_subnormals_bf16saturation_check_base ::bf16_max_mask_data[8 ] = {3 .38953139e+ 38f ,
162- 3 .38953139e+ 38f ,
163- 3 .38953139e+ 38f ,
164- 3 .38953139e+ 38f ,
165- 3 .38953139e+ 38f ,
166- 3 .38953139e+ 38f ,
167- 3 .38953139e+ 38f ,
168- 3 .38953139e+ 38f };
169-
170- const float jit_subnormals_bf16saturation_check_base ::bf16_min_mask_data[8 ] = {- 3 .38953139e+ 38f ,
171- - 3 .38953139e+ 38f ,
172- - 3 .38953139e+ 38f ,
173- - 3 .38953139e+ 38f ,
174- - 3 .38953139e+ 38f ,
175- - 3 .38953139e+ 38f ,
176- - 3 .38953139e+ 38f ,
177- - 3 .38953139e+ 38f };
161+ const float jit_has_special_value_base ::bf16_max_mask_data[8 ] = {std::numeric_limits<ov::bfloat16>:: max () ,
162+ std::numeric_limits<ov::bfloat16>:: max () ,
163+ std::numeric_limits<ov::bfloat16>:: max () ,
164+ std::numeric_limits<ov::bfloat16>:: max () ,
165+ std::numeric_limits<ov::bfloat16>:: max () ,
166+ std::numeric_limits<ov::bfloat16>:: max () ,
167+ std::numeric_limits<ov::bfloat16>:: max () ,
168+ std::numeric_limits<ov::bfloat16>:: max () };
169+
170+ const float jit_has_special_value_base ::bf16_min_mask_data[8 ] = {std::numeric_limits<ov::bfloat16>:: lowest () ,
171+ std::numeric_limits<ov::bfloat16>:: lowest () ,
172+ std::numeric_limits<ov::bfloat16>:: lowest () ,
173+ std::numeric_limits<ov::bfloat16>:: lowest () ,
174+ std::numeric_limits<ov::bfloat16>:: lowest () ,
175+ std::numeric_limits<ov::bfloat16>:: lowest () ,
176+ std::numeric_limits<ov::bfloat16>:: lowest () ,
177+ std::numeric_limits<ov::bfloat16>:: lowest () };
178178template <cpu_isa_t isa>
179- struct jit_has_subnormals : public jit_subnormals_bf16saturation_check_base {
179+ struct jit_has_subnormals : public jit_has_special_value_base {
180180 using Vmm = typename dnnl::impl::utils::conditional<isa == sse41, Xbyak::Xmm, Xbyak::Ymm>::type;
181181
182182 const Vmm rmm4 = Vmm(4 );
@@ -250,7 +250,7 @@ struct jit_has_subnormals : public jit_subnormals_bf16saturation_check_base {
250250 }
251251};
252252template <cpu_isa_t isa>
253- struct jit_has_bf16_overflows : public jit_subnormals_bf16saturation_check_base {
253+ struct jit_has_bf16_overflows : public jit_has_special_value_base {
254254 using Vmm = typename dnnl::impl::utils::conditional<isa == sse41, Xbyak::Xmm, Xbyak::Ymm>::type;
255255
256256 const Vmm rmm4 = Vmm(4 );
@@ -323,7 +323,7 @@ struct jit_has_bf16_overflows : public jit_subnormals_bf16saturation_check_base
323323 postamble ();
324324 }
325325};
326- jit_subnormals_bf16saturation_check_base ::fn_t jit_has_subnormals_function () {
326+ jit_has_special_value_base ::fn_t jit_has_subnormals_function () {
327327 if (mayiuse (cpu_isa_t ::avx2)) {
328328 static jit_has_subnormals<cpu_isa_t ::avx2> generator;
329329 static auto fn = generator.get ();
@@ -335,7 +335,7 @@ jit_subnormals_bf16saturation_check_base::fn_t jit_has_subnormals_function() {
335335 }
336336 return nullptr ;
337337}
338- jit_subnormals_bf16saturation_check_base ::fn_t jit_has_bf16_overflows_function () {
338+ jit_has_special_value_base ::fn_t jit_has_bf16_overflows_function () {
339339 if (mayiuse (cpu_isa_t ::avx2)) {
340340 static jit_has_bf16_overflows<cpu_isa_t ::avx2> generator;
341341 static auto fn = generator.get ();
@@ -414,24 +414,25 @@ void Input::cloneBlobIfRequired() {
414414
415415 volatile bool has_subnormals_local = false ;
416416 volatile bool has_bf16_overflows_local = false ;
417+ if (needFlushDenormalsToZero) {
418+ parallel_for (iterations_num, [&](int n) {
419+ auto ptr = u32data + n * batch_size;
420+ const jit_has_special_value_base::args_t args1 = {
421+ reinterpret_cast <float const *>(ptr),
422+ std::min (batch_size, (size_t )(u32data + size - ptr)),
423+ false };
417424
418- parallel_for (iterations_num, [&](int n) {
419- auto ptr = u32data + n * batch_size;
420- const jit_subnormals_bf16saturation_check_base::args_t args1 = {
421- reinterpret_cast <float const *>(ptr),
422- std::min (batch_size, (size_t )(u32data + size - ptr)),
423- false };
424-
425- fn (&args1);
425+ fn (&args1);
426426
427- if (args1.hasTargetValues )
428- has_subnormals_local = true ;
429- });
427+ if (args1.hasTargetValues )
428+ has_subnormals_local = true ;
429+ });
430+ }
430431
431432 if (do_bf16_saturation_check) {
432433 parallel_for (iterations_num, [&](int n) {
433434 auto ptr2 = f32data + n * batch_size;
434- const jit_subnormals_bf16saturation_check_base ::args_t args2 = {
435+ const jit_has_special_value_base ::args_t args2 = {
435436 reinterpret_cast <float const *>(ptr2),
436437 std::min (batch_size, (size_t )(f32data + size - ptr2)),
437438 false };
@@ -452,19 +453,18 @@ void Input::cloneBlobIfRequired() {
452453
453454 uint32_t mantissaMask = 0x007fffff ;
454455 uint32_t exponentMask = 0x7f800000 ;
455- const float bf16_max = 3 .3895313899137927e38f ;
456+ const float bf16_max = std::numeric_limits<ov::bfloat16>:: max () ;
456457 for (size_t i = 0 ; i < size; ++i) {
457- if ((u32data[i] & exponentMask) == 0 && (u32data[i] & mantissaMask) != 0 ) {
458+ if (needFlushDenormalsToZero && (u32data[i] & exponentMask) == 0 && (u32data[i] & mantissaMask) != 0 ) {
458459 has_subnormals = true ;
459460 }
460- if (do_bf16_saturation_check) {
461- if (f32data[i] < -bf16_max || f32data[i] > bf16_max) {
462- has_bf16_overflows = true ;
463- }
464- if (has_subnormals && has_bf16_overflows) {
465- return ;
466- }
467- } else if (has_subnormals) {
461+
462+ if (do_bf16_saturation_check && (f32data[i] < -bf16_max || f32data[i] > bf16_max)) {
463+ has_bf16_overflows = true ;
464+ }
465+
466+ if ((!needFlushDenormalsToZero || has_subnormals) &&
467+ (!do_bf16_saturation_check || has_bf16_overflows)) {
468468 return ;
469469 }
470470 }
@@ -508,7 +508,7 @@ void Input::cloneBlobIfRequired() {
508508 } else {
509509 ptr = std::make_shared<StaticMemory>(getEngine (), memDesc);
510510 }
511- ptr->load (*memory.get (), needFlushDenormalsToZero , has_bf16_overflows);
511+ ptr->load (*memory.get (), has_subnormals , has_bf16_overflows);
512512
513513 return ptr;
514514 };
@@ -536,7 +536,7 @@ void Input::cloneBlobIfRequired() {
536536 prec != element::string &&
537537 // IRs already have all subnormals flushed to zero, but in
538538 // read_model scenario with directly loaded original model still can have subnormals
539- isBlobAligned (m_constOp) && (!needFlushDenormalsToZero || ! has_subnormals) && !has_bf16_overflows &&
539+ isBlobAligned (m_constOp) && ! has_subnormals && !has_bf16_overflows &&
540540 // Blob should be cloned in cache only if original weights are stored on other numa node.
541541 // This is possible only in multistream case on multisocket machine.
542542 // TODO: don't clone blob for multisocket + multistream case if current stream is run on the numa node where
0 commit comments