Skip to content

Commit ac50644

Browse files
committed
bf16 saturation during model compilation
1 parent 6b5c0a5 commit ac50644

File tree

5 files changed

+95
-61
lines changed

5 files changed

+95
-61
lines changed

src/plugins/intel_cpu/src/cpu_memory.cpp

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,26 @@ BlockedMemoryDescPtr IMemory::getDescWithType<BlockedMemoryDesc, 0, 0>() const {
3030
}
3131

3232
namespace {
33-
inline void setSubnormalsToZero(float* data, size_t size) {
33+
inline void setSubnormalsToZeroAndbf16Saturation(float* data, size_t size, bool ftz, bool bf16saturation) {
3434
uint32_t* u32data = reinterpret_cast<uint32_t*>(data);
35+
float* floatdata = reinterpret_cast<float*>(data);
3536
for (size_t i = 0; i < size; ++i) {
36-
if ((u32data[i] & (0xFF << 23)) == 0) {
37+
if (ftz && ((u32data[i] & (0xFF << 23)) == 0)) {
3738
u32data[i] = 0;
39+
} else if (bf16saturation) {
40+
if (floatdata[i] < -3.3895313899137927e38f) {
41+
floatdata[i] = -3.3895313899137927e38f;
42+
} else if (floatdata[i] > 3.3895313899137927e38f) {
43+
floatdata[i] = 3.3895313899137927e38f;
44+
}
3845
}
3946
}
4047
}
4148

42-
void transferData(const IMemory& src, const IMemory& dst, bool ftz) {
49+
void transferData(const IMemory& src, const IMemory& dst, bool ftz, bool bf16saturation) {
4350
node::Reorder::reorderData(src, dst);
4451

45-
if (!ftz) {
52+
if (!ftz && !bf16saturation) {
4653
return;
4754
}
4855
if (src.getDesc().getPrecision() != ov::element::f32 || dst.getDesc().getPrecision() != ov::element::f32) {
@@ -62,7 +69,7 @@ void transferData(const IMemory& src, const IMemory& dst, bool ftz) {
6269
// actual FTZ
6370
auto* memData = static_cast<float*>(dst.getData());
6471
memData += offset;
65-
setSubnormalsToZero(memData, dst.getSize() / sizeof(float));
72+
setSubnormalsToZeroAndbf16Saturation(memData, dst.getSize() / sizeof(float), ftz, bf16saturation);
6673
}
6774

6875
} // namespace
@@ -125,11 +132,11 @@ void Memory::create(MemoryDescPtr desc, const void* data, bool pads_zeroing) {
125132
}
126133
}
127134

128-
void Memory::load(const IMemory& src, bool ftz) const {
135+
void Memory::load(const IMemory& src, bool ftz, bool bf16saturation) const {
129136
if (src.getDesc().getPrecision() == element::string) {
130137
OPENVINO_THROW("[CPU] Memory object cannot load string data.");
131138
}
132-
transferData(src, *this, ftz);
139+
transferData(src, *this, ftz, bf16saturation);
133140
}
134141

135142
void Memory::nullify() {
@@ -271,12 +278,12 @@ StringMemory::StringMemory(const dnnl::engine& engine, const MemoryDescPtr& desc
271278
}
272279
}
273280

274-
void StringMemory::load(const IMemory& src, bool ftz) const {
281+
void StringMemory::load(const IMemory& src, bool ftz, bool bf16saturation) const {
275282
if (src.getDesc().getPrecision() != element::string) {
276283
OPENVINO_THROW("[CPU] String memory cannot load a non-string object.");
277284
}
278285

279-
transferData(src, *this, false);
286+
transferData(src, *this, false, false);
280287
}
281288

282289
void* StringMemory::getData() const {
@@ -470,11 +477,11 @@ void StaticMemory::redefineDesc(MemoryDescPtr desc) {
470477
OPENVINO_THROW("Unexpected: Memory descriptor may not be modified in StaticMemory object");
471478
}
472479

473-
void StaticMemory::load(const IMemory& src, bool ftz) const {
480+
void StaticMemory::load(const IMemory& src, bool ftz, bool bf16saturation) const {
474481
if (src.getDesc().getPrecision() == element::string) {
475482
OPENVINO_THROW("[CPU] StaticMemory cannot load string data.");
476483
}
477-
transferData(src, *this, ftz);
484+
transferData(src, *this, ftz, bf16saturation);
478485
}
479486

480487
MemoryBlockPtr StaticMemory::getMemoryBlock() const {

src/plugins/intel_cpu/src/cpu_memory.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ class IMemory {
187187
// Caution!!! This action invalidates the previous data layout. The old data may become unreachable.
188188
virtual void redefineDesc(MemoryDescPtr desc) = 0;
189189

190-
virtual void load(const IMemory& src, bool ftz = true) const = 0;
190+
virtual void load(const IMemory& src, bool ftz = true, bool bf16saturation = false) const = 0;
191191

192192
virtual MemoryBlockPtr getMemoryBlock() const = 0;
193193

@@ -259,7 +259,7 @@ class StaticMemory final : public IMemory {
259259
// Always throws since a static memory descriptor should not be modified
260260
void redefineDesc(MemoryDescPtr desc) override;
261261

262-
void load(const IMemory& src, bool ftz = true) const override;
262+
void load(const IMemory& src, bool ftz = true, bool bf16saturation = false) const override;
263263

264264
MemoryBlockPtr getMemoryBlock() const override;
265265

@@ -314,7 +314,7 @@ class Memory : public IMemory {
314314

315315
void redefineDesc(MemoryDescPtr desc) override;
316316

317-
void load(const IMemory& src, bool ftz = true) const override;
317+
void load(const IMemory& src, bool ftz = true, bool bf16saturation = false) const override;
318318
void nullify() override;
319319

320320
dnnl::engine getEngine() const {
@@ -420,7 +420,7 @@ class StringMemory : public IMemory {
420420

421421
void redefineDesc(MemoryDescPtr desc) override;
422422

423-
void load(const IMemory& src, bool ftz = false) const override;
423+
void load(const IMemory& src, bool ftz = false, bool bf16saturation = false) const override;
424424

425425
MemoryBlockPtr getMemoryBlock() const override;
426426

src/plugins/intel_cpu/src/nodes/input.cpp

Lines changed: 71 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,71 @@ void Input::cloneBlobIfRequired() {
262262
needFlushDenormalsToZero = false;
263263
}
264264

265+
// The presence of subnormals is better to determined at IR read time.
266+
auto checkSubnormalsAndBF16Overflows = [&](bool& has_subnormals, bool& has_bf16_overflows) {
267+
if (prec == ov::element::f32) {
268+
uint32_t const* u32data = m_constOp->get_data_ptr<uint32_t>();
269+
float const* f32data = m_constOp->get_data_ptr<float>();
270+
271+
if (!size)
272+
return;
273+
274+
const float bf16_max = 3.3895313899137927e38f;
275+
276+
#if defined(OPENVINO_ARCH_X86_64)
277+
if (auto fn = jit_has_subnormals_function()) {
278+
static const size_t batch_size = 2048;
279+
const size_t iterations_num = size / batch_size + 1;
280+
281+
volatile bool has_subnormals_local = false;
282+
283+
parallel_for(iterations_num, [&](int n) {
284+
auto ptr = u32data + n * batch_size;
285+
const jit_has_subnormals_base::args_t args = {reinterpret_cast<float const*>(ptr),
286+
std::min(batch_size, (size_t)(u32data + size - ptr)),
287+
false};
288+
289+
fn(&args);
290+
291+
if (args.hasSubnormals)
292+
has_subnormals_local = true;
293+
});
294+
295+
has_subnormals = has_subnormals_local;
296+
//TODO: opt with jit
297+
for (size_t i = 0; i < size; ++i) {
298+
if (!std::isnan(f32data[i]) && !std::isinf(f32data[i]) &&
299+
(f32data[i] < -bf16_max || f32data[i] > bf16_max)) {
300+
has_bf16_overflows = true;
301+
return;
302+
}
303+
}
304+
return;
305+
}
306+
#endif
307+
308+
uint32_t mantissaMask = 0x007fffff;
309+
uint32_t exponentMask = 0x7f800000;
310+
for (size_t i = 0; i < size; ++i) {
311+
if ((u32data[i] & exponentMask) == 0 && (u32data[i] & mantissaMask) != 0) {
312+
has_subnormals = true;
313+
}
314+
if (!std::isnan(f32data[i]) && !std::isinf(f32data[i]) &&
315+
(f32data[i] < -bf16_max || f32data[i] > bf16_max)) {
316+
has_bf16_overflows = true;
317+
}
318+
if (has_subnormals && has_bf16_overflows) {
319+
return;
320+
}
321+
}
322+
}
323+
};
324+
325+
bool has_subnormals = false;
326+
bool has_bf16_overflows = false;
327+
328+
checkSubnormalsAndBF16Overflows(has_subnormals, has_bf16_overflows);
329+
265330
auto cloneBlob = [&, this]() {
266331
MemoryPtr memory;
267332

@@ -294,7 +359,7 @@ void Input::cloneBlobIfRequired() {
294359
} else {
295360
ptr = std::make_shared<StaticMemory>(getEngine(), memDesc);
296361
}
297-
ptr->load(*memory.get(), needFlushDenormalsToZero);
362+
ptr->load(*memory.get(), needFlushDenormalsToZero, has_bf16_overflows);
298363

299364
return ptr;
300365
};
@@ -311,60 +376,22 @@ void Input::cloneBlobIfRequired() {
311376
#endif
312377
};
313378

314-
// The presence of subnormals is better to determined at IR read time.
315-
auto hasSubnormals = [&]() {
316-
if (prec == ov::element::f32) {
317-
uint32_t const* u32data = m_constOp->get_data_ptr<uint32_t>();
318-
319-
if (!size)
320-
return false;
321-
322-
#if defined(OPENVINO_ARCH_X86_64)
323-
if (auto fn = jit_has_subnormals_function()) {
324-
static const size_t batch_size = 2048;
325-
const size_t iterations_num = size / batch_size + 1;
326-
327-
volatile bool has_subnormals = false;
328-
329-
parallel_for(iterations_num, [&](int n) {
330-
auto ptr = u32data + n * batch_size;
331-
const jit_has_subnormals_base::args_t args = {reinterpret_cast<float const*>(ptr),
332-
std::min(batch_size, (size_t)(u32data + size - ptr)),
333-
false};
334-
335-
fn(&args);
336-
337-
if (args.hasSubnormals)
338-
has_subnormals = true;
339-
});
340-
341-
return has_subnormals;
342-
}
343-
#endif
344-
345-
uint32_t mantissaMask = 0x007fffff;
346-
uint32_t exponentMask = 0x7f800000;
347-
for (size_t i = 0; i < size; ++i) {
348-
if ((u32data[i] & exponentMask) == 0 && (u32data[i] & mantissaMask) != 0) {
349-
return true;
350-
}
351-
}
352-
}
353-
return false;
354-
};
355-
356379
auto blobKey = [&]() {
357380
char ptr[32];
358381
snprintf(ptr, sizeof ptr, "%p", m_constOp->get_data_ptr());
359382
return getName() + "_" + std::to_string(size * prec.size()) + "_" + ptr;
360383
};
361384

385+
// my test
386+
if (has_bf16_overflows) {
387+
std::cout << "my test: has_bf16_overflows" << std::endl;
388+
}
362389
const auto weightCache = context->getWeightsCache();
363390
const bool clone_is_not_needed =
364391
prec != element::string &&
365392
// IRs already have all subnormals flushed to zero, but in
366393
// read_model scenario with directly loaded original model still can have subnormals
367-
isBlobAligned(m_constOp) && (!needFlushDenormalsToZero || !hasSubnormals()) &&
394+
isBlobAligned(m_constOp) && (!needFlushDenormalsToZero || !has_subnormals) && !has_bf16_overflows &&
368395
// Blob should be cloned in cache only if original weights are stored on other numa node.
369396
// This is possible only in multistream case on multisocket machine.
370397
// TODO: don't clone blob for multisocket + multistream case if current stream is run on the numa node where

src/plugins/intel_cpu/src/nodes/memory.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class MemoryStub : public IMemory {
8383
m_pMemDesc = desc;
8484
}
8585

86-
void load(const IMemory& src, bool ftz = true) const override {
86+
void load(const IMemory& src, bool ftz = true, bool bf16saturation = false) const override {
8787
OPENVINO_THROW("Unexpected call MemoryStub::load()");
8888
}
8989

src/plugins/intel_cpu/tests/unit/cpu_tensor_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class MockIMemory : public IMemory {
7171
MOCK_METHOD(const VectorDims&, getStaticDims, (), (const, override));
7272

7373
MOCK_METHOD(void, redefineDesc, (MemoryDescPtr), (override));
74-
MOCK_METHOD(void, load, (const IMemory&, bool), (const, override));
74+
MOCK_METHOD(void, load, (const IMemory&, bool, bool), (const, override));
7575
MOCK_METHOD(MemoryBlockPtr, getMemoryBlock, (), (const, override));
7676

7777
MOCK_METHOD(dnnl::memory, getPrimitive, (), (const, override));

0 commit comments

Comments
 (0)