48
48
49
49
50
50
def infer_framework_from_model (
51
- model , model_classes : Optional [Dict [str , type ]] = None , revision : Optional [str ] = None , task : Optional [ str ] = None
51
+ model , model_classes : Optional [Dict [str , type ]] = None , task : Optional [str ] = None , ** model_kwargs
52
52
):
53
53
"""
54
54
Select framework (TensorFlow or PyTorch) to use from the :obj:`model` passed. Returns a tuple (framework, model).
@@ -65,10 +65,11 @@ def infer_framework_from_model(
65
65
from.
66
66
model_classes (dictionary :obj:`str` to :obj:`type`, `optional`):
67
67
A mapping framework to class.
68
- revision (:obj:`str`, `optional`):
69
- The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
70
- git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
71
- identifier allowed by git.
68
+ task (:obj:`str`):
69
+ The task defining which pipeline will be returned.
70
+ model_kwargs:
71
+ Additional dictionary of keyword arguments passed along to the model's :obj:`from_pretrained(...,
72
+ **model_kwargs)` function.
72
73
73
74
Returns:
74
75
:obj:`Tuple`: A tuple framework, model.
@@ -80,19 +81,20 @@ def infer_framework_from_model(
80
81
"To install PyTorch, read the instructions at https://pytorch.org/."
81
82
)
82
83
if isinstance (model , str ):
84
+ model_kwargs ["_from_pipeline" ] = task
83
85
if is_torch_available () and not is_tf_available ():
84
86
model_class = model_classes .get ("pt" , AutoModel )
85
- model = model_class .from_pretrained (model , revision = revision , _from_pipeline = task )
87
+ model = model_class .from_pretrained (model , ** model_kwargs )
86
88
elif is_tf_available () and not is_torch_available ():
87
89
model_class = model_classes .get ("tf" , TFAutoModel )
88
- model = model_class .from_pretrained (model , revision = revision , _from_pipeline = task )
90
+ model = model_class .from_pretrained (model , ** model_kwargs )
89
91
else :
90
92
try :
91
93
model_class = model_classes .get ("pt" , AutoModel )
92
- model = model_class .from_pretrained (model , revision = revision , _from_pipeline = task )
94
+ model = model_class .from_pretrained (model , ** model_kwargs )
93
95
except OSError :
94
96
model_class = model_classes .get ("tf" , TFAutoModel )
95
- model = model_class .from_pretrained (model , revision = revision , _from_pipeline = task )
97
+ model = model_class .from_pretrained (model , ** model_kwargs )
96
98
97
99
framework = "tf" if model .__class__ .__name__ .startswith ("TF" ) else "pt"
98
100
return framework , model
0 commit comments