Skip to content

Commit eaeec9f

Browse files
authored
Fix a few things for the documentation. (#405)
* tiny changes * fix asert * lint * fix documentation
1 parent 0aa52ce commit eaeec9f

File tree

7 files changed

+70
-18
lines changed

7 files changed

+70
-18
lines changed

_doc/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def linkcode_resolve(domain, info):
144144
("py:class", "torch.utils._pytree.Context"),
145145
("py:class", "torch.utils._pytree.KeyEntry"),
146146
("py:class", "torch.utils._pytree.TreeSpec"),
147+
("py:class", "torch.utils._sympy.value_ranges.ValueRanges"),
147148
("py:class", "transformers.BartForConditionalGeneration"),
148149
("py:class", "transformers.LlamaConfig"),
149150
("py:class", "transformers.cache_utils.Cache"),

_doc/final/plot_export_tiny_llm_method_generate.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,14 @@
2828

2929

3030
def generate_text(
31-
prompt, model, tokenizer, max_length=50, temperature=1, top_k=50, top_p=0.95
31+
prompt,
32+
model,
33+
tokenizer,
34+
max_length=50,
35+
temperature=1,
36+
top_k=50,
37+
top_p=0.95,
38+
do_sample=False,
3239
):
3340
inputs = tokenizer(prompt, return_tensors="pt")
3441
input_ids = inputs["input_ids"]
@@ -41,7 +48,7 @@ def generate_text(
4148
temperature=temperature,
4249
top_k=top_k,
4350
top_p=top_p,
44-
do_sample=True,
51+
do_sample=do_sample,
4552
)
4653

4754
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

onnx_diagnostic/export/api.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,16 @@ def _reorder_kwargs(self, kwargs):
428428
new_kwargs[k] = v
429429
return new_kwargs
430430

431+
def is_empty_cache(self, cache):
432+
if cache.__class__.__name__ == "DynamicCache" and hasattr(cache, "layers"):
433+
if len(cache.layers) == 1 and cache.layers[0].keys is None:
434+
return True
435+
if len(cache.layers) == 0:
436+
return True
437+
if cache is None:
438+
return True
439+
return False
440+
431441
def forward(self, *args, **kwargs):
432442
if not self._export_done:
433443
inp_args = args
@@ -443,6 +453,7 @@ def forward(self, *args, **kwargs):
443453
if v is not None
444454
and (not self.skip_kwargs_names or k not in self.skip_kwargs_names)
445455
and not isinstance(v, (bool, int, float))
456+
and not self.is_empty_cache(v)
446457
}
447458
)
448459
inp_args, inp_kwargs = torch_deepcopy((inp_args, inp_kwargs))

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -834,7 +834,7 @@ def guess_dynamic_shape_object(
834834
"""Guesses the dynamic shapes for one argument."""
835835
if len(objs) == 0:
836836
return None
837-
set_types = set(type(o) for o in objs)
837+
set_types = set(type(o) for o in objs if o is not None)
838838
assert (
839839
len(set_types) == 1
840840
), f"Unexpected variety of input type {set_types}{msg() if msg else ''})"

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,4 +832,13 @@ def finalize_cache(cache: transformers.cache_utils.Cache) -> transformers.cache_
832832
# This is used to expand the cache when it does not contains enough layers.
833833
# This is needed since transformers>4.55.3
834834
cache.layer_class_to_replicate = cache.layers[0].__class__
835+
assert (
836+
not hasattr(cache, "layers")
837+
or len(cache.layers) != 1
838+
or cache.layers[0].keys is not None
839+
), (
840+
f"Size mismatch between {len(cache.layers)=}, "
841+
f"first key={cache.layers[0].keys}, " # type: ignore[attr-defined]
842+
f"first value={cache.layers[0].values}" # type: ignore[attr-defined]
843+
)
835844
return cache

onnx_diagnostic/helpers/torch_helper.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -851,9 +851,14 @@ def torch_deepcopy(value: Any) -> Any:
851851
from .cache_helper import CacheKeyValue
852852

853853
ca = CacheKeyValue(value)
854-
return make_dynamic_cache(
855-
torch_deepcopy(list(zip(ca.key_cache, ca.value_cache))), cls_layers=ca.cls_layers
854+
pairs = list(zip(ca.key_cache, ca.value_cache))
855+
assert not hasattr(value, "layers") or len(value.layers) == len(pairs), (
856+
f"Size mismatch between {len(value.layers)=} and {len(pairs)=}. "
857+
f"value={string_type(value, with_shape=True)}, "
858+
f"first key={value.layers[0].keys}, "
859+
f"first value={value.layers[0].values}"
856860
)
861+
return make_dynamic_cache(torch_deepcopy(pairs), cls_layers=ca.cls_layers)
857862
if value.__class__.__name__ == "StaticCache":
858863
from .cache_helper import CacheKeyValue
859864

onnx_diagnostic/investigate/input_observer.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def _flatten_unflatten_for_dynamic_shapes(
3232
like replace them by a shape
3333
3434
Returns:
35-
the serialized object
35+
the flattened object
3636
"""
3737
if isinstance(obj, torch.Tensor):
3838
return change_function(obj) if change_function else obj
@@ -104,14 +104,22 @@ def _infer_dynamic_dimensions(
104104
class InputCandidate:
105105
"""Retains one set of inputs given to the forward method or any
106106
other method the class :class:`InputObserver` is stealing from.
107+
Any class is allowed as long as it can be flattened.
107108
108109
Args:
109-
args: Positional arguments.
110-
kwargs: Optional arguments.
111-
clone: Clone the inputs before storing them. Some tensors
110+
args:
111+
Positional arguments.
112+
kwargs:
113+
Optional arguments.
114+
clone:
115+
Clones the inputs before storing them. Some tensors
112116
may be modified inplace, the original value must be retained.
113-
cst_kwargs: Any optional arguments constant over multiple calls.
117+
cst_kwargs:
118+
Any optional arguments constant over multiple calls.
114119
int, float, str, bool values must be stored here.
120+
121+
The constructor flattens the received arguments.
122+
Any necessary flattening function should have been registered first.
115123
"""
116124

117125
def __init__(
@@ -671,18 +679,20 @@ class InputObserver:
671679
>>> )
672680
673681
With LLM:
682+
674683
>>> input_observer = InputObserver()
675684
>>> with input_observer(model):
676685
>>> model.generate(input_ids)
677686
>>> ep = torch.export.export( # or torch.onnx.export
678687
>>> model,
679-
>>> ()
688+
>>> (),
680689
>>> kwargs=input_observer.infer_arguments(),
681690
>>> dynamic_shapes.input_observer.infer_dynamic_shapes(),
682691
>>> )
683692
684693
Examples can be found in :ref:`l-plot-tiny-llm-export-input-observer`,
685-
:ref:`l-plot-whisper-tiny-export-input-observer`.
694+
:ref:`l-plot-whisper-tiny-export-input-observer`,
695+
:ref:`l-plot-gemma3-tiny-export-input-observer`.
686696
"""
687697

688698
def __init__(self, missing: dict[str, Any] | None = None):
@@ -865,17 +875,26 @@ def check_discrepancies(
865875
with the saved onnx model.
866876
867877
Args:
868-
onnx_model: ONNX Model to verify.
869-
atol: Absolute tolerance, recommended values, 1e-4 for float, 1e-2 flot float16.
870-
rtol: Relative tolerance.
871-
hist: Thresholds, the function determines the number of discrepancies
878+
onnx_model:
879+
ONNX Model to verify.
880+
atol:
881+
Absolute tolerance, recommended values, 1e-4 for float, 1e-2 flot float16.
882+
rtol:
883+
Relative tolerance.
884+
hist:
885+
Thresholds, the function determines the number of discrepancies
872886
above these thresholds.
873-
progress_bar: Shows a progress bar (requires :epkg:`tqdm`).
874-
include_io: Shows inputs/outputs shapes in the summary
887+
progress_bar:
888+
Shows a progress bar (requires :epkg:`tqdm`).
889+
include_io:
890+
Shows inputs/outputs shapes in the summary
875891
returned by this function.
876892
877893
Returns:
878894
A list of dictionaries, ready to be consumed by a dataframe.
895+
896+
The function catches exceptions, it shows the error in the returned
897+
summary.
879898
"""
880899
sess = OnnxruntimeEvaluator(onnx_model, whole=True)
881900
input_names = sess.input_names

0 commit comments

Comments
 (0)