@@ -431,11 +431,19 @@ def _reorder_kwargs(self, kwargs):
431431 def forward (self , * args , ** kwargs ):
432432 if not self ._export_done :
433433 inp_args = args
434- # filters out the inputs not desired
434+ # filters out the inputs not desired, int, float, bool, None
435+ # are considered as constant for the exporter, they are removed
436+ # from the named arguments.
435437 inp_kwargs = (
436438 kwargs
437- if not kwargs or not self .skip_kwargs_names
438- else {k : v for k , v in kwargs .items () if k not in self .skip_kwargs_names }
439+ if not kwargs
440+ else {
441+ k : v
442+ for k , v in kwargs .items ()
443+ if v is not None
444+ and (not self .skip_kwargs_names or k not in self .skip_kwargs_names )
445+ and not isinstance (v , (bool , int , float ))
446+ }
439447 )
440448 if self .expand_batch_for :
441449 # extends the inputs to artificially create a batch dimension != 1.
@@ -538,7 +546,10 @@ def __init__(self, parent):
538546 if self .dynamic_batch_for :
539547 nds = (
540548 self ._dynamic_batch_dimension (nds [0 ], self .dynamic_batch_for ),
541- self ._dynamic_batch_dimension (nds [1 ], self .dynamic_batch_for ),
549+ self .rename_dynamic_shapes (
550+ self ._dynamic_batch_dimension (nds [1 ], self .dynamic_batch_for ),
551+ verbose = self .verbose ,
552+ ),
542553 )
543554 if self .verbose :
544555 print (f"[method_to_onnx] dynamic_batch_for={ self .dynamic_batch_for } " )
@@ -842,6 +853,79 @@ def check_discrepancies(
842853 print ("[method_to_onnx.check_discrepancies] done" )
843854 return data
844855
856+ @classmethod
857+ def _apply_known_shape_pattern (
858+ cls , shape : Dict [int , Any ], pattern : Dict [int , str ]
859+ ) -> Dict [int , Any ]:
860+ return {k : pattern .get (k , v ) for k , v in shape .items ()}
861+
862+ @classmethod
863+ def get_dynamic_shape_patterns (cls ) -> Dict [str , Any ]:
864+ """
865+ Returns the known patterns for the dynamic shapes.
866+
867+ .. runpython::
868+ :showcode:
869+
870+ import pprint
871+ from onnx_diagnostic.export.api import WrapperToExportMethodToOnnx
872+ pprint.pprint(WrapperToExportMethodToOnnx.get_dynamic_shape_patterns())
873+ """
874+ return {
875+ "LLM.text" : {
876+ "cache_position" : {0 : "seqlength" },
877+ "past_key_values" : {0 : "batch" , 2 : "pastlength" },
878+ "input_ids" : {0 : "batch" , 1 : "seqlength" },
879+ "attention_mask" : {0 : "batch" , 1 : "totallength" }, # pastlength+seqlength
880+ }
881+ }
882+
883+ @classmethod
884+ def rename_dynamic_shapes (cls , ds : Dict [str , Any ], verbose : int = 0 ) -> Dict [str , Any ]:
885+ """
886+ Renames the dynamic shapes with names.
887+ Tries to rename any dynamic dimnesion dimension
888+ before export. It is not very clever, it just tries
889+ to recognize a known configuration based on input names.
890+ Dimension names in dynamic shapes are renamed if *ds* has
891+ the same number of named arguments as the one of the patterns
892+ returned by function :meth:`get_dynamic_shape_patterns
893+ <onnx_diagnostic.export.api.WrapperToExportMethodToOnnx.get_dynamic_shape_patterns>`.
894+ """
895+ is_shape = lambda s : isinstance (s , dict ) and all ( # noqa: E731
896+ isinstance (_ , int ) for _ in s
897+ )
898+ llm_patterns = cls .get_dynamic_shape_patterns ()
899+ for pattern_name , pattern_shape in llm_patterns .items ():
900+ if len (set (ds ) & set (pattern_shape )) == len (pattern_shape ):
901+ if verbose :
902+ print (
903+ f"[method_to_onnx.rename_dynamic_shapes] "
904+ f"apply pattern shapes { pattern_name !r} "
905+ )
906+ new_ds = {}
907+ for k , v in ds .items ():
908+ if k not in pattern_shape :
909+ new_ds [k ] = v
910+ continue
911+ if is_shape (v ):
912+ # A shape
913+ new_ds [k ] = cls ._apply_known_shape_pattern (v , pattern_shape [k ])
914+ elif isinstance (v , list ):
915+ # A cache
916+ new_ds [k ] = [
917+ (
918+ cls ._apply_known_shape_pattern (s , pattern_shape [k ])
919+ if is_shape (s )
920+ else s
921+ )
922+ for s in v
923+ ]
924+ return new_ds
925+
926+ # unchanged
927+ return ds
928+
845929
846930def method_to_onnx (
847931 mod : "torch.nn.Module" ,
0 commit comments