Skip to content

Commit 9c4070b

Browse files
philschmidsgugger
authored andcommitted
Adds use_auth_token with pipelines (#11123)
* added model_kwargs to infer_framework_from_model * added model_kwargs to tokenizer * added use_auth_token as named parameter * added dynamic get for use_auth_token
1 parent cd39c8e commit 9c4070b

File tree

2 files changed

+21
-11
lines changed

2 files changed

+21
-11
lines changed

src/transformers/pipelines/__init__.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ def pipeline(
246246
framework: Optional[str] = None,
247247
revision: Optional[str] = None,
248248
use_fast: bool = True,
249+
use_auth_token: Optional[Union[str, bool]] = None,
249250
model_kwargs: Dict[str, Any] = {},
250251
**kwargs
251252
) -> Pipeline:
@@ -308,6 +309,10 @@ def pipeline(
308309
artifacts on huggingface.co, so ``revision`` can be any identifier allowed by git.
309310
use_fast (:obj:`bool`, `optional`, defaults to :obj:`True`):
310311
Whether or not to use a Fast tokenizer if possible (a :class:`~transformers.PreTrainedTokenizerFast`).
312+
use_auth_token (:obj:`str` or `bool`, `optional`):
313+
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
314+
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
315+
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
311316
model_kwargs:
312317
Additional dictionary of keyword arguments passed along to the model's :obj:`from_pretrained(...,
313318
**model_kwargs)` function.
@@ -367,6 +372,9 @@ def pipeline(
367372

368373
task_class, model_class = targeted_task["impl"], targeted_task[framework]
369374

375+
# Retrieve use_auth_token and add it to model_kwargs to be used in .from_pretrained
376+
model_kwargs["use_auth_token"] = model_kwargs.get("use_auth_token", use_auth_token)
377+
370378
# Instantiate tokenizer if needed
371379
if isinstance(tokenizer, (str, tuple)):
372380
if isinstance(tokenizer, tuple):
@@ -377,12 +385,12 @@ def pipeline(
377385
)
378386
else:
379387
tokenizer = AutoTokenizer.from_pretrained(
380-
tokenizer, revision=revision, use_fast=use_fast, _from_pipeline=task
388+
tokenizer, revision=revision, use_fast=use_fast, _from_pipeline=task, **model_kwargs
381389
)
382390

383391
# Instantiate config if needed
384392
if isinstance(config, str):
385-
config = AutoConfig.from_pretrained(config, revision=revision, _from_pipeline=task)
393+
config = AutoConfig.from_pretrained(config, revision=revision, _from_pipeline=task, **model_kwargs)
386394

387395
# Instantiate modelcard if needed
388396
if isinstance(modelcard, str):

src/transformers/pipelines/base.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848

4949

5050
def infer_framework_from_model(
51-
model, model_classes: Optional[Dict[str, type]] = None, revision: Optional[str] = None, task: Optional[str] = None
51+
model, model_classes: Optional[Dict[str, type]] = None, task: Optional[str] = None, **model_kwargs
5252
):
5353
"""
5454
Select framework (TensorFlow or PyTorch) to use from the :obj:`model` passed. Returns a tuple (framework, model).
@@ -65,10 +65,11 @@ def infer_framework_from_model(
6565
from.
6666
model_classes (dictionary :obj:`str` to :obj:`type`, `optional`):
6767
A mapping framework to class.
68-
revision (:obj:`str`, `optional`):
69-
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
70-
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
71-
identifier allowed by git.
68+
task (:obj:`str`):
69+
The task defining which pipeline will be returned.
70+
model_kwargs:
71+
Additional dictionary of keyword arguments passed along to the model's :obj:`from_pretrained(...,
72+
**model_kwargs)` function.
7273
7374
Returns:
7475
:obj:`Tuple`: A tuple framework, model.
@@ -80,19 +81,20 @@ def infer_framework_from_model(
8081
"To install PyTorch, read the instructions at https://pytorch.org/."
8182
)
8283
if isinstance(model, str):
84+
model_kwargs["_from_pipeline"] = task
8385
if is_torch_available() and not is_tf_available():
8486
model_class = model_classes.get("pt", AutoModel)
85-
model = model_class.from_pretrained(model, revision=revision, _from_pipeline=task)
87+
model = model_class.from_pretrained(model, **model_kwargs)
8688
elif is_tf_available() and not is_torch_available():
8789
model_class = model_classes.get("tf", TFAutoModel)
88-
model = model_class.from_pretrained(model, revision=revision, _from_pipeline=task)
90+
model = model_class.from_pretrained(model, **model_kwargs)
8991
else:
9092
try:
9193
model_class = model_classes.get("pt", AutoModel)
92-
model = model_class.from_pretrained(model, revision=revision, _from_pipeline=task)
94+
model = model_class.from_pretrained(model, **model_kwargs)
9395
except OSError:
9496
model_class = model_classes.get("tf", TFAutoModel)
95-
model = model_class.from_pretrained(model, revision=revision, _from_pipeline=task)
97+
model = model_class.from_pretrained(model, **model_kwargs)
9698

9799
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
98100
return framework, model

0 commit comments

Comments
 (0)