Skip to content

feat(OM-054): ONNX and WebGPU export targets#216

Open
matdou wants to merge 3 commits into
maziyarpanahi:masterfrom
matdou:feature/om-054-onnx-webgpu-export
Open

feat(OM-054): ONNX and WebGPU export targets#216
matdou wants to merge 3 commits into
maziyarpanahi:masterfrom
matdou:feature/om-054-onnx-webgpu-export

Conversation

@matdou

@matdou matdou commented Jun 13, 2026

Copy link
Copy Markdown

Description

Adds an ONNX export pipeline and a WebGPU-targeted fp16 artifact to OpenMed, so token-classification models can run in browsers and cross-platform runtimes. Implements the core of OM-054 decomposed into two independently-shippable pieces.

Type of Change

  • New feature (non-breaking change which adds functionality)
  • Test addition/improvement

Changes Made

  • openmed/onnx/convert.py — exports any HF token-classification model to ONNX; handles torch 2.x dynamo exporter correctly (torch.export.Dim, opset 18, batch=2 trace to avoid BERT position-embedding specialisation); CLI mirrors coreml/convert.py
  • openmed/onnx/webgpu.py — fp16 post-export conversion via onnxruntime.transformers.float16 with keep_io_types=True (int64 inputs, float32 outputs preserved for WebGPU EP); standalone CLI + --webgpu flag on the convert CLI
  • openmed/eval/gates.py — fail-closed stub for G4 quant_delta and G5 tier_fit with TODO(OM-032); raises NotImplementedError until OM-032 lands
  • openmed/mlx/artifact.pywrite_manifest() now also emits formats[] alongside the existing scalar format (additive, no reader breakage)
  • pyproject.toml[onnx] optional dependency group: onnx>=1.14, onnxruntime>=1.16, torch>=2.0, transformers>=4.50

Testing

  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have tested this change with different models/inputs

T1 parity: torch vs ONNX CPU EP max abs diff 7.45e-09, 100% argmax agreement on batch 1–8 × seq 8–64. Verified end-to-end with hf-internal-testing/tiny-bert-for-token-classification.

uv run pytest tests/unit/onnx/ -v          # T1 parity, T3 manifest, T5 WebGPU
uv run pytest tests/ -q --ignore=tests/integration  # full suite: 1259 passed
cd tests/node && npm install && npm test   # T4 Node WASM EP (generate fixtures first)
WEBGPU_SMOKE=1 node tests/node/test_onnx_webgpu.mjs  # T6 headed-Chrome WebGPU smoke

Documentation

  • I have added docstrings to new functions/classes

Code Quality

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • My changes generate no new warnings

Dependencies

  • I have added new dependencies and they are justified because: onnx, onnxruntime, and transformers are required for the ONNX export path; all gated behind pip install openmed[onnx] so existing installs are unaffected

Checklist

  • I have read the contributing guidelines
  • My commits have clear, descriptive messages

Related Issues

Closes #175

matdou added 3 commits June 12, 2026 09:54
…iter (sub-issue 1)

- openmed/onnx/convert.py: ONNX exporter mirroring coreml/convert.py;
  uses torch.onnx.export + dynamic_axes instead of ct.convert + RangeDim.
- openmed/eval/gates.py: fail-closed quant_delta/tier_fit seam with
  TODO(OM-032); raises NotImplementedError until OM-032 lands.
- openmed/mlx/artifact.py: write_manifest now emits formats[] alongside
  the existing format scalar for back-compat (additive, no reader changes).
- pyproject.toml: add [onnx] optional dependency group.
- tests/unit/onnx/: T3 manifest back-compat (14 passing), T1 parity
  (skips until onnxruntime installed), gate seam assertions.
- tests/node/: T4 onnxruntime-web WASM EP scaffold (skips until T1
  generates fixture).
The dynamic_axes API silently produces static batch=1 in torch>=2.1's
dynamo exporter. Fix: use dynamic_shapes with torch.export.Dim, and
trace with batch=2 so BERT's position embeddings don't specialize the
batch dimension as a constant. Also bump default opset to 18 (required
by LayerNormalization in the new exporter). Legacy path for torch<2.1
unchanged (dynamic_axes + opset 14). T1 now fully passes with correct
dynamic batch/seq shapes verified at batch=1,2,4.
…-issue 3)

- openmed/onnx/webgpu.py: convert_to_fp16() using onnxruntime.transformers
  float16 converter with keep_io_types=True (int64 inputs, float32 outputs)
- openmed/onnx/convert.py: add --webgpu CLI flag, fix --opset default 14→18
- tests/unit/onnx/test_webgpu.py: T5 tests — fp16 initializers present,
  default naming, IO types preserved, argmax matches fp32 within 0.1
- tests/node/test_onnx_webgpu.mjs: T6 WebGPU smoke (skip unless WEBGPU_SMOKE=1)

Verified end-to-end with hf-internal-testing/tiny-bert-for-token-classification:
torch vs fp32 max abs diff 7.45e-09, torch vs fp16 0.0000, dynamic shapes
batch/seq 1–8 × 8–64 all correct.
@maziyarpanahi maziyarpanahi self-requested a review June 13, 2026 09:30
@maziyarpanahi

Copy link
Copy Markdown
Owner

Thank you @matdou for your great contribution. I will review this today and we'll get this merged 🎉

@maziyarpanahi maziyarpanahi added roadmap-v2 OpenMed V2 roadmap backlog epic Large; decompose into child issues first P1 High blocked Blocked by unmet dependencies or prerequisite roadmap work labels Jun 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

blocked Blocked by unmet dependencies or prerequisite roadmap work epic Large; decompose into child issues first P1 High roadmap-v2 OpenMed V2 roadmap backlog

Projects

None yet

Development

Successfully merging this pull request may close these issues.

EPIC: Add ONNX and WebGPU export targets to model conversion

2 participants