diff --git a/docs/source-pytorch/deploy/production_advanced.rst b/docs/source-pytorch/deploy/production_advanced.rst index da4934a170fac..2fc882726eed9 100644 --- a/docs/source-pytorch/deploy/production_advanced.rst +++ b/docs/source-pytorch/deploy/production_advanced.rst @@ -59,6 +59,13 @@ Once you have the exported model, you can run it on your ONNX runtime in the fol ort_inputs = {input_name: np.random.randn(1, 64)} ort_outs = ort_session.run(None, ort_inputs) +If you want to catch a malformed export early, pass ``input_check=True`` to run +``onnx.checker.check_model`` on the saved file before you ever hand it off to a runtime: + +.. code-block:: python + + model.to_onnx(filepath, input_sample, input_check=True) + ---- **************************** diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 89e7da957f7a2..567d72571ec4a 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -14,6 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `log_key_prefix` parameter to `LearningRateMonitor` callback for prefixing logged metric names ([#21612](https://github.com/Lightning-AI/pytorch-lightning/issues/21612)) +- Added `input_check` argument to `LightningModule.to_onnx` to run `onnx.checker.check_model` on the exported model ([#7279](https://github.com/Lightning-AI/pytorch-lightning/issues/7279)) + ### Changed - diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 9c09aee4a8f24..0e1841b768f3b 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -1436,6 +1436,7 @@ def to_onnx( self, file_path: Union[str, Path, BytesIO, None] = None, input_sample: Optional[Any] = None, + input_check: bool = False, **kwargs: Any, ) -> Optional["ONNXProgram"]: """Saves the model in ONNX format. @@ -1443,6 +1444,9 @@ def to_onnx( Args: file_path: The path of the file the onnx model should be saved to. Default: None (no file saved). input_sample: An input for tracing. Default: None (Use self.example_input_array) + input_check: If ``True``, run :func:`onnx.checker.check_model` on the exported model to + validate its protobuf structure. Requires ``file_path`` to be a path or ``BytesIO``; + cannot be used when ``dynamo=True`` is passed via ``kwargs``. Default: ``False``. **kwargs: Will be passed to torch.onnx.export function. @@ -1470,6 +1474,14 @@ def forward(self, x): "requires `onnxscript` and `torch>=2.5.0` to be installed." ) + if input_check and kwargs.get("dynamo", False): + # dynamo path returns an ONNXProgram and may not produce a standalone protobuf file + # that `onnx.checker.check_model` can load. + raise ValueError("`input_check=True` is not supported together with `dynamo=True`.") + + if input_check and file_path is None: + raise ValueError("`input_check=True` requires `file_path` to be a path or BytesIO, got None.") + mode = self.training if input_sample is None: @@ -1488,6 +1500,19 @@ def forward(self, x): # BytesIO does work, too. ret = torch.onnx.export(self, input_sample, file_path, **kwargs) # type: ignore self.train(mode) + + if input_check: + import onnx + + if isinstance(file_path, BytesIO): + pos = file_path.tell() + file_path.seek(0) + onnx_model = onnx.load_model_from_string(file_path.read()) + file_path.seek(pos) + else: + onnx_model = onnx.load(file_path) + onnx.checker.check_model(onnx_model) + return ret @torch.no_grad() diff --git a/tests/tests_pytorch/models/test_onnx.py b/tests/tests_pytorch/models/test_onnx.py index 9f51332fbbdfa..972da311b38fe 100644 --- a/tests/tests_pytorch/models/test_onnx.py +++ b/tests/tests_pytorch/models/test_onnx.py @@ -142,6 +142,63 @@ def test_error_if_no_input(tmp_path): model.to_onnx(file_path) +@RunIf(onnx=True) +def test_input_check_runs_onnx_checker(tmp_path): + """`input_check=True` should load the exported model and run `onnx.checker.check_model`.""" + import onnx + + model = BoringModel() + input_sample = torch.randn((1, 32)) + + file_path = os.path.join(tmp_path, "model.onnx") + model.to_onnx(file_path, input_sample, input_check=True) + assert os.path.isfile(file_path) + # Sanity-check: same file should also pass onnx.checker when loaded independently. + onnx.checker.check_model(onnx.load(file_path)) + + # BytesIO path: the cursor position should be unchanged after the check reads the buffer. + buf = BytesIO() + model.to_onnx(file_path=buf, input_sample=input_sample, input_check=True) + end_pos = buf.tell() + assert end_pos > 4e2 + assert len(buf.getvalue()) == end_pos + + +@RunIf(onnx=True) +def test_input_check_raises_without_file_path(): + """`input_check=True` needs a path or BytesIO to load the exported model from.""" + model = BoringModel() + model.example_input_array = torch.randn((1, 32)) + with pytest.raises(ValueError, match=r"`input_check=True` requires `file_path`"): + model.to_onnx(file_path=None, input_check=True) + + +@RunIf(onnx=True) +def test_input_check_detects_invalid_model(tmp_path, monkeypatch): + """If the saved file isn't a valid ONNX model, `input_check=True` should raise.""" + import onnx + + model = BoringModel() + input_sample = torch.randn((1, 32)) + file_path = os.path.join(tmp_path, "model.onnx") + + def _raise(_): + raise onnx.checker.ValidationError("forced failure") + + monkeypatch.setattr(onnx.checker, "check_model", _raise) + with pytest.raises(onnx.checker.ValidationError, match="forced failure"): + model.to_onnx(file_path, input_sample, input_check=True) + + +@RunIf(onnx=True, min_torch="2.5.0", dynamo=True, onnxscript=True) +def test_input_check_rejects_dynamo(): + """`input_check=True` is not compatible with the dynamo exporter.""" + model = BoringModel() + model.example_input_array = torch.randn((1, 32)) + with pytest.raises(ValueError, match=r"`input_check=True` is not supported together with `dynamo=True`"): + model.to_onnx(input_check=True, dynamo=True) + + @pytest.mark.parametrize( "dynamo", [