Skip to content

Commit 7a28903

Browse files
authored
removed bool, int, float, None as input dummies for the exporter in m… (#383)
* removed bool, int, float, None as input dummies for the exporter in method_to_onnx * doc
1 parent bc71c8b commit 7a28903

File tree

3 files changed

+92
-11
lines changed

3 files changed

+92
-11
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.8.9
55
+++++
66

7+
* :pr:`383`: removed bool, int, float, None as input dummies for the exporter in ``method_to_onnx``
78
* :pr:`382`: make the ordering of the inferred dynamic shapes more robust
89
* :pr:`381`: add parameter *expand_batch_for* to ``method_to_onnx``
910
* :pr:`378`: implements the computation of discrepancies in ``method_to_onnx``

_doc/examples/plot_export_tiny_llm_method_generate.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,6 @@ def generate_text(
8484
# the others are used to infer the dynamic shapes if they are not
8585
# specified below
8686
convert_after_n_calls=3,
87-
# skips the following inputs even though they are captured,
88-
# these ones are filled with default values we don't want in
89-
# the onnx model
90-
skip_kwargs_names={"kwargs", "use_cache", "return_dict", "inputs_embeds"},
9187
# The input used in the example has a batch size equal to 1, all
9288
# inputs going through method forward will have the same batch size.
9389
# To force the dynamism of this dimension, we need to indicate
@@ -105,20 +101,20 @@ def generate_text(
105101
# .. code-block:: python
106102
#
107103
# dynamic_shapes={
108-
# "cache_position": {0: "total_sequence_length"},
104+
# "cache_position": {0: "sequence_length"},
109105
# "past_key_values": [
110106
# {0: "batch_size", 2: "past_sequence_length"},
111107
# {0: "batch_size", 2: "past_sequence_length"},
112108
# ],
113109
# "input_ids": {0: "batch_size", 1: "sequence_length"},
114-
# "attention_mask": {0: "batch_size", 1: "sequence_length"},
110+
# "attention_mask": {0: "batch_size", 1: "total_sequence_length"},
115111
# }
116112
#
117113
# Finally, we need to replace the forward method.
118114
# As ``forward_replacement`` is a module of type
119115
# :class:`onnx_diagnostic.export.api.WrapperToExportMethodToOnnx`,
120116
# a lambda function must be used to avoid this one to be
121-
# included as a submodule (and an infinite loop).
117+
# included as a submodule (and create an infinite loop).
122118

123119
print(f"type(forward_replacement)={type(forward_replacement)}")
124120
model.forward = lambda *args, **kwargs: forward_replacement(*args, **kwargs)

onnx_diagnostic/export/api.py

Lines changed: 88 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -431,11 +431,19 @@ def _reorder_kwargs(self, kwargs):
431431
def forward(self, *args, **kwargs):
432432
if not self._export_done:
433433
inp_args = args
434-
# filters out the inputs not desired
434+
# filters out the inputs not desired, int, float, bool, None
435+
# are considered as constant for the exporter, they are removed
436+
# from the named arguments.
435437
inp_kwargs = (
436438
kwargs
437-
if not kwargs or not self.skip_kwargs_names
438-
else {k: v for k, v in kwargs.items() if k not in self.skip_kwargs_names}
439+
if not kwargs
440+
else {
441+
k: v
442+
for k, v in kwargs.items()
443+
if v is not None
444+
and (not self.skip_kwargs_names or k not in self.skip_kwargs_names)
445+
and not isinstance(v, (bool, int, float))
446+
}
439447
)
440448
if self.expand_batch_for:
441449
# extends the inputs to artificially create a batch dimension != 1.
@@ -538,7 +546,10 @@ def __init__(self, parent):
538546
if self.dynamic_batch_for:
539547
nds = (
540548
self._dynamic_batch_dimension(nds[0], self.dynamic_batch_for),
541-
self._dynamic_batch_dimension(nds[1], self.dynamic_batch_for),
549+
self.rename_dynamic_shapes(
550+
self._dynamic_batch_dimension(nds[1], self.dynamic_batch_for),
551+
verbose=self.verbose,
552+
),
542553
)
543554
if self.verbose:
544555
print(f"[method_to_onnx] dynamic_batch_for={self.dynamic_batch_for}")
@@ -842,6 +853,79 @@ def check_discrepancies(
842853
print("[method_to_onnx.check_discrepancies] done")
843854
return data
844855

856+
@classmethod
857+
def _apply_known_shape_pattern(
858+
cls, shape: Dict[int, Any], pattern: Dict[int, str]
859+
) -> Dict[int, Any]:
860+
return {k: pattern.get(k, v) for k, v in shape.items()}
861+
862+
@classmethod
863+
def get_dynamic_shape_patterns(cls) -> Dict[str, Any]:
864+
"""
865+
Returns the known patterns for the dynamic shapes.
866+
867+
.. runpython::
868+
:showcode:
869+
870+
import pprint
871+
from onnx_diagnostic.export.api import WrapperToExportMethodToOnnx
872+
pprint.pprint(WrapperToExportMethodToOnnx.get_dynamic_shape_patterns())
873+
"""
874+
return {
875+
"LLM.text": {
876+
"cache_position": {0: "seqlength"},
877+
"past_key_values": {0: "batch", 2: "pastlength"},
878+
"input_ids": {0: "batch", 1: "seqlength"},
879+
"attention_mask": {0: "batch", 1: "totallength"}, # pastlength+seqlength
880+
}
881+
}
882+
883+
@classmethod
884+
def rename_dynamic_shapes(cls, ds: Dict[str, Any], verbose: int = 0) -> Dict[str, Any]:
885+
"""
886+
Renames the dynamic shapes with names.
887+
Tries to rename any dynamic dimnesion dimension
888+
before export. It is not very clever, it just tries
889+
to recognize a known configuration based on input names.
890+
Dimension names in dynamic shapes are renamed if *ds* has
891+
the same number of named arguments as the one of the patterns
892+
returned by function :meth:`get_dynamic_shape_patterns
893+
<onnx_diagnostic.export.api.WrapperToExportMethodToOnnx.get_dynamic_shape_patterns>`.
894+
"""
895+
is_shape = lambda s: isinstance(s, dict) and all( # noqa: E731
896+
isinstance(_, int) for _ in s
897+
)
898+
llm_patterns = cls.get_dynamic_shape_patterns()
899+
for pattern_name, pattern_shape in llm_patterns.items():
900+
if len(set(ds) & set(pattern_shape)) == len(pattern_shape):
901+
if verbose:
902+
print(
903+
f"[method_to_onnx.rename_dynamic_shapes] "
904+
f"apply pattern shapes {pattern_name!r}"
905+
)
906+
new_ds = {}
907+
for k, v in ds.items():
908+
if k not in pattern_shape:
909+
new_ds[k] = v
910+
continue
911+
if is_shape(v):
912+
# A shape
913+
new_ds[k] = cls._apply_known_shape_pattern(v, pattern_shape[k])
914+
elif isinstance(v, list):
915+
# A cache
916+
new_ds[k] = [
917+
(
918+
cls._apply_known_shape_pattern(s, pattern_shape[k])
919+
if is_shape(s)
920+
else s
921+
)
922+
for s in v
923+
]
924+
return new_ds
925+
926+
# unchanged
927+
return ds
928+
845929

846930
def method_to_onnx(
847931
mod: "torch.nn.Module",

0 commit comments

Comments
 (0)