Skip to content

Commit efd8560

Browse files
Leahlijuang-husam
andauthored
feat(core): implement optimized serialization tensor saving (#12)
- Implement `_save_tensor_optimized` in `CheckpointSaver` to support writing tensors directly to mmap buffers using `MLF_TENS` format. - Update `CheckpointLoader` to transparently detect and load both the new `MLF_TENS` format and legacy `torch.save` format. - Enable `use_optimized_save` by default in NeMo `wrapper_util` to leverage performance improvements. --------- Co-authored-by: g-husam <husameldawi@google.com>
1 parent a2450dc commit efd8560

File tree

11 files changed

+722
-34
lines changed

11 files changed

+722
-34
lines changed

src/ml_flashpoint/adapter/nemo/wrapper_util.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint(
4343
always_save_context: bool = False,
4444
write_thread_count: int = 1,
4545
initial_write_buffer_size_bytes: int = DEFAULT_INITIAL_BUFFER_SIZE_BYTES,
46+
use_optimized_save: bool = True,
4647
) -> MLFlashpointAutoResume:
4748
"""Wraps the trainer and creates an MLFlashpointAutoResume instance wrapping `default_auto_resume`.
4849
@@ -87,6 +88,7 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint(
8788
always_save_context=always_save_context,
8889
write_thread_count=write_thread_count,
8990
initial_write_buffer_size_bytes=initial_write_buffer_size_bytes,
91+
use_optimized_save=use_optimized_save,
9092
)
9193

9294
default_auto_resume_args = vars(default_auto_resume) if default_auto_resume else {}
@@ -107,6 +109,7 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint(
107109
always_save_context: bool = False,
108110
write_thread_count: int = 1,
109111
initial_write_buffer_size_bytes: int = DEFAULT_INITIAL_BUFFER_SIZE_BYTES,
112+
use_optimized_save: bool = True,
110113
):
111114
"""Wraps the trainer's checkpoint I/O with ML Flashpoint capabilities.
112115
@@ -202,6 +205,7 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint(
202205
ckpt_obj_manager=ckpt_obj_manager,
203206
replication_manager=replication_manager,
204207
initial_buffer_size_bytes=initial_write_buffer_size_bytes,
208+
use_optimized_save=use_optimized_save,
205209
),
206210
mp_manager=torch_mp.Manager(),
207211
thread_count=write_thread_count,

src/ml_flashpoint/checkpoint_object_manager/buffer_io.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,40 @@ def tell(self) -> int:
323323
self._check_validity()
324324
return self._pos
325325

326+
def next_buffer_slice(self, size: int) -> memoryview:
327+
"""Returns a writable memoryview slice of the buffer at the current position.
328+
329+
This allows for zero-copy operations into the buffer (e.g., direct tensor copy).
330+
The stream position is advanced by `size` bytes.
331+
332+
Args:
333+
size: The size of the slice in bytes.
334+
335+
Returns:
336+
A writable memoryview slice.
337+
"""
338+
self._check_validity("write")
339+
if size < 0:
340+
raise ValueError(f"Size must be non-negative, got {size}")
341+
342+
actual_start = METADATA_SIZE + self._pos
343+
actual_end = actual_start + size
344+
345+
if actual_end > len(self._mv):
346+
raise ValueError(
347+
f"Requested slice (size={size}) exceeds buffer capacity "
348+
f"(pos={self._pos}, cap={len(self._mv) - METADATA_SIZE})"
349+
)
350+
351+
# Create the slice
352+
slice_mv = self._mv[actual_start:actual_end]
353+
354+
# Advance position
355+
self._pos += size
356+
self._update_written_data_length(self._pos)
357+
358+
return slice_mv
359+
326360
def close(self, truncate: bool = True) -> None:
327361
"""Closes the BufferIO stream and the underlying C++ BufferObject.
328362
@@ -433,3 +467,24 @@ def flush(self):
433467
# Check validity first, as flush() is still an I/O operation.
434468
self._check_validity()
435469
pass
470+
471+
@property
472+
def format_signature(self) -> bytes:
473+
"""Returns the format signature stored in the buffer metadata.
474+
475+
Returns:
476+
The format signature bytes.
477+
"""
478+
self._check_validity()
479+
return self._metadata.format_signature
480+
481+
def set_format_signature(self, signature: bytes) -> None:
482+
"""Sets the format signature in the buffer metadata.
483+
484+
Args:
485+
signature: The signature bytes to set. Must be at most 8 bytes.
486+
"""
487+
self._check_validity("write")
488+
if len(signature) > 8:
489+
raise ValueError(f"Format signature must be at most 8 bytes, got {len(signature)}")
490+
self._metadata.format_signature = signature

src/ml_flashpoint/checkpoint_object_manager/buffer_metadata.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@ class BufferMetadataType(ctypes.LittleEndianStructure):
2727
_fields_ = [
2828
# 8 bytes for the length of valid data written *after* the metadata block
2929
("len_written_data", ctypes.c_uint64),
30+
# 8 bytes for checkpoint format signature to identify the file format version
31+
("format_signature", ctypes.c_char * 8),
3032
# Pad the rest of the structure to reach METADATA_SIZE
3133
(
3234
"reserved",
33-
ctypes.c_uint8 * (METADATA_SIZE - ctypes.sizeof(ctypes.c_uint64)),
35+
ctypes.c_uint8 * (METADATA_SIZE - ctypes.sizeof(ctypes.c_uint64) - 8),
3436
),
3537
]
3638

src/ml_flashpoint/core/checkpoint_loader.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import os
2020
import pickle
2121
import re
22+
import struct
2223
from collections import defaultdict
2324
from pathlib import Path
2425
from typing import IO, List, Optional, Tuple, TypeVar, cast
@@ -42,6 +43,7 @@
4243
COMMON_STATE_FNAME,
4344
DIRTY_MARKER_SUFFIX,
4445
GLOBAL_RANK_PATTERN,
46+
CheckpointFormat,
4547
default_metadata_object_name,
4648
)
4749
from ml_flashpoint.core.mlf_logging import get_logger
@@ -158,20 +160,48 @@ def read_metadata(
158160
_LOGGER.exception("Error reading metadata from '%s'", metadata_path)
159161
raise
160162

161-
def read_tensor(self, buffer_slice: IO[bytes], req: ReadItem) -> torch.Tensor:
163+
def read_tensor(self, buffer_slice: IO[bytes], req: ReadItem, use_optimized_loader: bool = False) -> torch.Tensor:
162164
"""Read tensor from file slice.
163165
164166
Args:
165167
buffer_slice (IO[bytes]): file slice to read from.
166168
req (ReadItem): read item.
169+
use_optimized_loader (bool): whether to use optimized loader.
167170
168171
Returns:
169172
torch.Tensor: read tensor.
170173
"""
171-
tensor = cast(
172-
torch.Tensor,
173-
torch.load(cast(IO[bytes], buffer_slice), map_location="cpu", weights_only=True),
174-
)
174+
pos = buffer_slice.tell()
175+
tensor: Optional[torch.Tensor] = None
176+
177+
if use_optimized_loader:
178+
# Read as optimized format (TensorHeader)
179+
# First read 4 bytes for length
180+
len_bytes = buffer_slice.read(4)
181+
if len(len_bytes) == 4:
182+
header_len = struct.unpack("<I", len_bytes)[0]
183+
# stored header length should be reasonable, if it's too large, it might be legacy format
184+
if header_len < 1024 * 1024:
185+
pickle_bytes = buffer_slice.read(header_len)
186+
187+
try:
188+
tensor_header = pickle.loads(pickle_bytes)
189+
190+
tensor_dtype = tensor_header.dtype
191+
tensor_shape = tensor_header.shape
192+
data_bytes = buffer_slice.read()
193+
tensor = torch.frombuffer(data_bytes, dtype=tensor_dtype)
194+
tensor = tensor.reshape(tensor_shape)
195+
except Exception:
196+
_LOGGER.exception("Failed to parse tensor header")
197+
raise
198+
# Fallback to torch.load if optimized loader fails.
199+
if tensor is None:
200+
buffer_slice.seek(pos)
201+
tensor = cast(
202+
torch.Tensor,
203+
torch.load(cast(IO[bytes], buffer_slice), map_location="cpu", weights_only=True),
204+
)
175205
return narrow_tensor_by_index(tensor, req.storage_offsets, req.lengths)
176206

177207
def _try_retrieve_object_if_missing(self, checkpoint_object_id: CheckpointObjectId) -> bool:
@@ -270,6 +300,11 @@ def read_data(
270300
raise FileNotFoundError(error_msg)
271301

272302
with self._checkpoint_object_manager.get_buffer(checkpoint_object_id) as stream:
303+
use_optimized_loader = False
304+
if stream.format_signature == CheckpointFormat.MLF_FORMAT:
305+
use_optimized_loader = True
306+
_LOGGER.debug("Using optimized loader for '%s'", checkpoint_object_id.data)
307+
273308
for req in read_items:
274309
item_md = storage_data[req.storage_index]
275310
buffer_slice = cast(IO[bytes], _create_file_view(stream, item_md.offset, item_md.length))
@@ -278,7 +313,7 @@ def read_data(
278313
read_bytes.seek(0)
279314
planner.load_bytes(req, read_bytes)
280315
else:
281-
tensor = self.read_tensor(buffer_slice, req)
316+
tensor = self.read_tensor(buffer_slice, req, use_optimized_loader=use_optimized_loader)
282317
target_tensor = planner.resolve_tensor(req).detach()
283318
assert target_tensor.size() == tensor.size(), (
284319
f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"

src/ml_flashpoint/core/checkpoint_saver.py

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@
3333
from ml_flashpoint.checkpoint_object_manager.checkpoint_object_manager import CheckpointObjectManager
3434
from ml_flashpoint.checkpoint_object_manager.object_manager import object_manager_ext
3535
from ml_flashpoint.core.checkpoint_id_types import CheckpointContainerId, CheckpointObjectId
36-
from ml_flashpoint.core.defaults import DIRTY_MARKER_SUFFIX, default_metadata_object_name
36+
from ml_flashpoint.core.defaults import DIRTY_MARKER_SUFFIX, CheckpointFormat, default_metadata_object_name
3737
from ml_flashpoint.core.mlf_logging import get_logger
38+
from ml_flashpoint.core.tensor_header import TensorHeader
3839
from ml_flashpoint.core.utils import log_execution_time
3940
from ml_flashpoint.replication.replication_manager import ReplicationManager
4041

@@ -294,6 +295,7 @@ def __init__(
294295
ckpt_obj_manager: CheckpointObjectManager,
295296
replication_manager: ReplicationManager,
296297
initial_buffer_size_bytes: int = DEFAULT_INITIAL_BUFFER_SIZE_BYTES,
298+
use_optimized_save: bool = True,
297299
):
298300
"""Initializes the DefaultMLFlashpointCheckpointSaver.
299301
@@ -307,13 +309,16 @@ def __init__(
307309
across nodes.
308310
initial_buffer_size_bytes: The initial buffer size in bytes to use
309311
for writing data.
312+
use_optimized_save: Whether to use the optimized zero-copy tensor saving.
313+
Defaults to True.
310314
"""
311315
self._global_rank_getter = global_rank_getter
312316
self._local_rank_getter = local_rank_getter
313317
self._barrier_func = global_barrier_func
314318
self._chkpt_obj_manager = ckpt_obj_manager
315319
self._replication_manager = replication_manager
316320
self._initial_buffer_size_bytes = initial_buffer_size_bytes
321+
self._use_optimized_save = use_optimized_save
317322

318323
@override
319324
@log_execution_time(logger=_LOGGER, name="initialize_checkpoint")
@@ -443,14 +448,16 @@ def write_data(
443448
for i in range(1, thread_count):
444449
thread = threading.Thread(
445450
target=self._write_to_buffer_from_queue_worker,
446-
args=(object_items_queue, results_from_threads, replicate_after_write),
451+
args=(object_items_queue, results_from_threads, replicate_after_write, self._use_optimized_save),
447452
name=f"{self.__class__.__name__}-Thread-{i}",
448453
)
449454
threads.append(thread)
450455
thread.start()
451456

452457
# Main thread execution.
453-
self._write_to_buffer_from_queue_worker(object_items_queue, results_from_threads, replicate_after_write)
458+
self._write_to_buffer_from_queue_worker(
459+
object_items_queue, results_from_threads, replicate_after_write, self._use_optimized_save
460+
)
454461

455462
for thread in threads:
456463
thread.join()
@@ -581,13 +588,15 @@ def _write_to_buffer_from_queue_worker(
581588
object_write_bucket_queue: queue.Queue,
582589
results_from_threads: queue.Queue,
583590
replicate_after_write: bool,
591+
use_optimized_write: bool,
584592
):
585593
"""Worker function for writing data from a queue to buffer objects.
586594
587595
Args:
588596
object_write_bucket_queue: A queue containing `ObjectWriteBucket` instances to process.
589597
results_from_threads: A queue to put `(List[WriteResult], Exception)` tuples into.
590598
replicate_after_write: Whether to trigger async replication of each object after it is written.
599+
use_optimized_write: Whether to use optimized write.
591600
"""
592601
while not object_write_bucket_queue.empty():
593602
try:
@@ -614,11 +623,19 @@ def _write_to_buffer_from_queue_worker(
614623
self._initial_buffer_size_bytes,
615624
overwrite=True,
616625
) as buffer_io_writer:
626+
# Set the format signature
627+
if use_optimized_write:
628+
buffer_io_writer.set_format_signature(CheckpointFormat.MLF_FORMAT)
629+
else:
630+
buffer_io_writer.set_format_signature(CheckpointFormat.TORCH_SAVE)
631+
617632
# First write tensors.
618633
for tensor_item, tensor in tensor_tuples:
619634
write_start_offset = buffer_io_writer.tell()
620-
621-
torch.save(tensor, buffer_io_writer)
635+
if use_optimized_write:
636+
self._save_tensor_optimized(tensor, buffer_io_writer)
637+
else:
638+
torch.save(tensor, buffer_io_writer)
622639

623640
num_bytes_written = buffer_io_writer.tell() - write_start_offset
624641
item_storage_data = _StorageInfo(
@@ -690,3 +707,41 @@ def _remove_older_checkpoints(
690707
siblings_to_delete.add(full_path)
691708

692709
return object_manager_ext.delete_directories_async(list(siblings_to_delete))
710+
711+
def _save_tensor_optimized(self, tensor: torch.Tensor, buffer_io_writer):
712+
"""Saves a tensor to the buffer using a zero-copy approach where possible.
713+
714+
NOTE: This method saves the tensor's data in a C-contiguous format,
715+
regardless of its original memory layout (stride).
716+
The stride information is not saved.
717+
718+
Format:
719+
[4 bytes HEADER_LEN] [HEADER_BYTES (JSON)] [RAW_BYTES]
720+
721+
Args:
722+
tensor: The tensor to save.
723+
buffer_io_writer: The BufferIO instance to write to.
724+
"""
725+
# Metadata
726+
tensor_header = TensorHeader(dtype=tensor.dtype, shape=tensor.shape)
727+
728+
# Write Header (Len + JSON)
729+
header_data = tensor_header.to_bytes()
730+
buffer_io_writer.write(header_data)
731+
732+
# Write Data (Zero Copy)
733+
num_bytes = tensor.numel() * tensor.element_size()
734+
735+
# Get a writable slice of the underlying C++ buffer
736+
if num_bytes > 0:
737+
try:
738+
dest_mv = buffer_io_writer.next_buffer_slice(num_bytes)
739+
except AttributeError:
740+
_LOGGER.exception("BufferIO does not support next_buffer_slice, try to disable use_optimized_save.")
741+
raise
742+
743+
# Create a tensor wrapper around the buffer slice
744+
dest_tensor = torch.frombuffer(dest_mv, dtype=tensor.dtype, count=tensor.numel()).reshape(tensor.shape)
745+
746+
# Perform the actual copy.
747+
dest_tensor.copy_(tensor)

src/ml_flashpoint/core/defaults.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,20 @@
1313
# limitations under the License.
1414

1515
import re
16+
from enum import Enum
1617

1718
DIRTY_MARKER_SUFFIX = "unfinished"
1819
GLOBAL_RANK_PATTERN = re.compile(r"src(\d+)")
1920
COMMON_STATE_FNAME = "common.pt"
2021

2122

23+
class CheckpointFormat(bytes, Enum):
24+
# Standard PyTorch save format
25+
TORCH_SAVE = b"TORCH___"
26+
# Our custom optimized format
27+
MLF_FORMAT = b"MLF_TENS"
28+
29+
2230
def default_metadata_object_name() -> str:
2331
"""Returns the default object name for metadata files (i.e. filename).
2432

0 commit comments

Comments
 (0)