Skip to content

Commit 6eaf528

Browse files
committed
[tsl] Make BFCAllocator respect user alignment
1 parent 1f10433 commit 6eaf528

File tree

5 files changed

+323
-51
lines changed

5 files changed

+323
-51
lines changed

xla/tsl/framework/BUILD

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -104,14 +104,10 @@ cc_library(
104104
"tracking_allocator.cc",
105105
"tracking_allocator.h",
106106
],
107-
hdrs = [
108-
"allocator.h",
109-
],
107+
hdrs = ["allocator.h"],
110108
features = ["parse_headers"],
111109
visibility = ["//visibility:public"],
112110
deps = [
113-
":numeric_types",
114-
":type_traits",
115111
"@com_google_absl//absl/base:core_headers",
116112
"@com_google_absl//absl/strings",
117113
"@com_google_absl//absl/strings:str_format",
@@ -126,9 +122,9 @@ cc_library(
126122
"//xla/tsl/platform:env_impl",
127123
"//xla/tsl/platform:logging",
128124
"//xla/tsl/platform:macros",
125+
"//xla/tsl/platform:types",
129126
"@tsl//tsl/platform:platform_port",
130127
"@tsl//tsl/platform:thread_annotations",
131-
"//xla/tsl/platform:types",
132128
],
133129
otherwise = [
134130
"//xla/tsl/lib/gtl:inlined_vector",
@@ -163,14 +159,11 @@ cc_library(
163159
"//third_party/xprof:__subpackages__",
164160
]),
165161
deps = [
166-
":numeric_types",
167-
":type_traits",
168162
"//xla/tsl/lib/gtl:inlined_vector",
169163
"//xla/tsl/platform:logging",
170164
"//xla/tsl/platform:macros",
171165
"//xla/tsl/platform:types",
172166
"@com_google_absl//absl/base:core_headers",
173-
"@com_google_absl//absl/strings",
174167
"@com_google_absl//absl/synchronization",
175168
"@tsl//tsl/platform:platform_port",
176169
"@tsl//tsl/platform:thread_annotations",
@@ -200,10 +193,12 @@ cc_library(
200193
"//xla/tsl/platform:types",
201194
"//xla/tsl/profiler/utils:trace_filter_utils",
202195
"//xla/tsl/protobuf:bfc_memory_map_proto_cc",
203-
"//xla/tsl/util:safe_reinterpret_cast",
196+
"@com_google_absl//absl/base",
204197
"@com_google_absl//absl/base:core_headers",
205198
"@com_google_absl//absl/container:flat_hash_set",
199+
"@com_google_absl//absl/numeric:bits",
206200
"@com_google_absl//absl/strings",
201+
"@com_google_absl//absl/strings:string_view",
207202
"@com_google_absl//absl/synchronization",
208203
"@com_google_absl//absl/time",
209204
"@tsl//tsl/platform:numbers",
@@ -213,6 +208,20 @@ cc_library(
213208
],
214209
)
215210

211+
tsl_cc_test(
212+
name = "bfc_allocator_test",
213+
srcs = ["bfc_allocator_test.cc"],
214+
deps = [
215+
":allocator",
216+
":bfc_allocator",
217+
"//xla/tsl/platform:env_impl", # buildcleaner: keep
218+
"//xla/tsl/platform:test",
219+
"@com_google_absl//absl/base",
220+
"@com_google_googletest//:gtest_main",
221+
"@tsl//tsl/platform:platform_port",
222+
],
223+
)
224+
216225
cc_library(
217226
name = "device_type",
218227
srcs = ["device_type.cc"],

xla/tsl/framework/allocator.h

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,15 @@ limitations under the License.
1616
#ifndef XLA_TSL_FRAMEWORK_ALLOCATOR_H_
1717
#define XLA_TSL_FRAMEWORK_ALLOCATOR_H_
1818

19-
#include <stdlib.h>
20-
19+
#include <cstdint>
2120
#include <functional>
22-
#include <limits>
2321
#include <optional>
22+
#include <string>
23+
#include <vector>
2424

25-
#include "absl/strings/string_view.h"
26-
#include "xla/tsl/framework/numeric_types.h"
27-
#include "xla/tsl/framework/type_traits.h"
25+
#include <stdlib.h>
2826
#include "xla/tsl/platform/logging.h"
2927
#include "xla/tsl/platform/macros.h"
30-
#include "xla/tsl/platform/types.h"
3128
#include "tsl/platform/numa.h"
3229

3330
namespace tsl {
@@ -108,7 +105,7 @@ struct AllocatorStats {
108105
enum class AllocatorMemoryType {
109106
kUnknown = 0, // Memory type unknown.
110107
kDevice = 1, // Memory on device.
111-
kHostPageable = 2, // Memory on host and it is pagable.
108+
kHostPageable = 2, // Memory on host and it is pageable.
112109
kHostPinned = 3, // Memory on host and it is pinned.
113110
};
114111

@@ -404,7 +401,7 @@ class SubAllocator {
404401
virtual ~SubAllocator() {}
405402
// Allocates at least num_bytes. Returns actual number of bytes allocated in
406403
// bytes_received. The caller can safely use the full bytes_received sized
407-
// buffer following the returend pointer.
404+
// buffer following the returned pointer.
408405
virtual void* Alloc(size_t alignment, size_t num_bytes,
409406
size_t* bytes_received) = 0;
410407
virtual void Free(void* ptr, size_t num_bytes) = 0;

xla/tsl/framework/bfc_allocator.cc

Lines changed: 64 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,10 @@ limitations under the License.
2929
#include <utility>
3030
#include <vector>
3131

32+
#include "absl/base/casts.h"
3233
#include "absl/base/thread_annotations.h"
3334
#include "absl/container/flat_hash_set.h"
35+
#include "absl/numeric/bits.h"
3436
#include "absl/strings/str_cat.h"
3537
#include "absl/strings/string_view.h"
3638
#include "absl/synchronization/mutex.h"
@@ -39,7 +41,6 @@ limitations under the License.
3941
#include "xla/tsl/platform/env.h"
4042
#include "xla/tsl/platform/file_system.h"
4143
#include "xla/tsl/platform/logging.h"
42-
#include "xla/tsl/platform/types.h"
4344
#include "xla/tsl/profiler/utils/trace_filter_utils.h"
4445
#include "xla/tsl/protobuf/bfc_memory_map.pb.h"
4546
#include "tsl/platform/numbers.h"
@@ -173,6 +174,10 @@ bool BFCAllocator::Extend(size_t alignment, size_t rounded_bytes) {
173174
curr_region_allocation_bytes_ *= 2;
174175
}
175176

177+
CHECK_EQ(absl::bit_cast<uintptr_t>(mem_addr) & (kMinAllocationSize - 1), 0)
178+
<< "SubAllocator must return memory aligned to at least "
179+
<< kMinAllocationSize << " bytes, got " << mem_addr;
180+
176181
VLOG(1) << "Extending allocation by "
177182
<< strings::HumanReadableNumBytes(bytes_received) << " bytes for "
178183
<< Name() << ".";
@@ -250,15 +255,14 @@ void BFCAllocator::DeallocateChunk(ChunkHandle h) {
250255
}
251256

252257
void* BFCAllocator::AllocateRawInternalWithRetry(
253-
size_t unused_alignment, size_t num_bytes,
258+
size_t alignment, size_t num_bytes,
254259
const AllocationAttributes& allocation_attr) {
255260
// Fast path: Try once to allocate without getting the retry_helper_ involved
256261
uint64_t freed_by_count = 0;
257262
if (allocation_attr.freed_by_func != nullptr) {
258263
freed_by_count = (*allocation_attr.freed_by_func)();
259264
}
260-
void* r =
261-
AllocateRawInternal(unused_alignment, num_bytes, false, freed_by_count);
265+
void* r = AllocateRawInternal(alignment, num_bytes, false, freed_by_count);
262266
if (r != nullptr) {
263267
return r;
264268
} else {
@@ -271,14 +275,15 @@ void* BFCAllocator::AllocateRawInternalWithRetry(
271275
}
272276
return AllocateRawInternal(a, nb, v, freed_by_count);
273277
},
274-
kMaxMillisToWait, unused_alignment, num_bytes);
278+
kMaxMillisToWait, alignment, num_bytes);
275279
return r;
276280
}
277281
}
278282

279-
void* BFCAllocator::AllocateRaw(size_t unused_alignment, size_t num_bytes,
283+
void* BFCAllocator::AllocateRaw(size_t alignment, size_t num_bytes,
280284
const AllocationAttributes& allocation_attr) {
281-
VLOG(3) << "AllocateRaw " << Name() << " " << num_bytes;
285+
VLOG(3) << "AllocateRaw " << Name() << " " << num_bytes
286+
<< " alignment=" << alignment;
282287
void* result = [&] {
283288
if (!opts_.allow_retry_on_failure || !allocation_attr.retry_on_failure) {
284289
// If we have globally disabled retry-on-failure and fail to allocate an
@@ -302,8 +307,8 @@ void* BFCAllocator::AllocateRaw(size_t unused_alignment, size_t num_bytes,
302307
if (allocation_attr.freed_by_func != nullptr) {
303308
freed_by_count = (*allocation_attr.freed_by_func)();
304309
}
305-
void* res = AllocateRawInternal(unused_alignment, num_bytes,
306-
dump_log_on_failure, freed_by_count);
310+
void* res = AllocateRawInternal(alignment, num_bytes, dump_log_on_failure,
311+
freed_by_count);
307312
if (res == nullptr) {
308313
int32_t counter_value = log_counter.load(std::memory_order_relaxed);
309314
if (counter_value < kMaxFailureLogs) {
@@ -321,7 +326,7 @@ void* BFCAllocator::AllocateRaw(size_t unused_alignment, size_t num_bytes,
321326
}
322327
return res;
323328
} else {
324-
return AllocateRawInternalWithRetry(unused_alignment, num_bytes,
329+
return AllocateRawInternalWithRetry(alignment, num_bytes,
325330
allocation_attr);
326331
}
327332
}();
@@ -431,8 +436,7 @@ void BFCAllocator::DeallocateRegions(
431436
}
432437
}
433438

434-
void* BFCAllocator::AllocateRawInternal(size_t unused_alignment,
435-
size_t num_bytes,
439+
void* BFCAllocator::AllocateRawInternal(size_t alignment, size_t num_bytes,
436440
bool dump_log_on_failure,
437441
uint64_t freed_before) {
438442
if (num_bytes == 0) {
@@ -444,6 +448,11 @@ void* BFCAllocator::AllocateRawInternal(size_t unused_alignment,
444448
// so all memory addresses are nicely byte aligned.
445449
size_t rounded_bytes = RoundedBytes(num_bytes);
446450

451+
// Alignment must be a power of two and at least kMinAllocationSize so that
452+
// splitting for alignment always produces kMinAllocationSize-aligned chunks.
453+
DCHECK(absl::has_single_bit(alignment)) << "alignment must be a power of 2";
454+
alignment = std::max(alignment, kMinAllocationSize);
455+
447456
// The BFC allocator tries to find the best fit first.
448457
BinNum bin_num = BinNumForSize(rounded_bytes);
449458

@@ -452,15 +461,17 @@ void* BFCAllocator::AllocateRawInternal(size_t unused_alignment,
452461
// Merge timestamped chunks whose counts have become safe for general use.
453462
MergeTimestampedChunks(0);
454463
}
455-
void* ptr = FindChunkPtr(bin_num, rounded_bytes, num_bytes, freed_before);
464+
void* ptr =
465+
FindChunkPtr(bin_num, rounded_bytes, num_bytes, alignment, freed_before);
456466
if (ptr != nullptr) {
457467
AddTraceMe("MemoryAllocation", ptr);
458468
return ptr;
459469
}
460470

461471
// Try to extend
462-
if (Extend(unused_alignment, rounded_bytes)) {
463-
ptr = FindChunkPtr(bin_num, rounded_bytes, num_bytes, freed_before);
472+
if (Extend(alignment, rounded_bytes)) {
473+
ptr = FindChunkPtr(bin_num, rounded_bytes, num_bytes, alignment,
474+
freed_before);
464475
if (ptr != nullptr) {
465476
AddTraceMe("MemoryAllocation", ptr);
466477
return ptr;
@@ -473,7 +484,8 @@ void* BFCAllocator::AllocateRawInternal(size_t unused_alignment,
473484
// timestamped chunks more aggressively until a free chunk of the necessary
474485
// size is formed.
475486
if (MergeTimestampedChunks(rounded_bytes)) {
476-
ptr = FindChunkPtr(bin_num, rounded_bytes, num_bytes, freed_before);
487+
ptr = FindChunkPtr(bin_num, rounded_bytes, num_bytes, alignment,
488+
freed_before);
477489
if (ptr != nullptr) {
478490
AddTraceMe("MemoryAllocation", ptr);
479491
return ptr;
@@ -486,8 +498,9 @@ void* BFCAllocator::AllocateRawInternal(size_t unused_alignment,
486498
// try deallocating free regions so that suballocator can combine them with
487499
// the unallocated bytes and form a larger region.
488500
if (DeallocateFreeRegions(rounded_bytes) &&
489-
Extend(unused_alignment, rounded_bytes)) {
490-
ptr = FindChunkPtr(bin_num, rounded_bytes, num_bytes, freed_before);
501+
Extend(alignment, rounded_bytes)) {
502+
ptr = FindChunkPtr(bin_num, rounded_bytes, num_bytes, alignment,
503+
freed_before);
491504
if (ptr != nullptr) {
492505
AddTraceMe("MemoryAllocation", ptr);
493506
return ptr;
@@ -555,7 +568,7 @@ void BFCAllocator::AddTraceMe(absl::string_view traceme_name,
555568
{"peak_bytes_in_use", stats_.peak_bytes_in_use},
556569
{"requested_bytes", req_bytes},
557570
{"allocation_bytes", alloc_bytes},
558-
{"addr", reinterpret_cast<uint64_t>(chunk_ptr)},
571+
{"addr", absl::bit_cast<uint64_t>(chunk_ptr)},
559572
{"tf_op", annotation.pending_op_name},
560573
{"id", annotation.pending_step_id},
561574
{"region_type", annotation.pending_region_type},
@@ -567,25 +580,52 @@ void BFCAllocator::AddTraceMe(absl::string_view traceme_name,
567580
}
568581

569582
void* BFCAllocator::FindChunkPtr(BinNum bin_num, size_t rounded_bytes,
570-
size_t num_bytes, uint64_t freed_before) {
583+
size_t num_bytes, size_t alignment,
584+
uint64_t freed_before) {
571585
// First identify the first bin that could satisfy rounded_bytes.
572586
for (; bin_num < kNumBins; bin_num++) {
573587
// Start searching from the first bin for the smallest chunk that fits
574588
// rounded_bytes.
575589
Bin* b = BinFromIndex(bin_num);
576590
for (auto citer = b->free_chunks.begin(); citer != b->free_chunks.end();
577591
++citer) {
578-
const BFCAllocator::ChunkHandle h = (*citer);
592+
BFCAllocator::ChunkHandle h = (*citer);
579593
BFCAllocator::Chunk* chunk = ChunkFromHandle(h);
580594
DCHECK(!chunk->in_use());
581595
if (freed_before > 0 && freed_before < chunk->freed_at_count) {
582596
continue;
583597
}
584-
if (chunk->size >= rounded_bytes) {
598+
599+
// Compute how many bytes we need to skip at the front of this chunk
600+
// to reach the requested alignment boundary.
601+
uintptr_t ptr_int = absl::bit_cast<uintptr_t>(chunk->ptr);
602+
size_t align_padding =
603+
(alignment - (ptr_int & (alignment - 1))) % alignment;
604+
// Round padding up to kMinAllocationSize so the prefix chunk is valid.
605+
align_padding = RoundedBytes(align_padding);
606+
607+
if (chunk->size >= rounded_bytes + align_padding) {
585608
// We found an existing chunk that fits us that wasn't in use, so remove
586609
// it from the free bin structure prior to using.
587610
RemoveFreeChunkIterFromBin(&b->free_chunks, citer);
588611

612+
// If alignment requires it, split off the unaligned prefix as a
613+
// separate free chunk.
614+
if (align_padding > 0) {
615+
SplitChunk(h, align_padding);
616+
// After splitting, h still points to the prefix chunk (size =
617+
// align_padding). The new aligned chunk is h's next and was
618+
// inserted into a free bin by SplitChunk.
619+
chunk = ChunkFromHandle(h);
620+
// Put the prefix back into the free bin.
621+
InsertFreeChunkIntoBin(h);
622+
// Advance to the aligned chunk and remove it from its free bin
623+
// so we can use it (and potentially split it again below).
624+
h = chunk->next;
625+
chunk = ChunkFromHandle(h);
626+
RemoveFreeChunkFromBin(h);
627+
}
628+
589629
// If we can break the size of the chunk into two reasonably large
590630
// pieces, do don't waste more than max_internal_fragmentation_bytes on
591631
// padding. If this threshold is not set by the user, then use 128MB as
@@ -1091,7 +1131,7 @@ void BFCAllocator::DumpMemoryLog(size_t num_bytes) {
10911131
}
10921132
std::string buf = absl::StrCat(
10931133
(c->in_use() ? "InUse" : "Free "), " at ",
1094-
absl::Hex(reinterpret_cast<uint64_t>(c->ptr)), " of size ", c->size);
1134+
absl::Hex(absl::bit_cast<uint64_t>(c->ptr)), " of size ", c->size);
10951135
#ifdef TENSORFLOW_MEM_DEBUG
10961136
if (ShouldRecordOpName()) {
10971137
absl::StrAppend(&buf, " by op ", c->op_name, " action_count ",
@@ -1187,7 +1227,7 @@ MemoryDump BFCAllocator::RecordMemoryMapInternal() {
11871227
const Chunk* c = ChunkFromHandle(h);
11881228
tensorflow::MemChunk* mc = md.add_chunk();
11891229
mc->set_in_use(c->in_use());
1190-
mc->set_address(reinterpret_cast<uint64_t>(c->ptr));
1230+
mc->set_address(absl::bit_cast<uint64_t>(c->ptr));
11911231
mc->set_size(c->size);
11921232
mc->set_requested_size(c->requested_size);
11931233
mc->set_bin(c->bin_num);

0 commit comments

Comments
 (0)