Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source-pytorch/deploy/production_advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ Once you have the exported model, you can run it on your ONNX runtime in the fol
ort_inputs = {input_name: np.random.randn(1, 64)}
ort_outs = ort_session.run(None, ort_inputs)

If you want to catch a malformed export early, pass ``input_check=True`` to run
``onnx.checker.check_model`` on the saved file before you ever hand it off to a runtime:

.. code-block:: python

model.to_onnx(filepath, input_sample, input_check=True)

----

****************************
Expand Down
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `log_key_prefix` parameter to `LearningRateMonitor` callback for prefixing logged metric names ([#21612](https://github.com/Lightning-AI/pytorch-lightning/issues/21612))

- Added `input_check` argument to `LightningModule.to_onnx` to run `onnx.checker.check_model` on the exported model ([#7279](https://github.com/Lightning-AI/pytorch-lightning/issues/7279))

### Changed

-
Expand Down
25 changes: 25 additions & 0 deletions src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1436,13 +1436,17 @@ def to_onnx(
self,
file_path: Union[str, Path, BytesIO, None] = None,
input_sample: Optional[Any] = None,
input_check: bool = False,
**kwargs: Any,
) -> Optional["ONNXProgram"]:
"""Saves the model in ONNX format.

Args:
file_path: The path of the file the onnx model should be saved to. Default: None (no file saved).
input_sample: An input for tracing. Default: None (Use self.example_input_array)
input_check: If ``True``, run :func:`onnx.checker.check_model` on the exported model to
validate its protobuf structure. Requires ``file_path`` to be a path or ``BytesIO``;
cannot be used when ``dynamo=True`` is passed via ``kwargs``. Default: ``False``.

**kwargs: Will be passed to torch.onnx.export function.

Expand Down Expand Up @@ -1470,6 +1474,14 @@ def forward(self, x):
"requires `onnxscript` and `torch>=2.5.0` to be installed."
)

if input_check and kwargs.get("dynamo", False):
# dynamo path returns an ONNXProgram and may not produce a standalone protobuf file
# that `onnx.checker.check_model` can load.
raise ValueError("`input_check=True` is not supported together with `dynamo=True`.")

if input_check and file_path is None:
raise ValueError("`input_check=True` requires `file_path` to be a path or BytesIO, got None.")

mode = self.training

if input_sample is None:
Expand All @@ -1488,6 +1500,19 @@ def forward(self, x):
# BytesIO does work, too.
ret = torch.onnx.export(self, input_sample, file_path, **kwargs) # type: ignore
self.train(mode)

if input_check:
import onnx

if isinstance(file_path, BytesIO):
pos = file_path.tell()
file_path.seek(0)
onnx_model = onnx.load_model_from_string(file_path.read())
file_path.seek(pos)
else:
onnx_model = onnx.load(file_path)
onnx.checker.check_model(onnx_model)

return ret

@torch.no_grad()
Expand Down
57 changes: 57 additions & 0 deletions tests/tests_pytorch/models/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,63 @@ def test_error_if_no_input(tmp_path):
model.to_onnx(file_path)


@RunIf(onnx=True)
def test_input_check_runs_onnx_checker(tmp_path):
"""`input_check=True` should load the exported model and run `onnx.checker.check_model`."""
import onnx

model = BoringModel()
input_sample = torch.randn((1, 32))

file_path = os.path.join(tmp_path, "model.onnx")
model.to_onnx(file_path, input_sample, input_check=True)
assert os.path.isfile(file_path)
# Sanity-check: same file should also pass onnx.checker when loaded independently.
onnx.checker.check_model(onnx.load(file_path))

# BytesIO path: the cursor position should be unchanged after the check reads the buffer.
buf = BytesIO()
model.to_onnx(file_path=buf, input_sample=input_sample, input_check=True)
end_pos = buf.tell()
assert end_pos > 4e2
assert len(buf.getvalue()) == end_pos


@RunIf(onnx=True)
def test_input_check_raises_without_file_path():
"""`input_check=True` needs a path or BytesIO to load the exported model from."""
model = BoringModel()
model.example_input_array = torch.randn((1, 32))
with pytest.raises(ValueError, match=r"`input_check=True` requires `file_path`"):
model.to_onnx(file_path=None, input_check=True)


@RunIf(onnx=True)
def test_input_check_detects_invalid_model(tmp_path, monkeypatch):
"""If the saved file isn't a valid ONNX model, `input_check=True` should raise."""
import onnx

model = BoringModel()
input_sample = torch.randn((1, 32))
file_path = os.path.join(tmp_path, "model.onnx")

def _raise(_):
raise onnx.checker.ValidationError("forced failure")

monkeypatch.setattr(onnx.checker, "check_model", _raise)
with pytest.raises(onnx.checker.ValidationError, match="forced failure"):
model.to_onnx(file_path, input_sample, input_check=True)


@RunIf(onnx=True, min_torch="2.5.0", dynamo=True, onnxscript=True)
def test_input_check_rejects_dynamo():
"""`input_check=True` is not compatible with the dynamo exporter."""
model = BoringModel()
model.example_input_array = torch.randn((1, 32))
with pytest.raises(ValueError, match=r"`input_check=True` is not supported together with `dynamo=True`"):
model.to_onnx(input_check=True, dynamo=True)


@pytest.mark.parametrize(
"dynamo",
[
Expand Down
Loading