Skip to content

[OMNIML-3050] Enable torch.compile on _get_log_softmax_dist#1479

Merged
ajrasane merged 1 commit into
mainfrom
ajrasane/torch_compile
May 13, 2026
Merged

[OMNIML-3050] Enable torch.compile on _get_log_softmax_dist#1479
ajrasane merged 1 commit into
mainfrom
ajrasane/torch_compile

Conversation

@ajrasane
Copy link
Copy Markdown
Contributor

@ajrasane ajrasane commented May 13, 2026

What does this PR do?

Type of change: new feature

Wraps _get_log_softmax_dist (modelopt/torch/quantization/algorithms.py) — the distributed log-softmax helper used by AutoQuantizeKLDivSearcher under TP > 1 — with @torch.compile(dynamic=True). Inductor fuses the amax / all_reduce(MAX) / logsumexp / all_reduce(SUM) / log+sub+cast pipeline into one kernel. dynamic=True avoids recompiles across the varying [batch, seq] shapes seen during calibration, matching the existing pattern in backends/fp8_per_tensor_gemm.py. Stale TODOs are removed; the prior ONNX-Windows concern no longer applies because the function is only reachable when a TP group is initialized (never on the Windows CPU unit-test job), and @torch.compile is import-time safe.

Usage

# Internal — invoked automatically by AutoQuantize KL-divergence search under TP > 1:
import modelopt.torch.quantization as mtq

model, _ = mtq.auto_quantize(
    model,
    constraints={"effective_bits": 6.0},
    quantization_formats=[mtq.INT4_AWQ_CFG, mtq.INT8_DEFAULT_CFG],
    data_loader=calib_loader,
    forward_step=lambda m, b: m(b),
    method="kl_div",
)

Testing

Verified locally against the affected unit-test paths:

  • pytest tests/unit/torch/quantization/test_autoquant.py -k kl_div → 22 passed (covers the call chain into _get_log_prob).
  • pytest tests/unit/onnx/ → 516 passed. The single failure (test_autocast_quantize.py::test_autocast_quantize_int8[False-True] — onnxruntime CopyTensorAsync is not implemented) is a pre-existing failure on main, unrelated to this change (confirmed by stashing the diff and re-running).
  • pytest tests/unit/torch/quantization/test_autoquant.py tests/unit/torch/quantization/test_quantize_cpu.py tests/unit/torch/quantization/test_config_validation.py → 159 passed.
  • Functional sanity in a single-process gloo group: compiled output matches torch.log_softmax reference for fp32 and fp16, with no recompiles across varying [batch, seq] shapes.
  • pre-commit run --files modelopt/torch/quantization/algorithms.py → ruff / mypy / bandit all pass.

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (`git commit -s -S`).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.).

  • Is this change backward compatible?: ✅
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: N/A
  • Did you write any new necessary tests?: N/A — existing `kl_div` autoquant tests already exercise the call chain.
  • Did you update Changelog?: N/A
  • Did you get Claude approval on this PR?: ❌ — will run `/claude review` after marking ready.

Additional Information

Single-line behavior change: `_get_log_softmax_dist` is now `@torch.compile(dynamic=True)`-decorated. No API surface changes.

Wrap the distributed log-softmax helper used by AutoQuantizeKLDivSearcher
with @torch.compile(dynamic=True). Inductor fuses the
amax / all_reduce(MAX) / logsumexp / all_reduce(SUM) / log+sub+cast
pipeline into one bandwidth-bound kernel and avoids intermediate
materialization across the KL-div sensitivity loop. dynamic=True keeps
the same compiled graph across varying [batch, seq] shapes per
calibration batch and prevents Dynamo cache-size thrash.

Drops the stale TODOs: the function is only reached under TP > 1 from
AutoQuantizeKLDivSearcher.estimate_sensitivity_scores, never from the
ONNX export path, and the @torch.compile decorator itself is import-time
safe (matches the existing pattern in backends/fp8_per_tensor_gemm.py).

Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 13, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 13, 2026

📝 Walkthrough

Walkthrough

This PR enables Torch compilation on the _get_log_softmax_dist helper function by adding a @torch.compile(dynamic=True) decorator, replacing a prior TODO comment. The function signature and implementation remain unchanged.

Changes

Enable Torch Compilation

Layer / File(s) Summary
Torch compile decorator for log-softmax distribution
modelopt/torch/quantization/algorithms.py
_get_log_softmax_dist gains @torch.compile(dynamic=True) decorator, replacing a previous TODO comment about external breaking issues preventing compilation.

Estimated code review effort

🎯 1 (Trivial) | ⏱️ ~2 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: enabling torch.compile on the _get_log_softmax_dist function, which aligns with the file modification shown in the summary.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch ajrasane/torch_compile

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Comment @coderabbitai help to get the list of available commands and usage tips.

@ajrasane ajrasane marked this pull request as ready for review May 13, 2026 18:11
@ajrasane ajrasane requested a review from a team as a code owner May 13, 2026 18:11
@ajrasane ajrasane requested a review from kinjalpatel27 May 13, 2026 18:11
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 13, 2026

PR Preview Action v1.8.1
Preview removed because the pull request was closed.
2026-05-13 21:04 UTC

@codecov
Copy link
Copy Markdown

codecov Bot commented May 13, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 76.87%. Comparing base (62401e1) to head (922093d).

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1479      +/-   ##
==========================================
+ Coverage   76.78%   76.87%   +0.09%     
==========================================
  Files         473      473              
  Lines       51413    51414       +1     
==========================================
+ Hits        39476    39525      +49     
+ Misses      11937    11889      -48     
Flag Coverage Δ
examples 41.60% <100.00%> (+2.62%) ⬆️
gpu 59.72% <100.00%> (-0.59%) ⬇️
regression 14.98% <100.00%> (+0.07%) ⬆️
unit 52.55% <100.00%> (+0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ajrasane ajrasane merged commit 229ba61 into main May 13, 2026
49 checks passed
@ajrasane ajrasane deleted the ajrasane/torch_compile branch May 13, 2026 21:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants