@@ -30,6 +30,7 @@ limitations under the License.
3030
3131#include " absl/container/flat_hash_map.h"
3232#include " absl/container/flat_hash_set.h"
33+ #include " absl/log/check.h"
3334#include " absl/log/log.h"
3435#include " absl/status/status.h"
3536#include " absl/status/statusor.h"
@@ -95,18 +96,6 @@ std::string ShapeDescription(const Shape& shape) {
9596 return ShapeUtil::HumanStringWithLayout (shape);
9697}
9798
98- // A wrapper around ShapeUtil::ByteSizeOf that clears out the layout/padding,
99- // since that is considered in the ByteSizeOf calculation.
100- int64_t ShapeUnpaddedSize (Shape shape) {
101- // Ensure the layout has no padding by making it the default layout.
102- LayoutUtil::SetToDefaultLayout (&shape);
103- // Note: we make a simplifying assumption here that a "minimal" size for a
104- // tuple member would be the size of a `void*` -- there may be even fancier
105- // ways of doing things, but this should give a good enough approximation of
106- // what a minimal tuple size is.
107- return ShapeUtil::ByteSizeOf (shape, /* pointer_size=*/ sizeof (void *));
108- }
109-
11099class BufferAllocationStruct {
111100 public:
112101 explicit BufferAllocationStruct (const BufferAllocationProto& proto)
@@ -157,19 +146,21 @@ class BufferAllocationStruct {
157146struct LogicalBufferStruct {
158147 LogicalBufferStruct (const LogicalBufferProto& p,
159148 const BufferAllocationStruct& b,
160- const ::xla::HloInstructionProto& i, uint64_t offset)
149+ const ::xla::HloInstructionProto& i, uint64_t offset,
150+ int64_t unpadded_size)
161151 : proto(p),
162152 buffer_allocation (b),
163153 hlo_instruction(i),
164154 offset(offset),
165155 shape(ResolveShapeIndex(hlo_instruction.shape(),
166- proto.defined_at().shape_index())) {}
156+ proto.defined_at().shape_index())),
157+ unpadded_size_(unpadded_size) {}
167158
168159 absl::string_view instruction_name () const { return hlo_instruction.name (); }
169160
170161 int64_t color () const { return proto.color (); }
171162 size_t size () const { return proto.size (); }
172- size_t unpadded_size () const { return ShapeUnpaddedSize (shape) ; }
163+ size_t unpadded_size () const { return unpadded_size_ ; }
173164
174165 // reference counting related
175166 int64_t inc () {
@@ -217,6 +208,7 @@ struct LogicalBufferStruct {
217208 xla::Shape shape;
218209 int64_t ref_count = 0 ;
219210 LogicalBufferStruct* canonical_buffer = nullptr ;
211+ int64_t unpadded_size_;
220212};
221213
222214// A wrapper of HLO BufferAssignment, with lookup maps for logical buffers and
@@ -312,6 +304,11 @@ class HloProtoBufferWrapper {
312304 id_to_logical_buffer_proto[logical_buffer.id ()] = &logical_buffer;
313305 }
314306
307+ absl::StatusOr<absl::flat_hash_map<int64_t , int64_t >>
308+ logical_buffer_unpadded_sizes = ComputeLogicalBufferUnpaddedSizes (
309+ hlo_proto_.hlo_module (), hlo_proto_.buffer_assignment ());
310+ CHECK_OK (logical_buffer_unpadded_sizes);
311+
315312 for (const auto & buffer_allocation :
316313 hlo_proto_.buffer_assignment ().buffer_allocations ()) {
317314 auto & buffer_allocation_s =
@@ -333,7 +330,8 @@ class HloProtoBufferWrapper {
333330 const auto * instruction = unique_id_to_hlo.at (inst_id);
334331 id_to_logical_buffer_[id] = std::make_unique<LogicalBufferStruct>(
335332 *logical_buffer, *buffer_allocation_s, *instruction,
336- assigned.offset ());
333+ assigned.offset (),
334+ logical_buffer_unpadded_sizes->at (logical_buffer->id ()));
337335 }
338336 }
339337
@@ -514,7 +512,6 @@ void NoteSpecialAllocations(const HloProtoBufferWrapper& wrapper,
514512 int64_t entry_parameters_bytes = 0 ;
515513 int64_t non_reusable_bytes = 0 ;
516514 int64_t maybe_live_out_bytes = 0 ;
517- int64_t indefinite_buffer_allocation_bytes = 0 ;
518515 for (const auto * buffer_allocation_struct :
519516 wrapper.GetBufferAllocations (memory_color)) {
520517 const auto & buffer_allocation = buffer_allocation_struct->proto ();
@@ -533,7 +530,6 @@ void NoteSpecialAllocations(const HloProtoBufferWrapper& wrapper,
533530 maybe_live_out_bytes += buffer_allocation.size ();
534531 }
535532 if (buffer_allocation_struct->IsIndefinite ()) {
536- indefinite_buffer_allocation_bytes += buffer_allocation.size ();
537533 Convert (buffer_allocation, wrapper, result->add_indefinite_lifetimes ());
538534 }
539535 }
@@ -546,7 +542,8 @@ void NoteSpecialAllocations(const HloProtoBufferWrapper& wrapper,
546542 BytesToMiB (xla::ComputeTotalAllocationBytes (
547543 wrapper.GetHloProto ().buffer_assignment (), memory_color)));
548544 result->set_indefinite_buffer_allocation_mib (
549- BytesToMiB (indefinite_buffer_allocation_bytes));
545+ BytesToMiB (xla::ComputeIndefiniteAllocationsInBytes (
546+ wrapper.GetHloProto ().buffer_assignment (), memory_color)));
550547}
551548
552549// Memory usage statistics collected from heap simulator trace.
0 commit comments