Skip to content

Fix for async dcp checkpointing with Float8Tensors#2721

Open
pstjohn wants to merge 7 commits intoNVIDIA:mainfrom
pstjohn:pstjohn/fix-async-dcp
Open

Fix for async dcp checkpointing with Float8Tensors#2721
pstjohn wants to merge 7 commits intoNVIDIA:mainfrom
pstjohn:pstjohn/fix-async-dcp

Conversation

@pstjohn
Copy link
Contributor

@pstjohn pstjohn commented Mar 2, 2026

(includes changes from #2698)

dcp.async_save fails silently with QuantizedTensor (Float8Tensor) — staged tensors contain uninitialized (NaN) data instead of actual FP8 values.

PyTorch's async save stages tensors to CPU by copying raw storage via new_empty() + deep_copy. Float8Tensor is a wrapper subclass with data_ptr()==0 (empty storage), so:

  1. new_empty() falls through to default dispatch, returning a plain tensor instead of a Float8Tensor
  2. The deep-copied _data/_scale_inv attributes land on the plain tensor but are ignored by DCP's write path

Changes

  • quantized_tensor.py: Handle aten.new_empty.default in torch_dispatch so staging preserves the Float8Tensor subclass type
  • float8_tensor_storage.py: Add a CPU fallback in dequantize() using PyTorch native FP8 dtypes, since tex.dequantize is CUDA-only and the staged tensor lives on CPU
  • run_fsdp2_fused_adam.py: Remove the _dequantize_state_dict workaround — dcp.async_save now works transparently

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
@pstjohn pstjohn force-pushed the pstjohn/fix-async-dcp branch from 1fba8c7 to 641898c Compare March 11, 2026 17:05
@pstjohn pstjohn marked this pull request as ready for review March 11, 2026 17:06
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 11, 2026

Greptile Summary

This PR fixes silent data corruption during dcp.async_save with Float8Tensor by intercepting three failure points in PyTorch's async staging pipeline. When staging tensors to CPU, DCP calls new_empty() to allocate a host buffer and then deep-copies the tensor's attributes — but Float8Tensor is a wrapper subclass with data_ptr()==0, so the default new_empty dispatch returned a plain tensor, and the copied _data/_scale_inv attributes were silently ignored during the write step.

Key changes:

  • quantized_tensor.py: Intercepts aten.new_empty.default in __torch_dispatch__ and delegates to tensor._quantizer.make_empty(...), preserving the Float8Tensor subclass type for the staged buffer. Includes a RuntimeError guard when _quantizer is None.
  • float8_tensor_storage.py: Adds a CPU fallback in _FromFloat8Func.forward that reinterprets the uint8 buffer as the corresponding FP8 dtype via .view(), then multiplies by _scale_inv on the correct device. Also fixes get_metadata() to include a device field with a safe None fallback for cleared tensors.
  • float8_tensor.py: Adds is_cpu property, and overrides __reduce_ex__ so that CPU-staged Float8Tensor instances serialize as dequantized plain tensors (compatible with torch.load(weights_only=True)). Both Float8Quantizer.make_empty and Float8CurrentScalingQuantizer.make_empty now forward device to the Float8Tensor constructor.
  • Tests: future.result() is moved to immediately after async_save (eliminating a potential race), the pytest.xfail blocks for DelayedScaling and Float8CurrentScaling are removed, and assert_close msg arguments are upgraded to lambdas for richer failure output.

One open issue found: The __reduce_ex__ correctly detects the case where only _transpose is on CPU (_data is None), but the underlying _FromFloat8Func.forward CPU path only handles the case when _data is not None. Calling dequantize() on a columnwise-only CPU-staged tensor will therefore still raise NotImplementedError("Casting back from the transpose not implemented yet!"). This edge case is unlikely to be hit in the FSDP2 DCP save scenario (where row-wise data is typically present), but the __reduce_ex__ logic implies it should be handled.

Confidence Score: 3/5

  • Safe to merge for the common rowwise Float8Tensor DCP path; one open edge-case bug remains for columnwise-only CPU-staged tensors.
  • The fix correctly addresses the core staging bug for the primary use case (rowwise Float8Tensors with DelayedScaling and Float8CurrentScaling). However, the __reduce_ex__ code claims to handle the transpose-only CPU path but delegates to dequantize(), which still raises NotImplementedError in _FromFloat8Func.forward when _data is None. This inconsistency leaves a silent landmine for columnwise-only checkpointing. The remainder of the changes (new_empty dispatch, CPU dequantize, get_metadata device field, test improvements) are correct and well-guarded.
  • transformer_engine/pytorch/tensor/float8_tensor.py and transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py — the transpose-only CPU dequantize path is unimplemented but silently claimed to be handled.

Important Files Changed

Filename Overview
transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py Adds CPU fallback dequantization path in _FromFloat8Func.forward and fixes get_metadata() to include device with a safe None fallback. The CPU path correctly reinterprets uint8 data as FP8 and applies _scale_inv. Minor: the TE-dtype lookup on line 32 (te_dtype = torch_to_transformer_engine_dtype[dtype]) is computed eagerly before the CPU branch diverges, so it is wasted work on the CPU path (and would KeyError for uncommon dtypes), though this is pre-existing behaviour.
transformer_engine/pytorch/tensor/float8_tensor.py Adds is_cpu property, fixes __reduce_ex__ to dequantize CPU-staged tensors using explicit _data/_transpose checks, and propagates device=device to Float8Tensor constructors in make_empty. The __reduce_ex__ branch for transpose_is_cpu=True, _data=None calls dequantize() which hits NotImplementedError in _FromFloat8Func.forward, since only the _data-present CPU path was implemented there.
transformer_engine/pytorch/quantized_tensor.py Adds aten.new_empty.default dispatch in __torch_dispatch__ so DCP async staging returns a proper Float8Tensor instead of falling through to a plain tensor. Guards against _quantizer=None with a clear RuntimeError, and correctly forwards dtype, device, pin_memory, and requires_grad to make_empty.
tests/pytorch/distributed/run_fsdp2_fused_adam.py Moves future.result() to immediately after dcp.async_save, eliminating a latent ordering ambiguity where CPU-staged data could be garbage-collected before blocking. Upgrades assert_close msg from a string to a lambda for richer failure output.
tests/pytorch/distributed/test_torch_fsdp2.py Removes the pytest.xfail block for DelayedScaling and Float8CurrentScaling recipes in the async DCP test, reflecting that the underlying staging bug is now fixed. Remaining xfail markers for MXFP8BlockScaling, NVFP4BlockScaling, and Float8BlockScaling are preserved.

Sequence Diagram

sequenceDiagram
    participant DCP as dcp.async_save
    participant Stager as StateDictStager
    participant FT as Float8Tensor
    participant QT as QuantizedTensor.__torch_dispatch__
    participant MakeEmpty as quantizer.make_empty
    participant ReduceEx as Float8Tensor.__reduce_ex__
    participant Deq as _FromFloat8Func.forward (CPU)

    DCP->>Stager: stage tensor to CPU
    Stager->>FT: new_empty(shape, device='cpu')
    FT->>QT: dispatch aten.new_empty.default
    QT->>MakeEmpty: make_empty(shape, device='cpu', ...)
    MakeEmpty-->>QT: Float8Tensor(_data=uint8[CPU], _scale_inv[CPU])
    QT-->>Stager: staged Float8Tensor (CPU)
    Stager->>FT: deep_copy(_data, _scale_inv, _transpose → CPU buffers)

    Note over DCP,Stager: DCP serializes the staged tensor via pickle

    DCP->>ReduceEx: __reduce_ex__(protocol)
    ReduceEx->>ReduceEx: data_is_cpu=True → dequantize(dtype=self.dtype)
    ReduceEx->>Deq: _FromFloat8Func.forward(tensor, dtype)
    Deq->>Deq: _data.view(fp8_dtype).float() * _scale_inv → .to(dtype)
    Deq-->>ReduceEx: plain BF16/FP32 tensor
    ReduceEx-->>DCP: (tensor.__reduce_ex__,) for plain tensor

    Note over DCP,Deq: DCP loads checkpoint back via dcp.load → model.load_state_dict
Loading

Comments Outside Diff (1)

  1. transformer_engine/pytorch/tensor/float8_tensor.py, line 997-1000 (link)

    dequantize raises NotImplementedError for transpose-only CPU tensors

    When transpose_is_cpu is True but _data is None (a columnwise-only Float8Tensor, produced when rowwise_usage=False), __reduce_ex__ correctly detects the CPU staging case and calls self.dequantize(dtype=self.dtype). However, _FromFloat8Func.forward only implements the CPU fallback when _data is not None:

    if tensor._data is not None:
        if tensor._data.is_cpu:
            # CPU fallback ...
        return tex.dequantize(...)
    raise NotImplementedError("Casting back from the transpose not implemented yet!")

    When _data is None the function always falls through to raise NotImplementedError, regardless of the device of _transpose. So a columnwise-only Float8Tensor staged to CPU via DCP async save will fail to pickle with an unhandled NotImplementedError rather than a clear error message or a successful serialization.

    The __reduce_ex__ code implies the intent to handle this path, but the underlying _FromFloat8Func.forward was not updated to match. Either the CPU path in _FromFloat8Func.forward needs a symmetric branch for _transpose, or the __reduce_ex__ check should be narrowed to only trigger when _data is not None and _data.is_cpu (acknowledging the transpose-only case is unimplemented):

    # Only dequantize when _data is on CPU; transpose-only CPU case is not yet supported
    data_is_cpu = self._data is not None and self._data.is_cpu
    if data_is_cpu:
        return self.dequantize(dtype=self.dtype).__reduce_ex__(protocol)

Last reviewed commit: 9e80320

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Peter St. John <peterc.stjohn@gmail.com>
Comment on lines +566 to +572
out = tensor._quantizer.make_empty(
shape=torch.Size(size),
dtype=dtype,
device=device,
requires_grad=tensor.requires_grad,
pin_memory=pin_memory,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AttributeError when _quantizer is None

tensor._quantizer can be None for Float8Tensor objects deserialized via the GPU path (_make_in_reduce_ex), which does not pass a quantizer argument. If a second async DCP save is attempted after a load/save round-trip, new_empty will be dispatched on the deserialized tensor, causing AttributeError: 'NoneType' object has no attribute 'make_empty'.

A guard is needed before calling make_empty:

if func == torch.ops.aten.new_empty.default:
    tensor = args[0]
    size = args[1]
    dtype = kwargs.get("dtype", tensor.dtype)
    device = kwargs.get("device", tensor.device)
    pin_memory = kwargs.get("pin_memory", False)
    if tensor._quantizer is None:
        raise RuntimeError(
            f"{type(tensor).__name__} does not have a quantizer; "
            "cannot create new_empty QuantizedTensor"
        )
    out = tensor._quantizer.make_empty(
        shape=torch.Size(size),
        dtype=dtype,
        device=device,
        requires_grad=tensor.requires_grad,
        pin_memory=pin_memory,
    )
    return out

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — added a guard that raises a clear RuntimeError if _quantizer is None.

Comment on lines +995 to +996
if self.is_cpu:
return self.dequantize(dtype=self.dtype).__reduce_ex__(protocol)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_cpu raises RuntimeError when both _data and _transpose are None

The is_cpu property raises RuntimeError("Both data and transpose are None") when both _data and _transpose are None — for example, after prepare_for_saving() or clear() sets them to None. The pre-existing GPU path in __reduce_ex__ handled this gracefully by passing None to _make_in_reduce_ex. The new check unconditionally calls self.is_cpu, introducing a regression for any Float8Tensor pickled in a cleared state.

A safe guard would be:

def __reduce_ex__(self, protocol: int) -> tuple:
    if self._data is not None and self._data.is_cpu:
        return self.dequantize(dtype=self.dtype).__reduce_ex__(protocol)
    return (
        Float8Tensor._make_in_reduce_ex,
        (self._data, self._fp8_dtype, self._scale_inv, self.dtype, self.shape),
    )

This covers only the case where _data is a CPU tensor (the staged DCP scenario) and falls through to the original path for all other cases, including _data is None.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait, didn't i just change this from the other comment starting with "CPU-only transpose path not handled in reduce_ex"?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right — I suggested changing to if self.is_cpu: to handle the CPU transpose path, but that creates a new issue when both _data and _transpose are None.

The is_cpu property raises RuntimeError("Both data and transpose are None") in that case, which the original GPU path handled gracefully by passing None to _make_in_reduce_ex.

The fix needs to handle all three cases:

  1. _data on CPU → dequantize
  2. _transpose on CPU (no _data) → dequantize
  3. Both None → fall through to original path

Here's a clean solution using try-except:

Suggested change
if self.is_cpu:
return self.dequantize(dtype=self.dtype).__reduce_ex__(protocol)
def __reduce_ex__(self, protocol: int) -> tuple:
"""Custom pickling to remove references to FP8 metadata objects
CPU Float8Tensors are serialized as dequantized plain tensors
for compatibility with torch.load(weights_only=True), which is
used by DCP async save staging.
"""
try:
is_cpu_tensor = self.is_cpu
except RuntimeError:
# Both _data and _transpose are None
is_cpu_tensor = False
if is_cpu_tensor:
return self.dequantize(dtype=self.dtype).__reduce_ex__(protocol)
return (
Float8Tensor._make_in_reduce_ex,
(self._data, self._fp8_dtype, self._scale_inv, self.dtype, self.shape),
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed by replacing self.is_cpu with explicit checks on _data and _transpose, so cleared tensors fall through to the existing GPU pickling path without raising.

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Comment on lines 138 to 140
"quantizer": self._quantizer,
"device": self.device,
"fake_dtype": self._dtype,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_metadata() raises when tensor is in cleared state

Adding "device": self.device is correct for the normal lifecycle, but Float8TensorStorage.device raises RuntimeError("Float8TensorStorage has no data!") when both _data and _transpose are None — exactly the state left by prepare_for_saving() or clear().

Before this PR, get_metadata() returned None for data and data_transpose without raising. Now any call to get_metadata() (e.g., via make_like()) on a cleared tensor would raise instead of propagating gracefully.

A safe guard:

"device": self._data.device if self._data is not None
          else (self._transpose.device if self._transpose is not None else None),

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed — the device field now uses explicit _data/_transpose checks with a None fallback, matching the pattern used elsewhere.

…e.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman
Copy link
Member

/te-ci pytorch

pre-commit-ci bot and others added 2 commits March 11, 2026 19:25
- Add _quantizer None guard in new_empty dispatch
- Replace self.is_cpu with explicit _data/_transpose checks in __reduce_ex__
- Make get_metadata() safe for cleared tensors (both _data and _transpose None)

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
@pstjohn pstjohn force-pushed the pstjohn/fix-async-dcp branch from 368869f to 5ce704e Compare March 13, 2026 20:37
@ksivaman
Copy link
Member

@greptileai

@ksivaman
Copy link
Member

/te-ci pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants