feat(OM-054): ONNX and WebGPU export targets#216
Open
matdou wants to merge 3 commits into
Open
Conversation
…iter (sub-issue 1) - openmed/onnx/convert.py: ONNX exporter mirroring coreml/convert.py; uses torch.onnx.export + dynamic_axes instead of ct.convert + RangeDim. - openmed/eval/gates.py: fail-closed quant_delta/tier_fit seam with TODO(OM-032); raises NotImplementedError until OM-032 lands. - openmed/mlx/artifact.py: write_manifest now emits formats[] alongside the existing format scalar for back-compat (additive, no reader changes). - pyproject.toml: add [onnx] optional dependency group. - tests/unit/onnx/: T3 manifest back-compat (14 passing), T1 parity (skips until onnxruntime installed), gate seam assertions. - tests/node/: T4 onnxruntime-web WASM EP scaffold (skips until T1 generates fixture).
The dynamic_axes API silently produces static batch=1 in torch>=2.1's dynamo exporter. Fix: use dynamic_shapes with torch.export.Dim, and trace with batch=2 so BERT's position embeddings don't specialize the batch dimension as a constant. Also bump default opset to 18 (required by LayerNormalization in the new exporter). Legacy path for torch<2.1 unchanged (dynamic_axes + opset 14). T1 now fully passes with correct dynamic batch/seq shapes verified at batch=1,2,4.
…-issue 3) - openmed/onnx/webgpu.py: convert_to_fp16() using onnxruntime.transformers float16 converter with keep_io_types=True (int64 inputs, float32 outputs) - openmed/onnx/convert.py: add --webgpu CLI flag, fix --opset default 14→18 - tests/unit/onnx/test_webgpu.py: T5 tests — fp16 initializers present, default naming, IO types preserved, argmax matches fp32 within 0.1 - tests/node/test_onnx_webgpu.mjs: T6 WebGPU smoke (skip unless WEBGPU_SMOKE=1) Verified end-to-end with hf-internal-testing/tiny-bert-for-token-classification: torch vs fp32 max abs diff 7.45e-09, torch vs fp16 0.0000, dynamic shapes batch/seq 1–8 × 8–64 all correct.
Owner
|
Thank you @matdou for your great contribution. I will review this today and we'll get this merged 🎉 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Adds an ONNX export pipeline and a WebGPU-targeted fp16 artifact to OpenMed, so token-classification models can run in browsers and cross-platform runtimes. Implements the core of OM-054 decomposed into two independently-shippable pieces.
Type of Change
Changes Made
openmed/onnx/convert.py— exports any HF token-classification model to ONNX; handles torch 2.x dynamo exporter correctly (torch.export.Dim, opset 18, batch=2 trace to avoid BERT position-embedding specialisation); CLI mirrorscoreml/convert.pyopenmed/onnx/webgpu.py— fp16 post-export conversion viaonnxruntime.transformers.float16withkeep_io_types=True(int64 inputs, float32 outputs preserved for WebGPU EP); standalone CLI +--webgpuflag on the convert CLIopenmed/eval/gates.py— fail-closed stub for G4quant_deltaand G5tier_fitwithTODO(OM-032); raisesNotImplementedErroruntil OM-032 landsopenmed/mlx/artifact.py—write_manifest()now also emitsformats[]alongside the existing scalarformat(additive, no reader breakage)pyproject.toml—[onnx]optional dependency group:onnx>=1.14,onnxruntime>=1.16,torch>=2.0,transformers>=4.50Testing
T1 parity: torch vs ONNX CPU EP max abs diff
7.45e-09, 100% argmax agreement on batch 1–8 × seq 8–64. Verified end-to-end withhf-internal-testing/tiny-bert-for-token-classification.Documentation
Code Quality
Dependencies
onnx,onnxruntime, andtransformersare required for the ONNX export path; all gated behindpip install openmed[onnx]so existing installs are unaffectedChecklist
Related Issues
Closes #175