[PyTorch] transformer_engine.pytorch.autocast suport inside torch.compile#2759
[PyTorch] transformer_engine.pytorch.autocast suport inside torch.compile#2759pggPL wants to merge 7 commits intoNVIDIA:mainfrom
Conversation
Move FP8 global state onto an instance so Dynamo can trace autocast state updates, explicitly reject DelayedScaling under torch.compile, and add toy compile tests that keep TE forward/backward opaque while covering supported recipes. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Drop the standalone global dict and dataclass mutation experiments now that the torch.compile regression coverage lives in the focused autocast test file. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Use compiler constant-result wrappers for support checks and rename the module-level FP8 singleton to `_FP8_GLOBAL_STATE` for clearer semantics. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
338ddae to
b5d46fd
Compare
Restore the FP8 naming and remove extra state access helpers so the torch.compile changes stay focused on the instance-backed global state. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Drop stale availability fields from FP8GlobalState now that support checks use module-level cached results instead of manager state. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Resolve conflicts in the FP8 torch.compile changes while preserving the upstream updates in graph.py and module/base.py. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR enables Key observations:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as User Code (torch.compile)
participant AC as te.autocast()
participant CRS as check_recipe_support()
participant FGSM as FP8GlobalStateManager
participant GS as FP8GlobalState (singleton)
User->>AC: enter with(te.autocast(recipe=R, enabled=True))
AC->>CRS: check_recipe_support(R)
alt R is DelayedScaling and compiling
CRS-->>User: raise RuntimeError (blocked)
else R is supported
CRS-->>AC: ok
end
AC->>GS: read fp8_enabled, fp8_recipe, ... (save state)
AC->>FGSM: autocast_enter(enabled, recipe, ...)
FGSM->>GS: write fp8_enabled=True, fp8_recipe=R
FGSM->>GS: autocast_depth += 1
AC-->>User: yield (body executes)
User->>FGSM: is_fp8_enabled() / get_fp8_recipe()
FGSM->>GS: read fp8_enabled, fp8_recipe
FGSM-->>User: True / R
User-->>AC: body complete
AC->>GS: restore saved fp8_enabled, fp8_recipe, ...
AC->>FGSM: autocast_exit(enabled)
FGSM->>GS: autocast_depth -= 1
alt depth == 0 and enabled and grad_enabled
FGSM->>FGSM: reduce_and_update_fp8_tensors()
end
|
| _recipes = [ | ||
| pytest.param(recipe.DelayedScaling(), False, id="no_fp8"), | ||
| pytest.param( | ||
| recipe.Float8CurrentScaling(), | ||
| True, | ||
| id="float8_current_scaling", | ||
| marks=pytest.mark.skipif(not _fp8_available, reason="FP8 not supported"), | ||
| ), | ||
| pytest.param( | ||
| recipe.MXFP8BlockScaling(), | ||
| True, | ||
| id="mxfp8_block_scaling", | ||
| marks=pytest.mark.skipif(not _mxfp8_available, reason="MXFP8 not supported"), | ||
| ), | ||
| pytest.param( | ||
| recipe.Float8BlockScaling(), | ||
| True, | ||
| id="float8_block_scaling", | ||
| marks=pytest.mark.skipif( | ||
| not _fp8_block_scaling_available, reason="FP8 block scaling not supported" | ||
| ), | ||
| ), | ||
| ] |
There was a problem hiding this comment.
NVFP4BlockScaling absent from parametrized recipe list without explanation
Float8CurrentScaling, MXFP8BlockScaling, and Float8BlockScaling all have entries, but NVFP4BlockScaling is not included. If NVFP4BlockScaling is unsupported under torch.compile, a test similar to test_autocast_delayed_scaling_unsupported should confirm that a clear error is raised. If it is supported, a parametrized entry with an appropriate skipif mark (e.g., not te.is_nvfp4_available()) would close the coverage gap.
The absence is currently unexplained, which may leave users uncertain about NVFP4 + torch.compile behavior.
Description
Enable
torch.compile(fullgraph=True)for FP8 autocast by moving compile-visible mutable state off class attributes, avoiding tracing through support checks, and adding test.Type of change
Changes
Please list the changes introduced in this PR:
clsattribute writes to a dataclass-backed singleton object, becausetorch.compiledoes not support writes directly to class attributes.lru_cache-based support checks with explicit module-level caches and mark the wrapper functions with@torch.compiler.assume_constant_resultsotorch.compiledoes not trace intocheck_*_support().torch.compilecoverage for FP8 autocast using a custom test module; the test is more involved because there is currently no simple TE layer that supports both FP8 andtorch.compile.DelayedScalingexplicitly unsupported undertorch.compileand raise a clear error for that case.Checklist:
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: