@@ -32,7 +32,7 @@ def _flatten_unflatten_for_dynamic_shapes(
3232 like replace them by a shape
3333
3434 Returns:
35- the serialized object
35+ the flattened object
3636 """
3737 if isinstance (obj , torch .Tensor ):
3838 return change_function (obj ) if change_function else obj
@@ -104,14 +104,22 @@ def _infer_dynamic_dimensions(
104104class InputCandidate :
105105 """Retains one set of inputs given to the forward method or any
106106 other method the class :class:`InputObserver` is stealing from.
107+ Any class is allowed as long as it can be flattened.
107108
108109 Args:
109- args: Positional arguments.
110- kwargs: Optional arguments.
111- clone: Clone the inputs before storing them. Some tensors
110+ args:
111+ Positional arguments.
112+ kwargs:
113+ Optional arguments.
114+ clone:
115+ Clones the inputs before storing them. Some tensors
112116 may be modified inplace, the original value must be retained.
113- cst_kwargs: Any optional arguments constant over multiple calls.
117+ cst_kwargs:
118+ Any optional arguments constant over multiple calls.
114119 int, float, str, bool values must be stored here.
120+
121+ The constructor flattens the received arguments.
122+ Any necessary flattening function should have been registered first.
115123 """
116124
117125 def __init__ (
@@ -671,18 +679,20 @@ class InputObserver:
671679 >>> )
672680
673681 With LLM:
682+
674683 >>> input_observer = InputObserver()
675684 >>> with input_observer(model):
676685 >>> model.generate(input_ids)
677686 >>> ep = torch.export.export( # or torch.onnx.export
678687 >>> model,
679- >>> ()
688+ >>> (),
680689 >>> kwargs=input_observer.infer_arguments(),
681690 >>> dynamic_shapes.input_observer.infer_dynamic_shapes(),
682691 >>> )
683692
684693 Examples can be found in :ref:`l-plot-tiny-llm-export-input-observer`,
685- :ref:`l-plot-whisper-tiny-export-input-observer`.
694+ :ref:`l-plot-whisper-tiny-export-input-observer`,
695+ :ref:`l-plot-gemma3-tiny-export-input-observer`.
686696 """
687697
688698 def __init__ (self , missing : dict [str , Any ] | None = None ):
@@ -865,17 +875,26 @@ def check_discrepancies(
865875 with the saved onnx model.
866876
867877 Args:
868- onnx_model: ONNX Model to verify.
869- atol: Absolute tolerance, recommended values, 1e-4 for float, 1e-2 flot float16.
870- rtol: Relative tolerance.
871- hist: Thresholds, the function determines the number of discrepancies
878+ onnx_model:
879+ ONNX Model to verify.
880+ atol:
881+ Absolute tolerance, recommended values, 1e-4 for float, 1e-2 flot float16.
882+ rtol:
883+ Relative tolerance.
884+ hist:
885+ Thresholds, the function determines the number of discrepancies
872886 above these thresholds.
873- progress_bar: Shows a progress bar (requires :epkg:`tqdm`).
874- include_io: Shows inputs/outputs shapes in the summary
887+ progress_bar:
888+ Shows a progress bar (requires :epkg:`tqdm`).
889+ include_io:
890+ Shows inputs/outputs shapes in the summary
875891 returned by this function.
876892
877893 Returns:
878894 A list of dictionaries, ready to be consumed by a dataframe.
895+
896+ The function catches exceptions, it shows the error in the returned
897+ summary.
879898 """
880899 sess = OnnxruntimeEvaluator (onnx_model , whole = True )
881900 input_names = sess .input_names
0 commit comments