Skip to content

Commit 8d149ba

Browse files
committed
style quality
1 parent c21abc1 commit 8d149ba

File tree

1 file changed

+49
-44
lines changed

1 file changed

+49
-44
lines changed

src/llmcompressor/entrypoints/oneshot.py

Lines changed: 49 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, List
1+
from typing import List, Optional
22

33
from torch.utils.data import DataLoader
44
from transformers import PreTrainedModel
@@ -190,13 +190,11 @@ def oneshot(
190190
save_compressed: bool = True,
191191
oneshot_device: str = "cuda:0",
192192
model_revision: str = "main",
193-
194193
# Recipe parameters
195194
recipe: Optional[str] = None,
196195
recipe_args: Optional[List[str]] = None,
197196
clear_sparse_session: bool = False,
198197
stage: Optional[str] = None,
199-
200198
# Dataset parameters
201199
dataset: Optional[str] = None,
202200
dataset_config_name: Optional[str] = None,
@@ -212,65 +210,75 @@ def oneshot(
212210
preprocessing_num_workers: Optional[int] = None,
213211
min_tokens_per_module: Optional[float] = None,
214212
trust_remote_code_data: bool = False,
215-
216213
# Output parameters
217214
output_dir: Optional[str] = None,
218-
219215
# For backward compatibility
220-
**kwargs
216+
**kwargs,
221217
) -> PreTrainedModel:
222218
"""
223219
Performs oneshot calibration on a model.
224-
220+
225221
Args:
226222
# Model arguments
227-
model (str): A pretrained model identifier from huggingface.co/models or a path to a
228-
local model. Required parameter.
229-
distill_teacher (Optional[str]): Teacher model (a trained text generation model) for
230-
distillation.
231-
config_name (Optional[str]): Pretrained config name or path if not the same as model_name.
232-
tokenizer (Optional[str]): Pretrained tokenizer name or path if not the same as model_name.
233-
processor (Optional[str]): Pretrained processor name or path if not the same as model_name.
234-
cache_dir (Optional[str]): Where to store the pretrained data from huggingface.co.
235-
use_auth_token (bool): Whether to use Hugging Face auth token for private models.
223+
model (str): A pretrained model identifier from huggingface.co/models or a path
224+
to a local model. Required parameter.
225+
distill_teacher (Optional[str]): Teacher model (a trained text generation model)
226+
for distillation.
227+
config_name (Optional[str]): Pretrained config name or path if not the same as
228+
model_name.
229+
tokenizer (Optional[str]): Pretrained tokenizer name or path if not the same as
230+
model_name.
231+
processor (Optional[str]): Pretrained processor name or path if not the same as
232+
model_name.
233+
cache_dir (Optional[str]): Where to store the pretrained data from
234+
huggingface.co.
235+
use_auth_token (bool): Whether to use Hugging Face auth token for private
236+
models.
236237
precision (str): Precision to cast model weights to, default to auto.
237238
tie_word_embeddings (bool): Whether the model's input and output word embeddings
238239
should be tied.
239-
trust_remote_code_model (bool): Whether to allow for custom models to execute their
240-
own modeling files.
240+
trust_remote_code_model (bool): Whether to allow for custom models to execute
241+
their own modeling files.
241242
save_compressed (bool): Whether to compress sparse models during save.
242243
oneshot_device (str): Device to run oneshot calibration on.
243-
model_revision (str): The specific model version to use (can be branch name, tag, or commit id).
244-
244+
model_revision (str): The specific model version to use (can be branch name,
245+
tag, or commit id).
246+
245247
# Recipe arguments
246248
recipe (Optional[str]): Path to a LLM Compressor sparsification recipe.
247-
recipe_args (Optional[List[str]]): List of recipe arguments to evaluate, in the format
248-
"key1=value1", "key2=value2".
249-
clear_sparse_session (bool): Whether to clear CompressionSession/CompressionLifecycle
250-
data between runs.
249+
recipe_args (Optional[List[str]]): List of recipe arguments to evaluate, in the
250+
format "key1=value1", "key2=value2".
251+
clear_sparse_session (bool): Whether to clear CompressionSession/
252+
CompressionLifecycle data between runs.
251253
stage (Optional[str]): The stage of the recipe to use for oneshot.
252-
254+
253255
# Dataset arguments
254-
dataset (Optional[str]): The name of the dataset to use (via the datasets library).
255-
dataset_config_name (Optional[str]): The configuration name of the dataset to use.
256+
dataset (Optional[str]): The name of the dataset to use (via the datasets
257+
library).
258+
dataset_config_name (Optional[str]): The configuration name of the dataset
259+
to use.
256260
dataset_path (Optional[str]): Path to a custom dataset. Supports json, csv, dvc.
257-
num_calibration_samples (int): Number of samples to use for one-shot calibration.
258-
shuffle_calibration_samples (bool): Whether to shuffle the dataset before calibration.
261+
num_calibration_samples (int): Number of samples to use for one-shot
262+
calibration.
263+
shuffle_calibration_samples (bool): Whether to shuffle the dataset before
264+
calibration.
259265
max_seq_length (int): Maximum total input sequence length after tokenization.
260266
pad_to_max_length (bool): Whether to pad all samples to `max_seq_length`.
261267
text_column (str): Key to use as the `text` input to tokenizer/processor.
262-
concatenate_data (bool): Whether to concatenate datapoints to fill max_seq_length.
268+
concatenate_data (bool): Whether to concatenate datapoints to fill
269+
max_seq_length.
263270
streaming (bool): True to stream data from a cloud dataset.
264271
overwrite_cache (bool): Whether to overwrite the cached preprocessed datasets.
265-
preprocessing_num_workers (Optional[int]): Number of processes for preprocessing.
266-
min_tokens_per_module (Optional[float]): Minimum percentage of tokens per module,
267-
relevant for MoE models.
268-
trust_remote_code_data (bool): Whether to allow for datasets defined on the Hub using
269-
a dataset script.
270-
272+
preprocessing_num_workers (Optional[int]): Number of processes for
273+
preprocessing.
274+
min_tokens_per_module (Optional[float]): Minimum percentage of tokens per
275+
module, relevant for MoE models.
276+
trust_remote_code_data (bool): Whether to allow for datasets defined on the Hub
277+
using a dataset script.
278+
271279
# Output arguments
272280
output_dir (Optional[str]): Path to save the output model after calibration.
273-
281+
274282
Returns:
275283
PreTrainedModel: The calibrated model
276284
"""
@@ -289,13 +297,11 @@ def oneshot(
289297
"save_compressed": save_compressed,
290298
"oneshot_device": oneshot_device,
291299
"model_revision": model_revision,
292-
293300
# Recipe parameters
294301
"recipe": recipe,
295302
"recipe_args": recipe_args,
296303
"clear_sparse_session": clear_sparse_session,
297304
"stage": stage,
298-
299305
# Dataset parameters
300306
"dataset": dataset,
301307
"dataset_config_name": dataset_config_name,
@@ -311,18 +317,17 @@ def oneshot(
311317
"preprocessing_num_workers": preprocessing_num_workers,
312318
"min_tokens_per_module": min_tokens_per_module,
313319
"trust_remote_code_data": trust_remote_code_data,
314-
315320
# Output parameters
316321
"output_dir": output_dir,
317322
}
318-
323+
319324
params = {k: v for k, v in params.items() if v is not None}
320-
325+
321326
# Merge with any kwargs (this preserves backward compatibility)
322327
# kwargs take precedence over explicit params if same key exists
323328
all_args = {**params, **kwargs}
324-
329+
325330
one_shot = Oneshot(**all_args)
326331
one_shot()
327-
332+
328333
return one_shot.model

0 commit comments

Comments
 (0)