Skip to content

Commit 4e5433a

Browse files
authored
Add hf to cross provider inference engine (#1866)
* Added option to run HF inference in CrossProviderInferenceEngline Signed-off-by: Yoav Katz <katz@il.ibm.com> * Changed example to use hf in CrossProviderInferenceEngine Signed-off-by: Yoav Katz <katz@il.ibm.com> * Revert unintended deletion Signed-off-by: Yoav Katz <katz@il.ibm.com> * More example changes Signed-off-by: Yoav Katz <katz@il.ibm.com> * Changed to provider name from hf to hf-local Fixed additional examples. Signed-off-by: Yoav Katz <katz@il.ibm.com> --------- Signed-off-by: Yoav Katz <katz@il.ibm.com>
1 parent 69e3386 commit 4e5433a

File tree

7 files changed

+34
-56
lines changed

7 files changed

+34
-56
lines changed

examples/evaluate_rag_response_generation.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
TaskCard,
44
)
55
from unitxt.collections_operators import Wrap
6-
from unitxt.inference import (
7-
HFPipelineBasedInferenceEngine,
8-
)
6+
from unitxt.inference import CrossProviderInferenceEngine
97
from unitxt.loaders import LoadFromDictionary
108
from unitxt.operators import Rename, Set
119
from unitxt.templates import MultiReferenceTemplate, TemplatesDict
@@ -78,13 +76,8 @@
7876
)
7977

8078

81-
# Infer using Llama-3.2-1B base using HF API
82-
model = HFPipelineBasedInferenceEngine(
83-
model_name="meta-llama/Llama-3.2-1B", max_new_tokens=32
84-
)
85-
# Change to this to infer with external APIs:
86-
# CrossProviderInferenceEngine(model="llama-3-2-1b-instruct", provider="watsonx")
87-
# The provider can be one of: ["watsonx", "together-ai", "open-ai", "aws", "ollama", "bam"]
79+
model = CrossProviderInferenceEngine(model="llama-3-2-1b-instruct", provider="watsonx")
80+
# The provider can be one of: ["watsonx", "together-ai", "open-ai", "aws", "ollama", "hf-local"]
8881

8982
predictions = model(dataset)
9083
results = evaluate(predictions=predictions, data=dataset)

examples/evaluate_using_metrics_ensemble.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
from unitxt import get_logger
22
from unitxt.api import evaluate, load_dataset
3-
from unitxt.inference import (
4-
HFPipelineBasedInferenceEngine,
5-
)
3+
from unitxt.inference import CrossProviderInferenceEngine
64
from unitxt.metrics import MetricsEnsemble
75

86
logger = get_logger()
97

108
# define the metrics ensemble
119
ensemble_metric = MetricsEnsemble(
1210
metrics=[
13-
"metrics.llm_as_judge.rating.llama_3_70b_instruct.generic_single_turn",
14-
"metrics.llm_as_judge.rating.llama_3_8b_instruct_ibm_genai_template_mt_bench_single_turn",
11+
"metrics.llm_as_judge.direct.watsonx.llama3_3_70b[criteria=metrics.llm_as_judge.direct.criteria.answer_relevance, context_fields=[question]]",
12+
"metrics.llm_as_judge.direct.watsonx.llama3_3_70b[criteria=metrics.llm_as_judge.direct.criteria.correctness_based_on_ground_truth, context_fields=[question,answers]]",
1513
],
1614
weights=[0.75, 0.25],
1715
)
@@ -27,13 +25,8 @@
2725
split="test",
2826
)
2927

30-
# Infer using SmolLM2 using HF API
31-
model = HFPipelineBasedInferenceEngine(
32-
model_name="HuggingFaceTB/SmolLM2-1.7B-Instruct", max_new_tokens=32
33-
)
3428
# Change to this to infer with external APIs:
35-
# CrossProviderInferenceEngine(model="llama-3-2-1b-instruct", provider="watsonx")
36-
# The provider can be one of: ["watsonx", "together-ai", "open-ai", "aws", "ollama", "bam"]
29+
model = CrossProviderInferenceEngine(model="llama-3-2-1b-instruct", provider="watsonx")
3730

