[OMNIML-3050] Enable torch.compile on _get_log_softmax_dist#1479
Conversation
Wrap the distributed log-softmax helper used by AutoQuantizeKLDivSearcher with @torch.compile(dynamic=True). Inductor fuses the amax / all_reduce(MAX) / logsumexp / all_reduce(SUM) / log+sub+cast pipeline into one bandwidth-bound kernel and avoids intermediate materialization across the KL-div sensitivity loop. dynamic=True keeps the same compiled graph across varying [batch, seq] shapes per calibration batch and prevents Dynamo cache-size thrash. Drops the stale TODOs: the function is only reached under TP > 1 from AutoQuantizeKLDivSearcher.estimate_sensitivity_scores, never from the ONNX export path, and the @torch.compile decorator itself is import-time safe (matches the existing pattern in backends/fp8_per_tensor_gemm.py). Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
📝 WalkthroughWalkthroughThis PR enables Torch compilation on the ChangesEnable Torch Compilation
Estimated code review effort🎯 1 (Trivial) | ⏱️ ~2 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Comment |
|
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #1479 +/- ##
==========================================
+ Coverage 76.78% 76.87% +0.09%
==========================================
Files 473 473
Lines 51413 51414 +1
==========================================
+ Hits 39476 39525 +49
+ Misses 11937 11889 -48
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
What does this PR do?
Type of change: new feature
Wraps
_get_log_softmax_dist(modelopt/torch/quantization/algorithms.py) — the distributed log-softmax helper used byAutoQuantizeKLDivSearcherunder TP > 1 — with@torch.compile(dynamic=True). Inductor fuses theamax / all_reduce(MAX) / logsumexp / all_reduce(SUM) / log+sub+castpipeline into one kernel.dynamic=Trueavoids recompiles across the varying[batch, seq]shapes seen during calibration, matching the existing pattern inbackends/fp8_per_tensor_gemm.py. Stale TODOs are removed; the prior ONNX-Windows concern no longer applies because the function is only reachable when a TP group is initialized (never on the Windows CPU unit-test job), and@torch.compileis import-time safe.Usage
Testing
Verified locally against the affected unit-test paths:
pytest tests/unit/torch/quantization/test_autoquant.py -k kl_div→ 22 passed (covers the call chain into_get_log_prob).pytest tests/unit/onnx/→ 516 passed. The single failure (test_autocast_quantize.py::test_autocast_quantize_int8[False-True]— onnxruntimeCopyTensorAsync is not implemented) is a pre-existing failure onmain, unrelated to this change (confirmed by stashing the diff and re-running).pytest tests/unit/torch/quantization/test_autoquant.py tests/unit/torch/quantization/test_quantize_cpu.py tests/unit/torch/quantization/test_config_validation.py→ 159 passed.torch.log_softmaxreference for fp32 and fp16, with no recompiles across varying[batch, seq]shapes.pre-commit run --files modelopt/torch/quantization/algorithms.py→ ruff / mypy / bandit all pass.Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (`git commit -s -S`).
Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.).
Additional Information
Single-line behavior change: `_get_log_softmax_dist` is now `@torch.compile(dynamic=True)`-decorated. No API surface changes.