@@ -29,8 +29,10 @@ limitations under the License.
2929#include < utility>
3030#include < vector>
3131
32+ #include " absl/base/casts.h"
3233#include " absl/base/thread_annotations.h"
3334#include " absl/container/flat_hash_set.h"
35+ #include " absl/numeric/bits.h"
3436#include " absl/strings/str_cat.h"
3537#include " absl/strings/string_view.h"
3638#include " absl/synchronization/mutex.h"
@@ -39,7 +41,6 @@ limitations under the License.
3941#include " xla/tsl/platform/env.h"
4042#include " xla/tsl/platform/file_system.h"
4143#include " xla/tsl/platform/logging.h"
42- #include " xla/tsl/platform/types.h"
4344#include " xla/tsl/profiler/utils/trace_filter_utils.h"
4445#include " xla/tsl/protobuf/bfc_memory_map.pb.h"
4546#include " tsl/platform/numbers.h"
@@ -173,6 +174,10 @@ bool BFCAllocator::Extend(size_t alignment, size_t rounded_bytes) {
173174 curr_region_allocation_bytes_ *= 2 ;
174175 }
175176
177+ CHECK_EQ (absl::bit_cast<uintptr_t >(mem_addr) & (kMinAllocationSize - 1 ), 0 )
178+ << " SubAllocator must return memory aligned to at least "
179+ << kMinAllocationSize << " bytes, got " << mem_addr;
180+
176181 VLOG (1 ) << " Extending allocation by "
177182 << strings::HumanReadableNumBytes (bytes_received) << " bytes for "
178183 << Name () << " ." ;
@@ -250,15 +255,14 @@ void BFCAllocator::DeallocateChunk(ChunkHandle h) {
250255}
251256
252257void * BFCAllocator::AllocateRawInternalWithRetry (
253- size_t unused_alignment , size_t num_bytes,
258+ size_t alignment , size_t num_bytes,
254259 const AllocationAttributes& allocation_attr) {
255260 // Fast path: Try once to allocate without getting the retry_helper_ involved
256261 uint64_t freed_by_count = 0 ;
257262 if (allocation_attr.freed_by_func != nullptr ) {
258263 freed_by_count = (*allocation_attr.freed_by_func )();
259264 }
260- void * r =
261- AllocateRawInternal (unused_alignment, num_bytes, false , freed_by_count);
265+ void * r = AllocateRawInternal (alignment, num_bytes, false , freed_by_count);
262266 if (r != nullptr ) {
263267 return r;
264268 } else {
@@ -271,14 +275,15 @@ void* BFCAllocator::AllocateRawInternalWithRetry(
271275 }
272276 return AllocateRawInternal (a, nb, v, freed_by_count);
273277 },
274- kMaxMillisToWait , unused_alignment , num_bytes);
278+ kMaxMillisToWait , alignment , num_bytes);
275279 return r;
276280 }
277281}
278282
279- void * BFCAllocator::AllocateRaw (size_t unused_alignment , size_t num_bytes,
283+ void * BFCAllocator::AllocateRaw (size_t alignment , size_t num_bytes,
280284 const AllocationAttributes& allocation_attr) {
281- VLOG (3 ) << " AllocateRaw " << Name () << " " << num_bytes;
285+ VLOG (3 ) << " AllocateRaw " << Name () << " " << num_bytes
286+ << " alignment=" << alignment;
282287 void * result = [&] {
283288 if (!opts_.allow_retry_on_failure || !allocation_attr.retry_on_failure ) {
284289 // If we have globally disabled retry-on-failure and fail to allocate an
@@ -302,8 +307,8 @@ void* BFCAllocator::AllocateRaw(size_t unused_alignment, size_t num_bytes,
302307 if (allocation_attr.freed_by_func != nullptr ) {
303308 freed_by_count = (*allocation_attr.freed_by_func )();
304309 }
305- void * res = AllocateRawInternal (unused_alignment , num_bytes,
306- dump_log_on_failure, freed_by_count);
310+ void * res = AllocateRawInternal (alignment , num_bytes, dump_log_on_failure ,
311+ freed_by_count);
307312 if (res == nullptr ) {
308313 int32_t counter_value = log_counter.load (std::memory_order_relaxed);
309314 if (counter_value < kMaxFailureLogs ) {
@@ -321,7 +326,7 @@ void* BFCAllocator::AllocateRaw(size_t unused_alignment, size_t num_bytes,
321326 }
322327 return res;
323328 } else {
324- return AllocateRawInternalWithRetry (unused_alignment , num_bytes,
329+ return AllocateRawInternalWithRetry (alignment , num_bytes,
325330 allocation_attr);
326331 }
327332 }();
@@ -431,8 +436,7 @@ void BFCAllocator::DeallocateRegions(
431436 }
432437}
433438
434- void * BFCAllocator::AllocateRawInternal (size_t unused_alignment,
435- size_t num_bytes,
439+ void * BFCAllocator::AllocateRawInternal (size_t alignment, size_t num_bytes,
436440 bool dump_log_on_failure,
437441 uint64_t freed_before) {
438442 if (num_bytes == 0 ) {
@@ -444,6 +448,11 @@ void* BFCAllocator::AllocateRawInternal(size_t unused_alignment,
444448 // so all memory addresses are nicely byte aligned.
445449 size_t rounded_bytes = RoundedBytes (num_bytes);
446450
451+ // Alignment must be a power of two and at least kMinAllocationSize so that
452+ // splitting for alignment always produces kMinAllocationSize-aligned chunks.
453+ DCHECK (absl::has_single_bit (alignment)) << " alignment must be a power of 2" ;
454+ alignment = std::max (alignment, kMinAllocationSize );
455+
447456 // The BFC allocator tries to find the best fit first.
448457 BinNum bin_num = BinNumForSize (rounded_bytes);
449458
@@ -452,15 +461,17 @@ void* BFCAllocator::AllocateRawInternal(size_t unused_alignment,
452461 // Merge timestamped chunks whose counts have become safe for general use.
453462 MergeTimestampedChunks (0 );
454463 }
455- void * ptr = FindChunkPtr (bin_num, rounded_bytes, num_bytes, freed_before);
464+ void * ptr =
465+ FindChunkPtr (bin_num, rounded_bytes, num_bytes, alignment, freed_before);
456466 if (ptr != nullptr ) {
457467 AddTraceMe (" MemoryAllocation" , ptr);
458468 return ptr;
459469 }
460470
461471 // Try to extend
462- if (Extend (unused_alignment, rounded_bytes)) {
463- ptr = FindChunkPtr (bin_num, rounded_bytes, num_bytes, freed_before);
472+ if (Extend (alignment, rounded_bytes)) {
473+ ptr = FindChunkPtr (bin_num, rounded_bytes, num_bytes, alignment,
474+ freed_before);
464475 if (ptr != nullptr ) {
465476 AddTraceMe (" MemoryAllocation" , ptr);
466477 return ptr;
@@ -473,7 +484,8 @@ void* BFCAllocator::AllocateRawInternal(size_t unused_alignment,
473484 // timestamped chunks more aggressively until a free chunk of the necessary
474485 // size is formed.
475486 if (MergeTimestampedChunks (rounded_bytes)) {
476- ptr = FindChunkPtr (bin_num, rounded_bytes, num_bytes, freed_before);
487+ ptr = FindChunkPtr (bin_num, rounded_bytes, num_bytes, alignment,
488+ freed_before);
477489 if (ptr != nullptr ) {
478490 AddTraceMe (" MemoryAllocation" , ptr);
479491 return ptr;
@@ -486,8 +498,9 @@ void* BFCAllocator::AllocateRawInternal(size_t unused_alignment,
486498 // try deallocating free regions so that suballocator can combine them with
487499 // the unallocated bytes and form a larger region.
488500 if (DeallocateFreeRegions (rounded_bytes) &&
489- Extend (unused_alignment, rounded_bytes)) {
490- ptr = FindChunkPtr (bin_num, rounded_bytes, num_bytes, freed_before);
501+ Extend (alignment, rounded_bytes)) {
502+ ptr = FindChunkPtr (bin_num, rounded_bytes, num_bytes, alignment,
503+ freed_before);
491504 if (ptr != nullptr ) {
492505 AddTraceMe (" MemoryAllocation" , ptr);
493506 return ptr;
@@ -555,7 +568,7 @@ void BFCAllocator::AddTraceMe(absl::string_view traceme_name,
555568 {" peak_bytes_in_use" , stats_.peak_bytes_in_use },
556569 {" requested_bytes" , req_bytes},
557570 {" allocation_bytes" , alloc_bytes},
558- {" addr" , reinterpret_cast <uint64_t >(chunk_ptr)},
571+ {" addr" , absl::bit_cast <uint64_t >(chunk_ptr)},
559572 {" tf_op" , annotation.pending_op_name },
560573 {" id" , annotation.pending_step_id },
561574 {" region_type" , annotation.pending_region_type },
@@ -567,25 +580,52 @@ void BFCAllocator::AddTraceMe(absl::string_view traceme_name,
567580}
568581
569582void * BFCAllocator::FindChunkPtr (BinNum bin_num, size_t rounded_bytes,
570- size_t num_bytes, uint64_t freed_before) {
583+ size_t num_bytes, size_t alignment,
584+ uint64_t freed_before) {
571585 // First identify the first bin that could satisfy rounded_bytes.
572586 for (; bin_num < kNumBins ; bin_num++) {
573587 // Start searching from the first bin for the smallest chunk that fits
574588 // rounded_bytes.
575589 Bin* b = BinFromIndex (bin_num);
576590 for (auto citer = b->free_chunks .begin (); citer != b->free_chunks .end ();
577591 ++citer) {
578- const BFCAllocator::ChunkHandle h = (*citer);
592+ BFCAllocator::ChunkHandle h = (*citer);
579593 BFCAllocator::Chunk* chunk = ChunkFromHandle (h);
580594 DCHECK (!chunk->in_use ());
581595 if (freed_before > 0 && freed_before < chunk->freed_at_count ) {
582596 continue ;
583597 }
584- if (chunk->size >= rounded_bytes) {
598+
599+ // Compute how many bytes we need to skip at the front of this chunk
600+ // to reach the requested alignment boundary.
601+ uintptr_t ptr_int = absl::bit_cast<uintptr_t >(chunk->ptr );
602+ size_t align_padding =
603+ (alignment - (ptr_int & (alignment - 1 ))) % alignment;
604+ // Round padding up to kMinAllocationSize so the prefix chunk is valid.
605+ align_padding = RoundedBytes (align_padding);
606+
607+ if (chunk->size >= rounded_bytes + align_padding) {
585608 // We found an existing chunk that fits us that wasn't in use, so remove
586609 // it from the free bin structure prior to using.
587610 RemoveFreeChunkIterFromBin (&b->free_chunks , citer);
588611
612+ // If alignment requires it, split off the unaligned prefix as a
613+ // separate free chunk.
614+ if (align_padding > 0 ) {
615+ SplitChunk (h, align_padding);
616+ // After splitting, h still points to the prefix chunk (size =
617+ // align_padding). The new aligned chunk is h's next and was
618+ // inserted into a free bin by SplitChunk.
619+ chunk = ChunkFromHandle (h);
620+ // Put the prefix back into the free bin.
621+ InsertFreeChunkIntoBin (h);
622+ // Advance to the aligned chunk and remove it from its free bin
623+ // so we can use it (and potentially split it again below).
624+ h = chunk->next ;
625+ chunk = ChunkFromHandle (h);
626+ RemoveFreeChunkFromBin (h);
627+ }
628+
589629 // If we can break the size of the chunk into two reasonably large
590630 // pieces, do don't waste more than max_internal_fragmentation_bytes on
591631 // padding. If this threshold is not set by the user, then use 128MB as
@@ -1091,7 +1131,7 @@ void BFCAllocator::DumpMemoryLog(size_t num_bytes) {
10911131 }
10921132 std::string buf = absl::StrCat (
10931133 (c->in_use () ? " InUse" : " Free " ), " at " ,
1094- absl::Hex (reinterpret_cast <uint64_t >(c->ptr )), " of size " , c->size );
1134+ absl::Hex (absl::bit_cast <uint64_t >(c->ptr )), " of size " , c->size );
10951135#ifdef TENSORFLOW_MEM_DEBUG
10961136 if (ShouldRecordOpName ()) {
10971137 absl::StrAppend (&buf, " by op " , c->op_name , " action_count " ,
@@ -1187,7 +1227,7 @@ MemoryDump BFCAllocator::RecordMemoryMapInternal() {
11871227 const Chunk* c = ChunkFromHandle (h);
11881228 tensorflow::MemChunk* mc = md.add_chunk ();
11891229 mc->set_in_use (c->in_use ());
1190- mc->set_address (reinterpret_cast <uint64_t >(c->ptr ));
1230+ mc->set_address (absl::bit_cast <uint64_t >(c->ptr ));
11911231 mc->set_size (c->size );
11921232 mc->set_requested_size (c->requested_size );
11931233 mc->set_bin (c->bin_num );
0 commit comments