1
- from typing import Optional , List
1
+ from typing import List , Optional
2
2
3
3
from torch .utils .data import DataLoader
4
4
from transformers import PreTrainedModel
@@ -190,13 +190,11 @@ def oneshot(
190
190
save_compressed : bool = True ,
191
191
oneshot_device : str = "cuda:0" ,
192
192
model_revision : str = "main" ,
193
-
194
193
# Recipe parameters
195
194
recipe : Optional [str ] = None ,
196
195
recipe_args : Optional [List [str ]] = None ,
197
196
clear_sparse_session : bool = False ,
198
197
stage : Optional [str ] = None ,
199
-
200
198
# Dataset parameters
201
199
dataset : Optional [str ] = None ,
202
200
dataset_config_name : Optional [str ] = None ,
@@ -212,65 +210,75 @@ def oneshot(
212
210
preprocessing_num_workers : Optional [int ] = None ,
213
211
min_tokens_per_module : Optional [float ] = None ,
214
212
trust_remote_code_data : bool = False ,
215
-
216
213
# Output parameters
217
214
output_dir : Optional [str ] = None ,
218
-
219
215
# For backward compatibility
220
- ** kwargs
216
+ ** kwargs ,
221
217
) -> PreTrainedModel :
222
218
"""
223
219
Performs oneshot calibration on a model.
224
-
220
+
225
221
Args:
226
222
# Model arguments
227
- model (str): A pretrained model identifier from huggingface.co/models or a path to a
228
- local model. Required parameter.
229
- distill_teacher (Optional[str]): Teacher model (a trained text generation model) for
230
- distillation.
231
- config_name (Optional[str]): Pretrained config name or path if not the same as model_name.
232
- tokenizer (Optional[str]): Pretrained tokenizer name or path if not the same as model_name.
233
- processor (Optional[str]): Pretrained processor name or path if not the same as model_name.
234
- cache_dir (Optional[str]): Where to store the pretrained data from huggingface.co.
235
- use_auth_token (bool): Whether to use Hugging Face auth token for private models.
223
+ model (str): A pretrained model identifier from huggingface.co/models or a path
224
+ to a local model. Required parameter.
225
+ distill_teacher (Optional[str]): Teacher model (a trained text generation model)
226
+ for distillation.
227
+ config_name (Optional[str]): Pretrained config name or path if not the same as
228
+ model_name.
229
+ tokenizer (Optional[str]): Pretrained tokenizer name or path if not the same as
230
+ model_name.
231
+ processor (Optional[str]): Pretrained processor name or path if not the same as
232
+ model_name.
233
+ cache_dir (Optional[str]): Where to store the pretrained data from
234
+ huggingface.co.
235
+ use_auth_token (bool): Whether to use Hugging Face auth token for private
236
+ models.
236
237
precision (str): Precision to cast model weights to, default to auto.
237
238
tie_word_embeddings (bool): Whether the model's input and output word embeddings
238
239
should be tied.
239
- trust_remote_code_model (bool): Whether to allow for custom models to execute their
240
- own modeling files.
240
+ trust_remote_code_model (bool): Whether to allow for custom models to execute
241
+ their own modeling files.
241
242
save_compressed (bool): Whether to compress sparse models during save.
242
243
oneshot_device (str): Device to run oneshot calibration on.
243
- model_revision (str): The specific model version to use (can be branch name, tag, or commit id).
244
-
244
+ model_revision (str): The specific model version to use (can be branch name,
245
+ tag, or commit id).
246
+
245
247
# Recipe arguments
246
248
recipe (Optional[str]): Path to a LLM Compressor sparsification recipe.
247
- recipe_args (Optional[List[str]]): List of recipe arguments to evaluate, in the format
248
- "key1=value1", "key2=value2".
249
- clear_sparse_session (bool): Whether to clear CompressionSession/CompressionLifecycle
250
- data between runs.
249
+ recipe_args (Optional[List[str]]): List of recipe arguments to evaluate, in the
250
+ format "key1=value1", "key2=value2".
251
+ clear_sparse_session (bool): Whether to clear CompressionSession/
252
+ CompressionLifecycle data between runs.
251
253
stage (Optional[str]): The stage of the recipe to use for oneshot.
252
-
254
+
253
255
# Dataset arguments
254
- dataset (Optional[str]): The name of the dataset to use (via the datasets library).
255
- dataset_config_name (Optional[str]): The configuration name of the dataset to use.
256
+ dataset (Optional[str]): The name of the dataset to use (via the datasets
257
+ library).
258
+ dataset_config_name (Optional[str]): The configuration name of the dataset
259
+ to use.
256
260
dataset_path (Optional[str]): Path to a custom dataset. Supports json, csv, dvc.
257
- num_calibration_samples (int): Number of samples to use for one-shot calibration.
258
- shuffle_calibration_samples (bool): Whether to shuffle the dataset before calibration.
261
+ num_calibration_samples (int): Number of samples to use for one-shot
262
+ calibration.
263
+ shuffle_calibration_samples (bool): Whether to shuffle the dataset before
264
+ calibration.
259
265
max_seq_length (int): Maximum total input sequence length after tokenization.
260
266
pad_to_max_length (bool): Whether to pad all samples to `max_seq_length`.
261
267
text_column (str): Key to use as the `text` input to tokenizer/processor.
262
- concatenate_data (bool): Whether to concatenate datapoints to fill max_seq_length.
268
+ concatenate_data (bool): Whether to concatenate datapoints to fill
269
+ max_seq_length.
263
270
streaming (bool): True to stream data from a cloud dataset.
264
271
overwrite_cache (bool): Whether to overwrite the cached preprocessed datasets.
265
- preprocessing_num_workers (Optional[int]): Number of processes for preprocessing.
266
- min_tokens_per_module (Optional[float]): Minimum percentage of tokens per module,
267
- relevant for MoE models.
268
- trust_remote_code_data (bool): Whether to allow for datasets defined on the Hub using
269
- a dataset script.
270
-
272
+ preprocessing_num_workers (Optional[int]): Number of processes for
273
+ preprocessing.
274
+ min_tokens_per_module (Optional[float]): Minimum percentage of tokens per
275
+ module, relevant for MoE models.
276
+ trust_remote_code_data (bool): Whether to allow for datasets defined on the Hub
277
+ using a dataset script.
278
+
271
279
# Output arguments
272
280
output_dir (Optional[str]): Path to save the output model after calibration.
273
-
281
+
274
282
Returns:
275
283
PreTrainedModel: The calibrated model
276
284
"""
@@ -289,13 +297,11 @@ def oneshot(
289
297
"save_compressed" : save_compressed ,
290
298
"oneshot_device" : oneshot_device ,
291
299
"model_revision" : model_revision ,
292
-
293
300
# Recipe parameters
294
301
"recipe" : recipe ,
295
302
"recipe_args" : recipe_args ,
296
303
"clear_sparse_session" : clear_sparse_session ,
297
304
"stage" : stage ,
298
-
299
305
# Dataset parameters
300
306
"dataset" : dataset ,
301
307
"dataset_config_name" : dataset_config_name ,
@@ -311,18 +317,17 @@ def oneshot(
311
317
"preprocessing_num_workers" : preprocessing_num_workers ,
312
318
"min_tokens_per_module" : min_tokens_per_module ,
313
319
"trust_remote_code_data" : trust_remote_code_data ,
314
-
315
320
# Output parameters
316
321
"output_dir" : output_dir ,
317
322
}
318
-
323
+
319
324
params = {k : v for k , v in params .items () if v is not None }
320
-
325
+
321
326
# Merge with any kwargs (this preserves backward compatibility)
322
327
# kwargs take precedence over explicit params if same key exists
323
328
all_args = {** params , ** kwargs }
324
-
329
+
325
330
one_shot = Oneshot (** all_args )
326
331
one_shot ()
327
-
332
+
328
333
return one_shot .model
0 commit comments