Skip to content

[PyTorch] transformer_engine.pytorch.autocast suport inside torch.compile#2759

Open
pggPL wants to merge 7 commits intoNVIDIA:mainfrom
pggPL:torch_compile_autocast
Open

[PyTorch] transformer_engine.pytorch.autocast suport inside torch.compile#2759
pggPL wants to merge 7 commits intoNVIDIA:mainfrom
pggPL:torch_compile_autocast

Conversation

@pggPL
Copy link
Collaborator

@pggPL pggPL commented Mar 13, 2026

Description

Enable torch.compile(fullgraph=True) for FP8 autocast by moving compile-visible mutable state off class attributes, avoiding tracing through support checks, and adding test.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Move mutable FP8 autocast state from direct cls attribute writes to a dataclass-backed singleton object, because torch.compile does not support writes directly to class attributes.
  • Replace lru_cache-based support checks with explicit module-level caches and mark the wrapper functions with @torch.compiler.assume_constant_result so torch.compile does not trace into check_*_support().
  • Add torch.compile coverage for FP8 autocast using a custom test module; the test is more involved because there is currently no simple TE layer that supports both FP8 and torch.compile.
  • Make DelayedScaling explicitly unsupported under torch.compile and raise a clear error for that case.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

pggPL added 3 commits March 13, 2026 14:02
Move FP8 global state onto an instance so Dynamo can trace autocast state updates, explicitly reject DelayedScaling under torch.compile, and add toy compile tests that keep TE forward/backward opaque while covering supported recipes.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Drop the standalone global dict and dataclass mutation experiments now that the torch.compile regression coverage lives in the focused autocast test file.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Use compiler constant-result wrappers for support checks and rename the module-level FP8 singleton to `_FP8_GLOBAL_STATE` for clearer semantics.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL force-pushed the torch_compile_autocast branch from 338ddae to b5d46fd Compare March 13, 2026 13:03
pggPL and others added 4 commits March 13, 2026 14:24
Restore the FP8 naming and remove extra state access helpers so the torch.compile changes stay focused on the instance-backed global state.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Drop stale availability fields from FP8GlobalState now that support checks use module-level cached results instead of manager state.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Resolve conflicts in the FP8 torch.compile changes while preserving the upstream updates in graph.py and module/base.py.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL marked this pull request as ready for review March 13, 2026 14:09
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 13, 2026

Greptile Summary

This PR enables torch.compile(fullgraph=True) for te.autocast FP8 regions by making three coordinated changes: (1) moving all mutable process-global FP8 state off FP8GlobalStateManager class attributes into a FP8GlobalState dataclass singleton (quantization_state), which torch.compile can trace through without hitting class-level setattr barriers; (2) replacing lru_cache-wrapped hardware-check functions with explicit module-level caches decorated with @torch.compiler.assume_constant_result, so the compiler treats GPU capability as a constant; and (3) explicitly blocking DelayedScaling under torch.compile with an early RuntimeError in check_recipe_support. A custom-op–backed ToyLinear test harness is added to validate fullgraph compilation across all three currently-supported recipes.

Key observations:

  • The FP8GlobalState dataclass uses slots=True, preventing accidental dynamic attribute addition and improving performance.
  • autocast() correctly saves/restores FP8 scalar state while autocast_depth is managed symmetrically via autocast_enter/autocast_exit, making nested autocast contexts work correctly.
  • check_recipe_support(recipe) does not guard against recipe=None during torch.compile, so calling te.autocast(enabled=True) without an explicit recipe inside a compiled function bypasses the gatekeeper and surfaces as an AssertionError in get_default_fp8_recipe() instead of the expected RuntimeError.
  • NVFP4BlockScaling has no compile-time guard in check_recipe_support and no test coverage in test_torch_compile.py; it is unclear whether it is supported or intentionally excluded.

Confidence Score: 4/5

  • Safe to merge with minor follow-up: the None-recipe gatekeeper gap and missing NVFP4 coverage should be addressed but are not blockers for the core FP8+compile feature.
  • The core refactor (dataclass singleton, assume_constant_result decorators, DelayedScaling guard) is well-structured and the nested-autocast state management is correct. The two gaps — recipe=None not raising a RuntimeError in check_recipe_support during compile, and NVFP4BlockScaling having no compile guard or test — are real omissions but affect only niche usage paths and do not break existing functionality.
  • transformer_engine/pytorch/quantization.py (check_recipe_support gatekeeper completeness) and tests/pytorch/test_torch_compile.py (missing NVFP4 coverage).

Important Files Changed

Filename Overview
transformer_engine/pytorch/quantization.py Core change: introduces FP8GlobalState dataclass singleton and module-level caches with @torch.compiler.assume_constant_result for hardware-check functions. check_recipe_support() doesn't handle recipe=None during compilation — a None recipe with enabled=True during torch.compile bypasses the gatekeeper and surfaces as AssertionError inside get_default_fp8_recipe() rather than a clear RuntimeError. NVFP4BlockScaling has no torch.compile guard here (only in autocast_enter via assert).
tests/pytorch/test_torch_compile.py New test file with ToyLinear custom-op harness for torch.compile+FP8 testing. Tests cover Float8CurrentScaling, MXFP8BlockScaling, Float8BlockScaling, nested autocast, and the DelayedScaling-under-compile error path. NVFP4BlockScaling is absent from the parametrized recipe list without explanation.
transformer_engine/pytorch/module/layernorm_mlp.py Replaces FP8GlobalStateManager.get_autocast_state()/set_autocast_state() calls with direct quantization_state attribute reads/writes. Also replaces IS_FIRST_FP8_MODULE class-attribute access with quantization_state.is_first_fp8_module and get_skip_fp8_weight_update_tensor() with direct attribute access.
transformer_engine/pytorch/graph.py Replaces set_skip_fp8_weight_update_tensor() helper with explicit lazy initialization + fill_ on quantization_state.skip_fp8_weight_update_tensor. Correct and safe within CUDA graph context.
transformer_engine/pytorch/distributed.py Minimal change: replaces IS_FIRST_FP8_MODULE class-attribute access with quantization_state.is_first_fp8_module in activation_recompute_forward.