3831
predictions = model(dataset)
3932

examples/inference_using_cross_provider.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from unitxt.text_utils import print_dict
33

44
if __name__ == "__main__":
5-
for provider in ["watsonx", "rits", "watsonx-sdk"]:
5+
for provider in ["watsonx", "rits", "watsonx-sdk", "hf-local"]:
66
print()
77
print("------------------------------------------------ ")
88
print("PROVIDER:", provider)

examples/multiple_choice_qa_evaluation.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from unitxt import get_logger, load_dataset
44
from unitxt.api import LoadFromDictionary, TaskCard, evaluate
55
from unitxt.blocks import Rename
6-
from unitxt.inference import HFPipelineBasedInferenceEngine
6+
from unitxt.inference import CrossProviderInferenceEngine
77
from unitxt.operators import IndexOf, ListFieldValues
88
from unitxt.templates import MultipleChoiceTemplate
99

@@ -61,14 +61,8 @@
6161
format="formats.chat_api",
6262
)
6363

64-
# Infer using Llama-3.2-1B base using HF API
65-
model = HFPipelineBasedInferenceEngine(
66-
model_name="HuggingFaceTB/SmolLM2-1.7B-Instruct", max_new_tokens=32
67-
)
68-
# Change to this to infer with external APIs:
69-
# from unitxt.inference import CrossProviderInferenceEngine
70-
# model = CrossProviderInferenceEngine(model="llama-3-2-1b-instruct", provider="watsonx")
71-
# The provider can be one of: ["watsonx", "together-ai", "open-ai", "aws", "ollama", "bam"]
64+
model = CrossProviderInferenceEngine(model="SmolLM2-1.7B-Instruct", provider="hf-local")
65+
# The provider can be one of: ["watsonx", "together-ai", "open-ai", "aws", "ollama","hf-local"]
7266

7367

7468
predictions = model(dataset)
@@ -79,7 +73,7 @@
7973

8074

8175
print("Instance Results:")
82-
print(results.instance_scores)
76+
print(results.instance_scores.summary)
8377

8478
print("Global Results:")
8579
print(results.global_scores.summary)

examples/qa_evaluation.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
from unitxt import get_logger
22
from unitxt.api import create_dataset, evaluate
3-
from unitxt.inference import (
4-
HFPipelineBasedInferenceEngine,
5-
)
3+
from unitxt.inference import CrossProviderInferenceEngine
64

75
logger = get_logger()
86

@@ -30,14 +28,9 @@
3028
format="formats.chat_api",
3129
)
3230

33-
# Infer using SmolLM2 using HF API
34-
model = HFPipelineBasedInferenceEngine(
35-
model_name="HuggingFaceTB/SmolLM2-1.7B-Instruct", max_new_tokens=32
36-
)
37-
# Change to this to infer with external APIs:
38-
# from unitxt.inference import CrossProviderInferenceEngine
39-
# engine = CrossProviderInferenceEngine(model="llama-3-2-1b-instruct", provider="watsonx")
40-
# The provider can be one of: ["watsonx", "together-ai", "open-ai", "aws", "ollama", "bam"]
31+
model = CrossProviderInferenceEngine(model="SmolLM2-1.7B-Instruct", provider="hf-local")
32+
# The provider can be one of: ["watsonx", "together-ai", "open-ai", "aws", "ollama", "hf-local"]
33+
# (model must be available in the provider service)
4134

4235

4336
predictions = model(dataset)

examples/standalone_qa_evaluation.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from unitxt import get_logger
22
from unitxt.api import create_dataset, evaluate
33
from unitxt.blocks import Task
4-
from unitxt.inference import HFPipelineBasedInferenceEngine
4+
from unitxt.inference import CrossProviderInferenceEngine
55
from unitxt.templates import InputOutputTemplate
66

77
logger = get_logger()
@@ -37,21 +37,17 @@
3737
)
3838

3939

