Skip to content

Commit f677edb

Browse files
authored
preliminary support for submodule export (#341)
* preliminiary support for submodule export * fix mypy * spell * fix * fix * fix * fix
1 parent 0ee1179 commit f677edb

File tree

8 files changed

+158
-61
lines changed

8 files changed

+158
-61
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.8.4
55
+++++
66

7+
* :pr:`341`: preliminary support to export submodule
78
* :pr:`340`: supports devices in onnx plugs
89
* :pr:`338`: fixes ReplayConfiguration.dump, add function to select of part of a model
910
* :pr:`337`: fixes extract_subset_of_nodes

_doc/cmds/validate.rst

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ of function :func:`onnx_diagnostic.torch_models.validate.run_ort_fusion`.
124124

125125
main("validate -m arnir0/Tiny-LLM --run -v 1 --export onnx-dynamo -o dump_models --patch --opt ir --ortfusiontype ALL".split())
126126

127-
Sdpa or Eager implementation or Use a StaticCache
127+
SDPA or Eager implementation or Use a StaticCache
128128
+++++++++++++++++++++++++++++++++++++++++++++++++
129129

130130
Add ``--mop cache_implementation=static --iop cls_cache=StaticCache`` to use a StaticCache instead of a DynamicCache (default).
@@ -147,3 +147,22 @@ Add ``--mop attn_implementation=eager`` to explicitly select eager implementatio
147147
--mop attn_implementation=eager \
148148
--mop cache_implementation=static \
149149
--iop cls_cache=StaticCache
150+
151+
Frequent examples used to test
152+
++++++++++++++++++++++++++++++
153+
154+
.. code-block:: bash
155+
156+
python -m onnx_diagnostic validate -m arnir0/Tiny-LLM --run -v 1 --device cuda --dtype float16 -o dump_models --patch --opt default+onnxruntime --export custom
157+
158+
About the exporter 'custom'
159+
+++++++++++++++++++++++++++
160+
161+
It used to investigate issues or scenarios. It is usually very strict
162+
and fails every time it falls in one unexpected situation.
163+
It call :func:`experimental_experiment.torch_interpreter.to_onnx`.
164+
Some useful environment variables to set before running the command line.
165+
166+
* ``DROPPATTERN=<pattern1,patterns2,...>``: do not apply those patterns when optimizing a model
167+
* ``DUMPPATTERNS=<folder>``: dumps all matched and applied nodes when a pattern is applied
168+
* ``PATTERN=<pattern1,pattern2,...>``: increase verbosity for specific patterns to understand why one pattern was not applied

_unittests/ut_tasks/test_tasks.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,24 @@ def test_text_generation(self):
4747
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
4848
)
4949

50+
@hide_stdout()
51+
def test_submodule(self):
52+
mid = "arnir0/Tiny-LLM::model"
53+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
54+
self.assertEqual(data["task"], "text-generation")
55+
self.assertIn("inputs", data)
56+
self.assertIn("inputs2", data)
57+
self.assertIn("inputs_batch1", data)
58+
self.assertIn("inputs_empty_cache", data)
59+
self.assertIn((data["size"], data["n_weights"]), [(27379968, 6844992)])
60+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
61+
model(**inputs)
62+
model(**data["inputs2"])
63+
with torch_export_patches(patch_transformers=True, verbose=10):
64+
torch.export.export(
65+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
66+
)
67+
5068
@hide_stdout()
5169
def test_text_generation_empty_cache(self):
5270
mid = "arnir0/Tiny-LLM"

onnx_diagnostic/export/api.py

Lines changed: 60 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,52 @@
33
from .onnx_plug import EagerDirectReplacementWithOnnx
44

55