Sequence Diagram

sequenceDiagram
    participant User as User Code (torch.compile)
    participant AC as te.autocast()
    participant CRS as check_recipe_support()
    participant FGSM as FP8GlobalStateManager
    participant GS as FP8GlobalState (singleton)

    User->>AC: enter with(te.autocast(recipe=R, enabled=True))
    AC->>CRS: check_recipe_support(R)
    alt R is DelayedScaling and compiling
        CRS-->>User: raise RuntimeError (blocked)
    else R is supported
        CRS-->>AC: ok
    end
    AC->>GS: read fp8_enabled, fp8_recipe, ... (save state)
    AC->>FGSM: autocast_enter(enabled, recipe, ...)
    FGSM->>GS: write fp8_enabled=True, fp8_recipe=R
    FGSM->>GS: autocast_depth += 1
    AC-->>User: yield (body executes)
    User->>FGSM: is_fp8_enabled() / get_fp8_recipe()
    FGSM->>GS: read fp8_enabled, fp8_recipe
    FGSM-->>User: True / R
    User-->>AC: body complete
    AC->>GS: restore saved fp8_enabled, fp8_recipe, ...
    AC->>FGSM: autocast_exit(enabled)
    FGSM->>GS: autocast_depth -= 1
    alt depth == 0 and enabled and grad_enabled
        FGSM->>FGSM: reduce_and_update_fp8_tensors()
    end
Loading

Comments Outside Diff (2)

  1. transformer_engine/pytorch/quantization.py, line 128-144 (link)

    None recipe not guarded during torch.compile tracing

    check_recipe_support is the new compile-time gatekeeper, but it silently passes when recipe=None. If a user calls te.autocast(enabled=True) without an explicit recipe inside torch.compile, none of the isinstance branches match None, so recipe_supported stays True and no error is raised here. Control then falls into autocast_enter, which calls get_default_fp8_recipe() — which raises an AssertionError (not a RuntimeError). The assertion message is reasonable, but the inconsistency means the gatekeeper doesn't catch all invalid-under-compile cases.

    A minimal guard would make the failure earlier and consistent with the DelayedScaling path:

    def check_recipe_support(recipe: Recipe) -> None:
        """Check if the given recipe is supported."""
        if torch.compiler.is_compiling():
            if recipe is None:
                raise RuntimeError(
                    "te.autocast() must be called with an explicit recipe under torch.compile. "
                    "Pass recipe=Float8CurrentScaling(), MXFP8BlockScaling(), or similar."
                )
            if isinstance(recipe, DelayedScaling):
                raise RuntimeError(
                    "DelayedScaling is not supported under torch.compile yet. "
                    "Use Float8CurrentScaling, MXFP8BlockScaling, or Float8BlockScaling instead."
                )
        ...
  2. transformer_engine/pytorch/quantization.py, line 128-144 (link)

    NVFP4BlockScaling missing from compile-time guard

    DelayedScaling is explicitly blocked under torch.compile with a RuntimeError, but NVFP4BlockScaling has no corresponding guard here. If NVFP4 is also unsupported under torch.compile (it is absent from the test parametrization in test_torch_compile.py), users will get no early error — they will instead hit whatever graph-break or tracing failure occurs deeper in the FP8 path.

    If NVFP4 is intentionally unsupported under torch.compile, the same pattern used for DelayedScaling should be applied:

    if torch.compiler.is_compiling() and isinstance(recipe, NVFP4BlockScaling):
        raise RuntimeError(
            "NVFP4BlockScaling is not supported under torch.compile yet."
        )

    If it is supported, a test case should be added to test_torch_compile.py to confirm there are no graph breaks.

Last reviewed commit: 16ea6e7

Comment on lines +178 to +200
_recipes = [
pytest.param(recipe.DelayedScaling(), False, id="no_fp8"),
pytest.param(
recipe.Float8CurrentScaling(),
True,
id="float8_current_scaling",
marks=pytest.mark.skipif(not _fp8_available, reason="FP8 not supported"),
),
pytest.param(
recipe.MXFP8BlockScaling(),
True,
id="mxfp8_block_scaling",
marks=pytest.mark.skipif(not _mxfp8_available, reason="MXFP8 not supported"),
),
pytest.param(
recipe.Float8BlockScaling(),
True,
id="float8_block_scaling",
marks=pytest.mark.skipif(
not _fp8_block_scaling_available, reason="FP8 block scaling not supported"
),
),
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NVFP4BlockScaling absent from parametrized recipe list without explanation

Float8CurrentScaling, MXFP8BlockScaling, and Float8BlockScaling all have entries, but NVFP4BlockScaling is not included. If NVFP4BlockScaling is unsupported under torch.compile, a test similar to test_autocast_delayed_scaling_unsupported should confirm that a clear error is raised. If it is supported, a parametrized entry with an appropriate skipif mark (e.g., not te.is_nvfp4_available()) would close the coverage gap.

The absence is currently unexplained, which may leave users uncertain about NVFP4 + torch.compile behavior.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant