Skip to content

Commit 2139cd4

Browse files
committed
bf16_staturation Jit Impl
1 parent 4a325ce commit 2139cd4

File tree

2 files changed

+185
-40
lines changed

2 files changed

+185
-40
lines changed

src/plugins/intel_cpu/src/nodes/eltwise.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2895,15 +2895,9 @@ void Eltwise::prepareParams() {
28952895

28962896
// FP32 constant inputs may contain values out of BF16 representable range. In case output precision is BF16 we
28972897
// choose "saturation" mode for fp32->bf16 conversion procedure to prevent getting -Inf/+Inf values in the
2898-
// outputs. Since "saturation" conversion is more time consuming, better solution would be to clamp constants on
2899-
// compilation stage (ticket: 159589).
2898+
// outputs. Since "saturation" conversion during kernel runtime is more time consuming, current solution is
2899+
// clamp constants on compilation stage.
29002900
key.doOutputSaturation = false;
2901-
for (size_t i = 0; i < getParentEdges().size(); i++) {
2902-
if (getParentEdgeAt(i)->getParent()->isConstant()) {
2903-
key.doOutputSaturation = true;
2904-
break;
2905-
}
2906-
}
29072901

29082902
auto cache = context->getParamsCache();
29092903
auto result = cache->getOrCreate(key, buildExecutor);

src/plugins/intel_cpu/src/nodes/input.cpp

Lines changed: 183 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,18 @@ namespace node {
2323

2424
#if defined(OPENVINO_ARCH_X86_64)
2525
namespace {
26-
struct jit_has_subnormals_base : public jit_generator {
27-
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_has_subnormals_base)
26+
struct jit_subnormals_bf16saturation_check_base : public jit_generator {
27+
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_subnormals_bf16saturation_check_base)
2828

2929
typedef struct {
3030
const float* src;
3131
const size_t count;
32-
bool hasSubnormals;
32+
bool hasTargetValues;
3333
} args_t;
3434

3535
typedef void (*fn_t)(const args_t*);
3636

37-
jit_has_subnormals_base() : jit_generator(jit_name()) {
37+
jit_subnormals_bf16saturation_check_base() : jit_generator(jit_name()) {
3838
jit_ker_ = nullptr;
3939
}
4040

@@ -110,8 +110,35 @@ struct jit_has_subnormals_base : public jit_generator {
110110
uni_vtestps(b, c); // if ((!b & c) == 0) CF = 1 else CF = 0
111111
}
112112

113+
void check_bf16_saturations(const Xbyak::Reg64& src,
114+
const Xbyak::Ymm& bf16_max_mask,
115+
const Xbyak::Ymm& bf16_min_mask) {
116+
auto a = ymm1;
117+
auto b = ymm2;
118+
auto c = ymm3;
119+
vmovdqu(a, yword[src]); // load 8 floats
120+
vcmpps(b, a, bf16_max_mask, _CMP_GT_OQ); // b = (a > bf16_max) ? 1 : 0
121+
vcmpps(c, a, bf16_min_mask, _CMP_LT_OQ); // c = (a < bf16_min) ? 1 : 0
122+
vorps(b, b, c); // b = b | c
123+
vptest(b, b); // if (b != 0) CF = 1 else CF = 0
124+
}
125+
126+
void check_bf16_saturations(const Xbyak::Reg64& src,
127+
const Xbyak::Xmm& bf16_max_mask,
128+
const Xbyak::Xmm& bf16_min_mask) {
129+
auto a = xmm1;
130+
auto b = xmm2;
131+
auto c = xmm3;
132+
133+
uni_vmovdqu(a, xword[src]); // load 4 floats
134+
uni_vcmpps(b, a, bf16_max_mask, _CMP_GT_OQ); // b = (a > bf16_max) ? 1 : 0
135+
uni_vcmpps(c, a, bf16_max_mask, _CMP_LT_OQ); // c = (a < bf16_min) ? 1 : 0
136+
uni_vorps(b, b, c); // b = b | c
137+
uni_vtestps(b, b); // if (b != 0) CF = 1 else CF = 0
138+
}
139+
113140
protected:
114-
Label exit, has_subnormals, no_subnormals;
141+
Label exit, has_target_values, no_target_values;
115142

116143
const Reg64& reg_src = rax;
117144
const Reg64& reg_dst = rbx;
@@ -121,16 +148,35 @@ struct jit_has_subnormals_base : public jit_generator {
121148

122149
static const uint32_t exponent_mask_data[8];
123150
static const uint32_t mantissa_mask_data[8];
151+
static const float bf16_max_mask_data[8];
152+
static const float bf16_min_mask_data[8];
124153
};
125154

126-
const uint32_t jit_has_subnormals_base::exponent_mask_data[8] =
155+
const uint32_t jit_subnormals_bf16saturation_check_base::exponent_mask_data[8] =
127156
{0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000};
128157

129-
const uint32_t jit_has_subnormals_base::mantissa_mask_data[8] =
158+
const uint32_t jit_subnormals_bf16saturation_check_base::mantissa_mask_data[8] =
130159
{0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff};
131160

161+
const float jit_subnormals_bf16saturation_check_base::bf16_max_mask_data[8] = {3.38953139e+38f,
162+
3.38953139e+38f,
163+
3.38953139e+38f,
164+
3.38953139e+38f,
165+
3.38953139e+38f,
166+
3.38953139e+38f,
167+
3.38953139e+38f,
168+
3.38953139e+38f};
169+
170+
const float jit_subnormals_bf16saturation_check_base::bf16_min_mask_data[8] = {-3.38953139e+38f,
171+
-3.38953139e+38f,
172+
-3.38953139e+38f,
173+
-3.38953139e+38f,
174+
-3.38953139e+38f,
175+
-3.38953139e+38f,
176+
-3.38953139e+38f,
177+
-3.38953139e+38f};
132178
template <cpu_isa_t isa>
133-
struct jit_has_subnormals : public jit_has_subnormals_base {
179+
struct jit_has_subnormals : public jit_subnormals_bf16saturation_check_base {
134180
using Vmm = typename dnnl::impl::utils::conditional<isa == sse41, Xbyak::Xmm, Xbyak::Ymm>::type;
135181

136182
const Vmm rmm4 = Vmm(4);
@@ -150,7 +196,7 @@ struct jit_has_subnormals : public jit_has_subnormals_base {
150196

151197
// Get arguments addresses
152198
mov(reg_src, ptr[param1 + offsetof(args_t, src)]);
153-
lea(reg_dst, ptr[param1 + offsetof(args_t, hasSubnormals)]);
199+
lea(reg_dst, ptr[param1 + offsetof(args_t, hasTargetValues)]);
154200
mov(reg_sz, ptr[param1 + offsetof(args_t, count)]);
155201

156202
// Initialize necessary consts
@@ -167,7 +213,7 @@ struct jit_has_subnormals : public jit_has_subnormals_base {
167213

168214
foreach (reg_idx, 1, r8, [&, this](const Xbyak::Reg64& idx) {
169215
check_subnormals(reg_src, exponent_mask, mantissa_mask, zero);
170-
jnc(has_subnormals);
216+
jnc(has_target_values);
171217
add(reg_src, sizeof(float) * vlen);
172218
})
173219
;
@@ -186,25 +232,98 @@ struct jit_has_subnormals : public jit_has_subnormals_base {
186232

187233
copy_floats(r8, reg_src, reg_sz);
188234
check_subnormals(r8, exponent_mask, mantissa_mask, zero);
189-
jc(no_subnormals);
235+
jc(no_target_values);
190236
add(rsp, vlen * sizeof(float));
191237

192-
L(has_subnormals);
238+
L(has_target_values);
193239

194240
mov(rax, 1);
195241
mov(byte[reg_dst], al);
196242
jmp(exit);
197243

198-
L(no_subnormals);
244+
L(no_target_values);
199245
add(rsp, vlen * sizeof(float));
200246

201247
L(exit);
202248

203249
postamble();
204250
}
205251
};
252+
template <cpu_isa_t isa>
253+
struct jit_has_bf16_overflows : public jit_subnormals_bf16saturation_check_base {
254+
using Vmm = typename dnnl::impl::utils::conditional<isa == sse41, Xbyak::Xmm, Xbyak::Ymm>::type;
255+
256+
const Vmm rmm4 = Vmm(4);
257+
const Vmm rmm5 = Vmm(5);
258+
const Vmm rmm6 = Vmm(6);
259+
const int length = isa == sse41 ? 4 : 8;
260+
261+
void generate() override final { // NOLINT
262+
size_t const vlen = length;
263+
const int sh_bits = std::ilogb(vlen);
264+
265+
auto zero = rmm4;
266+
auto bf16_max_mask = rmm5;
267+
auto bf16_min_mask = rmm6;
268+
269+
preamble();
270+
271+
// Get arguments addresses
272+
mov(reg_src, ptr[param1 + offsetof(args_t, src)]);
273+
lea(reg_dst, ptr[param1 + offsetof(args_t, hasTargetValues)]);
274+
mov(reg_sz, ptr[param1 + offsetof(args_t, count)]);
275+
276+
// Initialize necessary consts
277+
uni_vpxor(zero, zero, zero);
278+
mov(reg_mask_addr, (size_t)bf16_max_mask_data);
279+
uni_vmovdqu(bf16_max_mask, ptr[reg_mask_addr]);
280+
mov(reg_mask_addr, (size_t)bf16_min_mask_data);
281+
uni_vmovdqu(bf16_min_mask, ptr[reg_mask_addr]);
282+
283+
// Main loop
284+
xor_(reg_idx, reg_idx);
285+
mov(r8, reg_sz);
286+
shr(r8, sh_bits);
287+
288+
foreach (reg_idx, 1, r8, [&, this](const Xbyak::Reg64& idx) {
289+
check_bf16_saturations(reg_src, bf16_max_mask, bf16_min_mask);
290+
jnz(has_target_values, T_NEAR);
291+
add(reg_src, sizeof(float) * vlen);
292+
})
293+
;
294+
295+
// Tail
296+
shl(reg_idx, sh_bits);
297+
sub(reg_sz, reg_idx);
298+
test(reg_sz, reg_sz);
299+
jz(exit);
206300

207-
jit_has_subnormals_base::fn_t jit_has_subnormals_function() {
301+
// use space on stack for 4 or 8 floats
302+
sub(rsp, vlen * sizeof(float));
303+
mov(r8, rsp);
304+
305+
uni_vmovdqu(ptr[r8], zero);
306+
307+
copy_floats(r8, reg_src, reg_sz);
308+
check_bf16_saturations(r8, bf16_max_mask, bf16_min_mask);
309+
jz(no_target_values, T_NEAR);
310+
add(rsp, vlen * sizeof(float));
311+
312+
L(has_target_values);
313+
314+
mov(rax, 1);
315+
mov(byte[reg_dst], al);
316+
jmp(exit);
317+
318+
L(no_target_values);
319+
add(rsp, vlen * sizeof(float));
320+
321+
L(exit);
322+
323+
postamble();
324+
}
325+
};
326+
jit_subnormals_bf16saturation_check_base::fn_t jit_has_subnormals_function() {
208327
if (mayiuse(cpu_isa_t::avx2)) {
209328
static jit_has_subnormals<cpu_isa_t::avx2> generator;
210329
static auto fn = generator.get();
@@ -216,6 +335,18 @@ jit_has_subnormals_base::fn_t jit_has_subnormals_function() {
216335
}
217336
return nullptr;
218337
}
338+
jit_subnormals_bf16saturation_check_base::fn_t jit_has_bf16_overflows_function() {
339+
if (mayiuse(cpu_isa_t::avx2)) {
340+
static jit_has_bf16_overflows<cpu_isa_t::avx2> generator;
341+
static auto fn = generator.get();
342+
return fn;
343+
} else if (mayiuse(cpu_isa_t::sse41)) {
344+
static jit_has_bf16_overflows<cpu_isa_t::sse41> generator;
345+
static auto fn = generator.get();
346+
return fn;
347+
}
348+
return nullptr;
349+
}
219350

220351
} // namespace
221352
#endif
@@ -271,49 +402,69 @@ void Input::cloneBlobIfRequired() {
271402
if (!size)
272403
return;
273404

274-
const float bf16_max = 3.3895313899137927e38f;
405+
const bool do_bf16_saturation_check =
406+
(context->getConfig().inferencePrecision == ov::element::bf16) ? true : false;
275407

276408
#if defined(OPENVINO_ARCH_X86_64)
277-
if (auto fn = jit_has_subnormals_function()) {
409+
auto fn = jit_has_subnormals_function();
410+
auto fn_bf16_check = jit_has_bf16_overflows_function();
411+
if (fn && fn_bf16_check) {
278412
static const size_t batch_size = 2048;
279413
const size_t iterations_num = size / batch_size + 1;
280414

281415
volatile bool has_subnormals_local = false;
416+
volatile bool has_bf16_overflows_local = false;
282417

283418
parallel_for(iterations_num, [&](int n) {
284419
auto ptr = u32data + n * batch_size;
285-
const jit_has_subnormals_base::args_t args = {reinterpret_cast<float const*>(ptr),
286-
std::min(batch_size, (size_t)(u32data + size - ptr)),
287-
false};
420+
const jit_subnormals_bf16saturation_check_base::args_t args1 = {
421+
reinterpret_cast<float const*>(ptr),
422+
std::min(batch_size, (size_t)(u32data + size - ptr)),
423+
false};
288424

289-
fn(&args);
425+
fn(&args1);
290426

291-
if (args.hasSubnormals)
427+
if (args1.hasTargetValues)
292428
has_subnormals_local = true;
293429
});
294430

295-
has_subnormals = has_subnormals_local;
296-
//TODO: opt with jit
297-
for (size_t i = 0; i < size; ++i) {
298-
if (f32data[i] < -bf16_max || f32data[i] > bf16_max) {
299-
has_bf16_overflows = true;
300-
return;
301-
}
431+
if (do_bf16_saturation_check) {
432+
parallel_for(iterations_num, [&](int n) {
433+
auto ptr2 = f32data + n * batch_size;
434+
const jit_subnormals_bf16saturation_check_base::args_t args2 = {
435+
reinterpret_cast<float const*>(ptr2),
436+
std::min(batch_size, (size_t)(f32data + size - ptr2)),
437+
false};
438+
439+
fn_bf16_check(&args2);
440+
441+
if (args2.hasTargetValues)
442+
has_bf16_overflows_local = true;
443+
});
302444
}
445+
446+
has_subnormals = has_subnormals_local;
447+
has_bf16_overflows = has_bf16_overflows_local;
448+
303449
return;
304450
}
305451
#endif
306452

307453
uint32_t mantissaMask = 0x007fffff;
308454
uint32_t exponentMask = 0x7f800000;
455+
const float bf16_max = 3.3895313899137927e38f;
309456
for (size_t i = 0; i < size; ++i) {
310457
if ((u32data[i] & exponentMask) == 0 && (u32data[i] & mantissaMask) != 0) {
311458
has_subnormals = true;
312459
}
313-
if (f32data[i] < -bf16_max || f32data[i] > bf16_max) {
314-
has_bf16_overflows = true;
315-
}
316-
if (has_subnormals && has_bf16_overflows) {
460+
if (do_bf16_saturation_check) {
461+
if (f32data[i] < -bf16_max || f32data[i] > bf16_max) {
462+
has_bf16_overflows = true;
463+
}
464+
if (has_subnormals && has_bf16_overflows) {
465+
return;
466+
}
467+
} else if (has_subnormals) {
317468
return;
318469
}
319470
}

0 commit comments

Comments
 (0)