6+
def get_main_dispatcher(
7+
use_control_flow_dispatcher: bool = False,
8+
onnx_plugs: Optional[List[EagerDirectReplacementWithOnnx]] = None,
9+
) -> Any: # Dispatcher
10+
"""Creates a custom dispatcher for the custom exporter."""
11+
from experimental_experiment.torch_interpreter import Dispatcher
12+
13+
if use_control_flow_dispatcher:
14+
from .control_flow_onnx import create_global_dispatcher
15+
16+
control_flow_dispatcher = create_global_dispatcher()
17+
else:
18+
control_flow_dispatcher = None
19+
20+
class MainDispatcher(Dispatcher):
21+
def __init__(self, previous_dispatcher=None):
22+
super().__init__({})
23+
self.previous_dispatcher = previous_dispatcher
24+
25+
@property
26+
def supported(self):
27+
if self.previous_dispatcher:
28+
return set(self.registered_functions) | self.previous_dispatcher.supported
29+
return set(self.registered_functions)
30+
31+
def find_function(self, name: Any):
32+
if self.previous_dispatcher:
33+
find = self.previous_dispatcher.find_function(name)
34+
if find:
35+
return find
36+
return Dispatcher.find_function(self, name)
37+
38+
def find_method(self, name: Any):
39+
if self.previous_dispatcher:
40+
find = self.previous_dispatcher.find_method(name)
41+
if find:
42+
return find
43+
return Dispatcher.find_method(self, name)
44+
45+
main_dispatcher = MainDispatcher(control_flow_dispatcher)
46+
if onnx_plugs:
47+
for plug in onnx_plugs:
48+
main_dispatcher.registered_functions[plug.target_name] = plug.custom_converter()
49+
return main_dispatcher
50+
51+
652
def to_onnx(
753
mod: Union["torch.nn.Module", "torch.fx.GraphModule"], # noqa: F821
854
args: Optional[Sequence["torch.Tensor"]] = None, # noqa: F821
@@ -82,51 +128,11 @@ def to_onnx(
82128
options = exporter_kwargs.pop("options", None)
83129
if options is None:
84130
options = OptimizationOptions(patterns="default+onnxruntime")
85-
if onnx_plugs or use_control_flow_dispatcher:
86-
from experimental_experiment.torch_interpreter import Dispatcher
87-
88-
if use_control_flow_dispatcher:
89-
from .control_flow_onnx import create_global_dispatcher
90-
91-
control_flow_dispatcher = create_global_dispatcher()
92-
else:
93-
control_flow_dispatcher = None
94-
95-
class MainDispatcher(Dispatcher):
96-
def __init__(self, previous_dispatcher=None):
97-
super().__init__({})
98-
self.previous_dispatcher = previous_dispatcher
99-
100-
@property
101-
def supported(self):
102-
if self.previous_dispatcher:
103-
return (
104-
set(self.registered_functions) | self.previous_dispatcher.supported
105-
)
106-
return set(self.registered_functions)
107-
108-
def find_function(self, name: Any):
109-
if self.previous_dispatcher:
110-
find = self.previous_dispatcher.find_function(name)
111-
if find:
112-
return find
113-
return Dispatcher.find_function(self, name)
114-
115-
def find_method(self, name: Any):
116-
if self.previous_dispatcher:
117-
find = self.previous_dispatcher.find_method(name)
118-
if find:
119-
return find
120-
return Dispatcher.find_method(self, name)
121-
122-
main_dispatcher = MainDispatcher(control_flow_dispatcher)
123-
if onnx_plugs:
124-
for plug in onnx_plugs:
125-
main_dispatcher.registered_functions[plug.target_name] = (
126-
plug.custom_converter()
127-
)
128-
else:
129-
main_dispatcher = None
131+
main_dispatcher = (
132+
get_main_dispatcher(use_control_flow_dispatcher, onnx_plugs)
133+
if onnx_plugs or use_control_flow_dispatcher
134+
else None
135+
)
130136

131137
return _to_onnx(
132138
mod,
@@ -181,9 +187,17 @@ def find_method(self, name: Any):
181187
import onnx_ir as ir
182188
import onnx_ir.passes.common as common_passes
183189

190+
opset = (
191+
18
192+
if target_opset is None
193+
else (target_opset if isinstance(target_opset, int) else target_opset[""])
194+
)
195+
184196
irfunctions = [
185197
ir.from_proto(
186-
plug.get_function_proto(*flatten_object((args, kwargs), drop_keys=True))
198+
plug.get_function_proto(
199+
opset, *flatten_object((args, kwargs), drop_keys=True)
200+
)
187201
)
188202
for plug in onnx_plugs
189203
]

onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -262,12 +262,14 @@ def qwen_version_selector(opset: int, *args: torch.Tensor) -> Tuple[str, torch.d
262262
itype = torch_dtype_to_onnx_dtype(dtype)
263263
if strategy is not None:
264264
return strategy, itype
265-
if dtype == torch.float32:
265+
if dtype == torch.float32 or itype == onnx.TensorProto.FLOAT:
266266
if opset >= 24:
267267
return "LOOPA24", itype
268268
return "LOOPMHA", itype
269-
if dtype == torch.float16:
270-
if first_tensor.is_cuda:
269+
if dtype == torch.float16 or itype == onnx.TensorProto.FLOAT16:
270+
# first_tensor may be a SymbolicTensor (onnx).
271+
# is_cuda is not available.
272+
if hasattr(first_tensor, "is_cuda") and first_tensor.is_cuda:
271273
return "PACKED", itype
272274
return "LOOPMHA", itype
273275
raise AssertionError(
@@ -638,12 +640,14 @@ def forward(
638640
self.config._attn_implementation
639641
]
640642

641-
is_sdpa = (
643+
is_sdpa_or_eager = (
642644
attention_interface
643645
is transformers.integrations.sdpa_attention.sdpa_attention_forward
644646
or attention_interface is patched_sdpa_attention_forward
647+
or attention_interface
648+
is transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.eager_attention_forward
645649
)
646-
if is_sdpa:
650+
if is_sdpa_or_eager:
647651
attn_output = qwen_sdpa_attention_versatile(
648652
query_states,
649653
key_states,

onnx_diagnostic/torch_models/code_sample.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def code_sample(
236236
)
237237
)
238238
"""
239-
model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
239+
model_id, subfolder, same_as_pretrained, use_pretrained, submodule = _preprocess_model_id(
240240
model_id,
241241
subfolder,
242242
same_as_pretrained=same_as_pretrained,
@@ -256,6 +256,7 @@ def code_sample(
256256
model_kwargs=mop,
257257
subfolder=subfolder,
258258
add_second_input=False,
259+
submodule=submodule,
259260
)
260261
if drop_inputs:
261262
update = {}

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,26 @@ def _code_needing_rewriting(model: Any) -> Any:
2626

2727

2828
def _preprocess_model_id(
29-
model_id: str, subfolder: Optional[str], same_as_pretrained: bool, use_pretrained: bool
30-
) -> Tuple[str, Optional[str], bool, bool]:
29+
model_id: str,
30+
subfolder: Optional[str],
31+
same_as_pretrained: bool,
32+
use_pretrained: bool,
33+
submodule: Optional[str] = None,
34+
) -> Tuple[str, Optional[str], bool, bool, Optional[str]]:
35+
if "::" in model_id:
36+
assert (
37+
not submodule
38+
), f"submodule={submodule!r} cannot be defined in model_id={model_id!r} as well"
39+
model_id, submodule = model_id.split("::", maxsplit=1)
3140
if subfolder or "//" not in model_id:
32-
return model_id, subfolder, same_as_pretrained, use_pretrained
41+
return model_id, subfolder, same_as_pretrained, use_pretrained, submodule
3342
spl = model_id.split("//")
3443
if spl[-1] == "pretrained":
35-
return _preprocess_model_id("//".join(spl[:-1]), "", True, True)
44+
return _preprocess_model_id("//".join(spl[:-1]), "", True, True, submodule)
3645
if spl[-1] in {"transformer", "vae"}:
3746
# known subfolder
38-
return "//".join(spl[:-1]), spl[-1], same_as_pretrained, use_pretrained
39-
return model_id, subfolder, same_as_pretrained, use_pretrained
47+
return "//".join(spl[:-1]), spl[-1], same_as_pretrained, use_pretrained, submodule
48+
return model_id, subfolder, same_as_pretrained, use_pretrained, submodule
4049

4150

4251
def get_untrained_model_with_inputs(
@@ -54,6 +63,7 @@ def get_untrained_model_with_inputs(
5463
subfolder: Optional[str] = None,
5564
use_only_preinstalled: bool = False,
5665
config_reduction: Optional[Callable[[Any, str], Dict]] = None,
66+
submodule: Optional[str] = None,
5767
) -> Dict[str, Any]:
5868
"""
5969
Gets a non initialized model similar to the original model
@@ -82,6 +92,7 @@ def get_untrained_model_with_inputs(
8292
<onnx_diagnostic.torch_models.hghub.reduce_model_config>`,
8393
this function takes a configuration and a task (string)
8494
as arguments
95+
:param submodule: use a submodule instead of the main model
8596
:return: dictionary with a model, inputs, dynamic shapes, and the configuration,
8697
some necessary rewriting as well
8798
@@ -108,11 +119,12 @@ def get_untrained_model_with_inputs(
108119
f"model_id={model_id!r}, preinstalled model is only available "
109120
f"if use_only_preinstalled is False."
110121
)
111-
model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
122+
model_id, subfolder, same_as_pretrained, use_pretrained, submodule = _preprocess_model_id(
112123
model_id,
113124
subfolder,
114125
same_as_pretrained=same_as_pretrained,
115126
use_pretrained=use_pretrained,
127+
submodule=submodule,
116128
)
117129
if verbose:
118130
print(
@@ -147,6 +159,8 @@ def get_untrained_model_with_inputs(
147159
if verbose:
148160
print(f"[get_untrained_model_with_inputs] architecture={arch!r}")
149161
print(f"[get_untrained_model_with_inputs] cls={config.__class__.__name__!r}")
162+
if submodule:
163+
print(f"[get_untrained_model_with_inputs] submodule={submodule!r}")
150164
if task is None:
151165
task = task_from_arch(arch, model_id=model_id, subfolder=subfolder)
152166
if verbose:
@@ -357,6 +371,19 @@ def get_untrained_model_with_inputs(
357371
if diff_config is not None:
358372
res["dump_info"] = dict(config_diff=diff_config)
359373

374+
if submodule:
375+
path = submodule.split("::") if "::" in submodule else [submodule]
376+
for p in path:
377+
assert hasattr(model, p), (
378+
f"Unable to find submodule {p!r} in in class {type(model)}, "
379+
f"submodule={submodule!r}, possible candidates: "
380+
f"{[k for k in dir(model) if isinstance(getattr(model, k), torch.nn.Module)]}"
381+
)
382+
model = getattr(model, p)
383+
384+
if verbose:
385+
print(f"[get_untrained_model_with_inputs] model class={model.__class__.__name__!r}")
386+
360387
sizes = compute_model_size(model)
361388
res["model"] = model
362389
res["configuration"] = config

0 commit comments

Comments
 (0)