@@ -23,18 +23,18 @@ namespace node {
2323
2424#if defined(OPENVINO_ARCH_X86_64)
2525namespace {
26- struct jit_has_subnormals_base : public jit_generator {
27- DECLARE_CPU_JIT_AUX_FUNCTIONS (jit_has_subnormals_base )
26+ struct jit_subnormals_bf16saturation_check_base : public jit_generator {
27+ DECLARE_CPU_JIT_AUX_FUNCTIONS (jit_subnormals_bf16saturation_check_base )
2828
2929 typedef struct {
3030 const float * src;
3131 const size_t count;
32- bool hasSubnormals ;
32+ bool hasTargetValues ;
3333 } args_t ;
3434
3535 typedef void (*fn_t )(const args_t *);
3636
37- jit_has_subnormals_base () : jit_generator(jit_name()) {
37+ jit_subnormals_bf16saturation_check_base () : jit_generator(jit_name()) {
3838 jit_ker_ = nullptr ;
3939 }
4040
@@ -110,8 +110,35 @@ struct jit_has_subnormals_base : public jit_generator {
110110 uni_vtestps (b, c); // if ((!b & c) == 0) CF = 1 else CF = 0
111111 }
112112
113+ void check_bf16_saturations (const Xbyak::Reg64& src,
114+ const Xbyak::Ymm& bf16_max_mask,
115+ const Xbyak::Ymm& bf16_min_mask) {
116+ auto a = ymm1;
117+ auto b = ymm2;
118+ auto c = ymm3;
119+ vmovdqu (a, yword[src]); // load 8 floats
120+ vcmpps (b, a, bf16_max_mask, 0x1e ); // b = (a > bf16_max) ? 1 : 0
121+ vcmpps (c, a, bf16_min_mask, 0x11 ); // c = (a < bf16_min) ? 1 : 0
122+ vorps (b, b, c); // b = b | c
123+ vptest (b, b); // if (b != 0) CF = 1 else CF = 0
124+ }
125+
126+ void check_bf16_saturations (const Xbyak::Reg64& src,
127+ const Xbyak::Xmm& bf16_max_mask,
128+ const Xbyak::Xmm& bf16_min_mask) {
129+ auto a = xmm1;
130+ auto b = xmm2;
131+ auto c = xmm3;
132+
133+ uni_vmovdqu (a, xword[src]); // load 4 floats
134+ uni_vcmpps (b, a, bf16_max_mask, 0x1e ); // b = (a > bf16_max) ? 1 : 0
135+ uni_vcmpps (c, a, bf16_max_mask, 0x11 ); // c = (a < bf16_min) ? 1 : 0
136+ uni_vorps (b, b, c); // b = b | c
137+ uni_vtestps (b, b); // if (b != 0) CF = 1 else CF = 0
138+ }
139+
113140protected:
114- Label exit, has_subnormals, no_subnormals ;
141+ Label exit, has_target_values, no_target_values ;
115142
116143 const Reg64& reg_src = rax;
117144 const Reg64& reg_dst = rbx;
@@ -121,16 +148,35 @@ struct jit_has_subnormals_base : public jit_generator {
121148
122149 static const uint32_t exponent_mask_data[8 ];
123150 static const uint32_t mantissa_mask_data[8 ];
151+ static const float bf16_max_mask_data[8 ];
152+ static const float bf16_min_mask_data[8 ];
124153};
125154
126- const uint32_t jit_has_subnormals_base ::exponent_mask_data[8 ] =
155+ const uint32_t jit_subnormals_bf16saturation_check_base ::exponent_mask_data[8 ] =
127156 {0x7f800000 , 0x7f800000 , 0x7f800000 , 0x7f800000 , 0x7f800000 , 0x7f800000 , 0x7f800000 , 0x7f800000 };
128157
129- const uint32_t jit_has_subnormals_base ::mantissa_mask_data[8 ] =
158+ const uint32_t jit_subnormals_bf16saturation_check_base ::mantissa_mask_data[8 ] =
130159 {0x007fffff , 0x007fffff , 0x007fffff , 0x007fffff , 0x007fffff , 0x007fffff , 0x007fffff , 0x007fffff };
131160
161+ const float jit_subnormals_bf16saturation_check_base::bf16_max_mask_data[8 ] = {3 .38953139e+38f ,
162+ 3 .38953139e+38f ,
163+ 3 .38953139e+38f ,
164+ 3 .38953139e+38f ,
165+ 3 .38953139e+38f ,
166+ 3 .38953139e+38f ,
167+ 3 .38953139e+38f ,
168+ 3 .38953139e+38f };
169+
170+ const float jit_subnormals_bf16saturation_check_base::bf16_min_mask_data[8 ] = {-3 .38953139e+38f ,
171+ -3 .38953139e+38f ,
172+ -3 .38953139e+38f ,
173+ -3 .38953139e+38f ,
174+ -3 .38953139e+38f ,
175+ -3 .38953139e+38f ,
176+ -3 .38953139e+38f ,
177+ -3 .38953139e+38f };
132178template <cpu_isa_t isa>
133- struct jit_has_subnormals : public jit_has_subnormals_base {
179+ struct jit_has_subnormals : public jit_subnormals_bf16saturation_check_base {
134180 using Vmm = typename dnnl::impl::utils::conditional<isa == sse41, Xbyak::Xmm, Xbyak::Ymm>::type;
135181
136182 const Vmm rmm4 = Vmm(4 );
@@ -150,7 +196,7 @@ struct jit_has_subnormals : public jit_has_subnormals_base {
150196
151197 // Get arguments addresses
152198 mov (reg_src, ptr[param1 + offsetof (args_t , src)]);
153- lea (reg_dst, ptr[param1 + offsetof (args_t , hasSubnormals )]);
199+ lea (reg_dst, ptr[param1 + offsetof (args_t , hasTargetValues )]);
154200 mov (reg_sz, ptr[param1 + offsetof (args_t , count)]);
155201
156202 // Initialize necessary consts
@@ -167,7 +213,7 @@ struct jit_has_subnormals : public jit_has_subnormals_base {
167213
168214 foreach (reg_idx, 1 , r8, [&, this ](const Xbyak::Reg64& idx) {
169215 check_subnormals (reg_src, exponent_mask, mantissa_mask, zero);
170- jnc (has_subnormals );
216+ jnc (has_target_values );
171217 add (reg_src, sizeof (float ) * vlen);
172218 })
173219 ;
@@ -186,25 +232,98 @@ struct jit_has_subnormals : public jit_has_subnormals_base {
186232
187233 copy_floats (r8, reg_src, reg_sz);
188234 check_subnormals (r8, exponent_mask, mantissa_mask, zero);
189- jc (no_subnormals );
235+ jc (no_target_values );
190236 add (rsp, vlen * sizeof (float ));
191237
192- L (has_subnormals );
238+ L (has_target_values );
193239
194240 mov (rax, 1 );
195241 mov (byte[reg_dst], al);
196242 jmp (exit);
197243
198- L (no_subnormals );
244+ L (no_target_values );
199245 add (rsp, vlen * sizeof (float ));
200246
201247 L (exit);
202248
203249 postamble ();
204250 }
205251};
252+ template <cpu_isa_t isa>
253+ struct jit_has_bf16_overflows : public jit_subnormals_bf16saturation_check_base {
254+ using Vmm = typename dnnl::impl::utils::conditional<isa == sse41, Xbyak::Xmm, Xbyak::Ymm>::type;
255+
256+ const Vmm rmm4 = Vmm(4 );
257+ const Vmm rmm5 = Vmm(5 );
258+ const Vmm rmm6 = Vmm(6 );
259+ const int length = isa == sse41 ? 4 : 8 ;
260+
261+ void generate () override final { // NOLINT
262+ size_t const vlen = length;
263+ const int sh_bits = std::ilogb (vlen);
264+
265+ auto zero = rmm4;
266+ auto bf16_max_mask = rmm5;
267+ auto bf16_min_mask = rmm6;
268+
269+ preamble ();
270+
271+ // Get arguments addresses
272+ mov (reg_src, ptr[param1 + offsetof (args_t , src)]);
273+ lea (reg_dst, ptr[param1 + offsetof (args_t , hasTargetValues)]);
274+ mov (reg_sz, ptr[param1 + offsetof (args_t , count)]);
275+
276+ // Initialize necessary consts
277+ uni_vpxor (zero, zero, zero);
278+ mov (reg_mask_addr, (size_t )bf16_max_mask_data);
279+ uni_vmovdqu (bf16_max_mask, ptr[reg_mask_addr]);
280+ mov (reg_mask_addr, (size_t )bf16_min_mask_data);
281+ uni_vmovdqu (bf16_min_mask, ptr[reg_mask_addr]);
282+
283+ // Main loop
284+ xor_ (reg_idx, reg_idx);
285+ mov (r8, reg_sz);
286+ shr (r8, sh_bits);
287+
288+ foreach (reg_idx, 1 , r8, [&, this ](const Xbyak::Reg64& idx) {
289+ check_bf16_saturations (reg_src, bf16_max_mask, bf16_min_mask);
290+ jnz (has_target_values, T_NEAR);
291+ add (reg_src, sizeof (float ) * vlen);
292+ })
293+ ;
294+
295+ // Tail
296+ shl (reg_idx, sh_bits);
297+ sub (reg_sz, reg_idx);
298+ test (reg_sz, reg_sz);
299+ jz (exit);
206300
207- jit_has_subnormals_base::fn_t jit_has_subnormals_function () {
301+ // use space on stack for 4 or 8 floats
302+ sub (rsp, vlen * sizeof (float ));
303+ mov (r8, rsp);
304+
305+ uni_vmovdqu (ptr[r8], zero);
306+
307+ copy_floats (r8, reg_src, reg_sz);
308+ check_bf16_saturations (r8, bf16_max_mask, bf16_min_mask);
309+ jz (no_target_values, T_NEAR);
310+ add (rsp, vlen * sizeof (float ));
311+
312+ L (has_target_values);
313+
314+ mov (rax, 1 );
315+ mov (byte[reg_dst], al);
316+ jmp (exit);
317+
318+ L (no_target_values);
319+ add (rsp, vlen * sizeof (float ));
320+
321+ L (exit);
322+
323+ postamble ();
324+ }
325+ };
326+ jit_subnormals_bf16saturation_check_base::fn_t jit_has_subnormals_function () {
208327 if (mayiuse (cpu_isa_t ::avx2)) {
209328 static jit_has_subnormals<cpu_isa_t ::avx2> generator;
210329 static auto fn = generator.get ();
@@ -216,6 +335,18 @@ jit_has_subnormals_base::fn_t jit_has_subnormals_function() {
216335 }
217336 return nullptr ;
218337}
338+ jit_subnormals_bf16saturation_check_base::fn_t jit_has_bf16_overflows_function () {
339+ if (mayiuse (cpu_isa_t ::avx2)) {
340+ static jit_has_bf16_overflows<cpu_isa_t ::avx2> generator;
341+ static auto fn = generator.get ();
342+ return fn;
343+ } else if (mayiuse (cpu_isa_t ::sse41)) {
344+ static jit_has_bf16_overflows<cpu_isa_t ::sse41> generator;
345+ static auto fn = generator.get ();
346+ return fn;
347+ }
348+ return nullptr ;
349+ }
219350
220351} // namespace
221352#endif
@@ -271,49 +402,69 @@ void Input::cloneBlobIfRequired() {
271402 if (!size)
272403 return ;
273404
274- const float bf16_max = 3 .3895313899137927e38f;
405+ const bool do_bf16_saturation_check =
406+ (context->getConfig ().inferencePrecision == ov::element::bf16 ) ? true : false ;
275407
276408#if defined(OPENVINO_ARCH_X86_64)
277- if (auto fn = jit_has_subnormals_function ()) {
409+ auto fn = jit_has_subnormals_function ();
410+ auto fn_bf16_check = jit_has_bf16_overflows_function ();
411+ if (fn && fn_bf16_check) {
278412 static const size_t batch_size = 2048 ;
279413 const size_t iterations_num = size / batch_size + 1 ;
280414
281415 volatile bool has_subnormals_local = false ;
416+ volatile bool has_bf16_overflows_local = false ;
282417
283418 parallel_for (iterations_num, [&](int n) {
284419 auto ptr = u32data + n * batch_size;
285- const jit_has_subnormals_base::args_t args = {reinterpret_cast <float const *>(ptr),
286- std::min (batch_size, (size_t )(u32data + size - ptr)),
287- false };
420+ const jit_subnormals_bf16saturation_check_base::args_t args1 = {
421+ reinterpret_cast <float const *>(ptr),
422+ std::min (batch_size, (size_t )(u32data + size - ptr)),
423+ false };
288424
289- fn (&args );
425+ fn (&args1 );
290426
291- if (args. hasSubnormals )
427+ if (args1. hasTargetValues )
292428 has_subnormals_local = true ;
293429 });
294430
295- has_subnormals = has_subnormals_local;
296- // TODO: opt with jit
297- for (size_t i = 0 ; i < size; ++i) {
298- if (f32data[i] < -bf16_max || f32data[i] > bf16_max) {
299- has_bf16_overflows = true ;
300- return ;
301- }
431+ if (do_bf16_saturation_check) {
432+ parallel_for (iterations_num, [&](int n) {
433+ auto ptr2 = f32data + n * batch_size;
434+ const jit_subnormals_bf16saturation_check_base::args_t args2 = {
435+ reinterpret_cast <float const *>(ptr2),
436+ std::min (batch_size, (size_t )(f32data + size - ptr2)),
437+ false };
438+
439+ fn_bf16_check (&args2);
440+
441+ if (args2.hasTargetValues )
442+ has_bf16_overflows_local = true ;
443+ });
302444 }
445+
446+ has_subnormals = has_subnormals_local;
447+ has_bf16_overflows = has_bf16_overflows_local;
448+
303449 return ;
304450 }
305451#endif
306452
307453 uint32_t mantissaMask = 0x007fffff ;
308454 uint32_t exponentMask = 0x7f800000 ;
455+ const float bf16_max = 3 .3895313899137927e38f;
309456 for (size_t i = 0 ; i < size; ++i) {
310457 if ((u32data[i] & exponentMask) == 0 && (u32data[i] & mantissaMask) != 0 ) {
311458 has_subnormals = true ;
312459 }
313- if (f32data[i] < -bf16_max || f32data[i] > bf16_max) {
314- has_bf16_overflows = true ;
315- }
316- if (has_subnormals && has_bf16_overflows) {
460+ if (do_bf16_saturation_check) {
461+ if (f32data[i] < -bf16_max || f32data[i] > bf16_max) {
462+ has_bf16_overflows = true ;
463+ }
464+ if (has_subnormals && has_bf16_overflows) {
465+ return ;
466+ }
467+ } else if (has_subnormals) {
317468 return ;
318469 }
319470 }
0 commit comments