2020"""
2121
2222import dataclasses
23- import inspect
2423from dataclasses import dataclass , field
2524from typing import Any , Callable , Dict , List , Optional , Sequence , Tuple
2625
@@ -100,7 +99,7 @@ def __init__(
10099 visual_keys : Sequence [str ] = ("pixel_values" ,),
101100 min_pixels : Optional [int ] = None ,
102101 max_pixels : Optional [int ] = None ,
103- collate_fn : Callable [[ list , Any ] , dict [str , Any ]] | None = None ,
102+ collate_fn : Callable [... , dict [str , Any ]] | None = None ,
104103 pad_to_max_length : bool = False ,
105104 pad_to_multiple_of : int = 128 ,
106105 enable_in_batch_packing : bool = False ,
@@ -127,31 +126,6 @@ def __init__(
127126 )
128127 self ._collate_impl = COLLATE_FNS [collate_key ]
129128
130- def _supported_collate_kwargs (self ) -> dict [str , Any ]:
131- """Return encoder options accepted by the selected collate function."""
132- try :
133- parameters = inspect .signature (self ._collate_impl ).parameters
134- except (TypeError , ValueError ):
135- return {}
136-
137- accepts_kwargs = any (param .kind == inspect .Parameter .VAR_KEYWORD for param in parameters .values ())
138- candidates : dict [str , Any ] = {
139- "visual_keys" : self .visual_keys ,
140- "sequence_length" : self .seq_length ,
141- "pad_to_max_length" : self .pad_to_max_length ,
142- "pad_to_multiple_of" : self .pad_to_multiple_of ,
143- "pack_sequences" : self .enable_in_batch_packing ,
144- "pack_sequences_pad_to_multiple_of" : self .in_batch_packing_pad_to_multiple_of ,
145- }
146- if self .min_pixels is not None :
147- candidates ["min_pixels" ] = self .min_pixels
148- if self .max_pixels is not None :
149- candidates ["max_pixels" ] = self .max_pixels
150-
151- if accepts_kwargs :
152- return candidates
153- return {key : value for key , value in candidates .items () if key in parameters }
154-
155129 def encode_sample (self , sample : ChatMLSample ) -> HFEnergonSample :
156130 """Normalize a single ChatML sample into a HF-style collate example.
157131
@@ -185,7 +159,18 @@ def collate_fn(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
185159 The exact batch dictionary returned by the selected HF collate
186160 function for this processor type.
187161 """
188- return self ._collate_impl (examples , self .processor , ** self ._supported_collate_kwargs ())
162+ return self ._collate_impl (
163+ examples ,
164+ self .processor ,
165+ visual_keys = self .visual_keys ,
166+ min_pixels = self .min_pixels ,
167+ max_pixels = self .max_pixels ,
168+ sequence_length = self .seq_length ,
169+ pad_to_max_length = self .pad_to_max_length ,
170+ pad_to_multiple_of = self .pad_to_multiple_of ,
171+ pack_sequences = self .enable_in_batch_packing ,
172+ in_batch_packing_pad_to_multiple_of = self .in_batch_packing_pad_to_multiple_of ,
173+ )
189174
190175 # ------------------------------------------------------------------
191176 # batch
0 commit comments