40-
# Infer using SmolLM2 using HF API
41-
model = HFPipelineBasedInferenceEngine(
42-
model_name="HuggingFaceTB/SmolLM2-1.7B-Instruct", max_new_tokens=32
40+
model = CrossProviderInferenceEngine(
41+
model="SmolLM2-1.7B-Instruct", provider="hf-local", use_cache=False
4342
)
44-
# Change to this to infer with external APIs:
45-
# from unitxt.inference import CrossProviderInferenceEngine
46-
# model = CrossProviderInferenceEngine(model="llama-3-2-1b-instruct", provider="watsonx")
47-
# The provider can be one of: ["watsonx", "together-ai", "open-ai", "aws", "ollama", "bam". "rits"]
48-
43+
# The provider can be one of: ["watsonx", "together-ai", "open-ai", "aws", "ollama", "rits", "hf-local"]
44+
# (model must be available in the provider service)
4945

5046
predictions = model(dataset)
5147
results = evaluate(predictions=predictions, data=dataset)
5248

5349
print("Instance Results:")
54-
print(results.instance_scores)
50+
print(results.instance_scores.summary)
5551

5652
print("Global Results:")
5753
print(results.global_scores.summary)

src/unitxt/inference.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class StandardAPIParamsMixin(Artifact):
7979
n: Optional[int] = None
8080
parallel_tool_calls: Optional[bool] = None
8181
service_tier: Optional[Literal["auto", "default"]] = None
82-
credentials: Optional[Dict[str, str]] = {}
82+
credentials: Optional[Dict[str, str]] = None
8383
extra_headers: Optional[Dict[str, str]] = None
8484

8585

@@ -468,7 +468,7 @@ def _is_loaded(self):
468468

469469

470470
class HFGenerationParamsMixin(Artifact):
471-
max_new_tokens: int
471+
max_new_tokens: Optional[int] = None
472472
do_sample: bool = False
473473
temperature: Optional[float] = None
474474
top_p: Optional[float] = None
@@ -3362,6 +3362,8 @@ def get_engine_id(self):
33623362
return get_model_and_label_id(self.model, self.label)
33633363

33643364
def prepare_engine(self):
3365+
if self.credentials is None:
3366+
self.credentials = {}
33653367
# Initialize the token bucket rate limiter
33663368
self._rate_limiter = AsyncTokenBucket(
33673369
rate=self.max_requests_per_second,
@@ -3477,7 +3479,7 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
34773479
user requests.
34783480
34793481
Current _supported_apis = ["watsonx", "together-ai", "open-ai", "aws", "ollama",
3480-
"bam", "watsonx-sdk", "rits", "vertex-ai"]
3482+
"bam", "watsonx-sdk", "rits", "vertex-ai","hf-local"]
34813483
34823484
Args:
34833485
provider (Optional):
@@ -3684,6 +3686,11 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
36843686
"mixtral-8x7b-instruct-v0.1": "replicate/mistralai/mixtral-8x7b-instruct-v0.1",
36853687
"gpt-4-1": "replicate/openai/gpt-4.1",
36863688
},
3689+
"hf-local": {
3690+
"granite-3-3-8b-instruct": "ibm-granite/granite-3.3-8b-instruct",
3691+
"llama-3-3-8b-instruct": "meta-llama/Llama-3.3-8B-Instruct",
3692+
"SmolLM2-1.7B-Instruct": "HuggingFaceTB/SmolLM2-1.7B-Instruct",
3693+
},
36873694
}
36883695
provider_model_map["watsonx"] = {
36893696
k: f"watsonx/{v}" for k, v in provider_model_map["watsonx-sdk"].items()
@@ -3701,12 +3708,14 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
37013708
"azure": LiteLLMInferenceEngine,
37023709
"vertex-ai": LiteLLMInferenceEngine,
37033710
"replicate": LiteLLMInferenceEngine,
3711+
"hf-local": HFAutoModelInferenceEngine,
37043712
}
37053713

37063714
_provider_param_renaming = {
37073715
"bam": {"max_tokens": "max_new_tokens", "model": "model_name"},
37083716
"watsonx-sdk": {"model": "model_name"},
37093717
"rits": {"model": "model_name"},
3718+
"hf-local": {"model": "model_name"},
37103719
}
37113720

37123721
def get_return_object(self, **kwargs):

0 commit comments

Comments
 (0)