Skip to content

Commit a3e37ac

Browse files
elronbandeleladven
andauthored
Fix some bugs in inference engine tests (#1682)
* Fix some bugs in inference engine tests Signed-off-by: elronbandel <[email protected]> * Fix some bugs in inference engine tests Signed-off-by: elronbandel <[email protected]> * Fix bug in name conversion in rits Signed-off-by: elronbandel <[email protected]> * Add engine id Signed-off-by: elronbandel <[email protected]> * fix Signed-off-by: elronbandel <[email protected]> * fix Signed-off-by: elronbandel <[email protected]> * Use greedy decoding and remove redundant cache Signed-off-by: elronbandel <[email protected]> * Fix hf-auto model test Signed-off-by: elronbandel <[email protected]> * Touch up watsonx tests Signed-off-by: elronbandel <[email protected]> * Fix inference tests. 1. Use local inference engine on CPU when test inference engine, for reproducability. 2. In cache maechanisim, don't assum that infer on empty list yields empty list. Signed-off-by: Elad Venezian <[email protected]> * Fix setting of data classification policy Signed-off-by: elronbandel <[email protected]> --------- Signed-off-by: elronbandel <[email protected]> Signed-off-by: Elad Venezian <[email protected]> Co-authored-by: Elad Venezian <[email protected]>
1 parent 794513d commit a3e37ac

File tree

8 files changed

+199
-225
lines changed

8 files changed

+199
-225
lines changed

.github/workflows/inference_tests.yml

+4-3
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,11 @@ jobs:
2727
WML_URL: ${{ secrets.WML_URL }}
2828
WML_PROJECT_ID: ${{ secrets.WML_PROJECT_ID }}
2929
WML_APIKEY: ${{ secrets.WML_APIKEY }}
30-
WX_URL: ${{ secrets.WX_URL }}
31-
WX_PROJECT_ID: ${{ secrets.WX_PROJECT_ID }}
32-
WX_API_KEY: ${{ secrets.WX_API_KEY }}
30+
WX_URL: ${{ secrets.WML_URL }} # Similar to WML_URL
31+
WX_PROJECT_ID: ${{ secrets.WML_PROJECT_ID }} # Similar to WML_PROJECT_ID
32+
WX_API_KEY: ${{ secrets.WML_APIKEY }} # Similar to WML_APIKEY
3333
GENAI_KEY: ${{ secrets.GENAI_KEY }}
34+
3435
steps:
3536
- uses: actions/checkout@v4
3637

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ watsonx = [
115115
"ibm-watsonx-ai==1.1.14"
116116
]
117117
inference-tests = [
118-
"litellm==v1.52.9",
118+
"litellm>=1.52.9",
119119
"tenacity",
120120
"diskcache",
121121
"numpy==1.26.4",

src/unitxt/api.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from .logging_utils import get_logger
2222
from .metric_utils import EvaluationResults, _compute, _inference_post_process
2323
from .operator import SourceOperator
24-
from .schema import loads_instance
24+
from .schema import loads_batch
2525
from .settings_utils import get_constants, get_settings
2626
from .standard import DatasetRecipe
2727
from .task import Task
@@ -98,6 +98,7 @@ def create_dataset(
9898
train_set: Optional[List[Dict[Any, Any]]] = None,
9999
validation_set: Optional[List[Dict[Any, Any]]] = None,
100100
split: Optional[str] = None,
101+
data_classification_policy: Optional[List[str]] = None,
101102
**kwargs,
102103
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
103104
"""Creates dataset from input data based on a specific task.
@@ -108,6 +109,7 @@ def create_dataset(
108109
train_set : optional train_set
109110
validation_set: optional validation set
110111
split: optional one split to choose
112+
data_classification_policy: data_classification_policy
111113
**kwargs: Arguments used to load dataset from provided datasets (see load_dataset())
112114
113115
Returns:
@@ -129,7 +131,7 @@ def create_dataset(
129131
f"No 'template' was passed to the create_dataset() and the given task ('{task.__id__}') has no 'default_template' field."
130132
)
131133

132-
card = TaskCard(loader=LoadFromDictionary(data=data), task=task)
134+
card = TaskCard(loader=LoadFromDictionary(data=data, data_classification_policy=data_classification_policy), task=task)
133135
return load_dataset(card=card, split=split, **kwargs)
134136

135137

@@ -283,7 +285,7 @@ def produce(
283285
result = _get_produce_with_cache(dataset_query, **kwargs)(instance_or_instances)
284286
if not is_list:
285287
return result[0]
286-
return Dataset.from_list(result).with_transform(loads_instance)
288+
return Dataset.from_list(result).with_transform(loads_batch)
287289

288290

289291
def infer(

src/unitxt/dataset.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from typing import Optional, Union
2+
from typing import Dict, Optional, Union
33

44
import datasets
55

@@ -46,7 +46,7 @@
4646
from .random_utils import __file__ as _
4747
from .recipe import __file__ as _
4848
from .register import __file__ as _
49-
from .schema import loads_instance
49+
from .schema import loads_batch, loads_instance
5050
from .serializers import __file__ as _
5151
from .settings_utils import get_constants
5252
from .span_lableing_operators import __file__ as _
@@ -115,6 +115,13 @@ def _download_and_prepare(
115115
dl_manager, "no_checks", **prepare_splits_kwargs
116116
)
117117

118+
def as_streaming_dataset(self, split: Optional[str] = None, base_path: Optional[str] = None) -> Union[Dict[str, datasets.IterableDataset], datasets.IterableDataset]:
119+
return (
120+
super()
121+
.as_streaming_dataset(split, base_path=base_path)
122+
.map(loads_instance)
123+
)
124+
118125
def as_dataset(
119126
self,
120127
split: Optional[datasets.Split] = None,
@@ -157,5 +164,5 @@ def as_dataset(
157164
return (
158165
super()
159166
.as_dataset(split, run_post_process, verification_mode, in_memory)
160-
.with_transform(loads_instance)
167+
.with_transform(loads_batch)
161168
)

src/unitxt/inference.py

+29-21
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import uuid
1414
from collections import Counter
1515
from datetime import datetime
16+
from itertools import islice
1617
from multiprocessing.pool import ThreadPool
1718
from typing import (
1819
Any,
@@ -55,6 +56,11 @@
5556
logger = get_logger()
5657

5758

59+
def batched(lst, n):
60+
it = iter(lst)
61+
while batch := list(islice(it, n)):
62+
yield batch
63+
5864
class StandardAPIParamsMixin(Artifact):
5965
model: str
6066
frequency_penalty: Optional[float] = None
@@ -227,12 +233,8 @@ def infer(
227233
result = self._mock_infer(dataset)
228234
else:
229235
if self.use_cache:
230-
if isinstance(dataset, Dataset):
231-
dataset = dataset.to_list()
232-
dataset_batches = [dataset[i:i + self.cache_batch_size]
233-
for i in range(0, len(dataset), self.cache_batch_size)]
234236
result = []
235-
for batch_num, batch in enumerate(dataset_batches):
237+
for batch_num, batch in enumerate(batched(dataset, self.cache_batch_size)):
236238
cached_results = []
237239
missing_examples = []
238240
for i, item in enumerate(batch):
@@ -243,16 +245,19 @@ def infer(
243245
else:
244246
missing_examples.append((i, item)) # each element is index in batch and example
245247
# infare on missing examples only, without indices
246-
logger.info(f"Inferring batch {batch_num} / {len(dataset_batches)}")
247-
inferred_results = self._infer([e[1] for e in missing_examples], return_meta_data)
248-
# recombined to index and value
249-
inferred_results = list(zip([e[0] for e in missing_examples], inferred_results))
250-
# Add missing examples to cache
251-
for (_, item), (_, prediction) in zip(missing_examples, inferred_results):
252-
if prediction is None:
253-
continue
254-
cache_key = self._get_cache_key(item)
255-
self._cache[cache_key] = prediction
248+
logger.info(f"Inferring batch {batch_num} / {len(dataset) // self.cache_batch_size}")
249+
if len(missing_examples) > 0:
250+
inferred_results = self._infer([e[1] for e in missing_examples], return_meta_data)
251+
# recombined to index and value
252+
inferred_results = list(zip([e[0] for e in missing_examples], inferred_results))
253+
# Add missing examples to cache
254+
for (_, item), (_, prediction) in zip(missing_examples, inferred_results):
255+
if prediction is None:
256+
continue
257+
cache_key = self._get_cache_key(item)
258+
self._cache[cache_key] = prediction
259+
else:
260+
inferred_results = []
256261

257262
# Combine cached and inferred results in original order
258263
batch_predictions = [p[1] for p in sorted(cached_results + inferred_results)]
@@ -1798,6 +1803,10 @@ class RITSInferenceEngine(
17981803
label: str = "rits"
17991804
data_classification_policy = ["public", "proprietary"]
18001805

1806+
model_names_dict = {
1807+
"microsoft/phi-4": "microsoft-phi-4"
1808+
}
1809+
18011810
def get_default_headers(self):
18021811
return {"RITS_API_KEY": self.credentials["api_key"]}
18031812

@@ -1818,8 +1827,10 @@ def get_base_url_from_model_name(model_name: str):
18181827
RITSInferenceEngine._get_model_name_for_endpoint(model_name)
18191828
)
18201829

1821-
@staticmethod
1822-
def _get_model_name_for_endpoint(model_name: str):
1830+
@classmethod
1831+
def _get_model_name_for_endpoint(cls, model_name: str):
1832+
if model_name in cls.model_names_dict:
1833+
return cls.model_names_dict[model_name]
18231834
return (
18241835
model_name.split("/")[-1]
18251836
.lower()
@@ -2959,15 +2970,12 @@ def prepare_engine(self):
29592970
capacity=self.max_requests_per_second,
29602971
)
29612972
self.inference_type = "litellm"
2962-
import litellm
29632973
from litellm import acompletion
2964-
from litellm.caching.caching import Cache
29652974

2966-
litellm.cache = Cache(type="disk")
29672975

29682976
self._completion = acompletion
29692977
# Initialize a semaphore to limit concurrency
2970-
self._semaphore = asyncio.Semaphore(self.max_requests_per_second)
2978+
self._semaphore = asyncio.Semaphore(round(self.max_requests_per_second))
29712979

29722980
async def _infer_instance(
29732981
self, index: int, instance: Dict[str, Any]

src/unitxt/loaders.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -845,7 +845,7 @@ def verify(self):
845845

846846
def _maybe_set_classification_policy(self):
847847
self.set_default_data_classification(
848-
["proprietary"], "when loading from python dictionary"
848+
self.data_classification_policy or ["proprietary"], "when loading from python dictionary"
849849
)
850850

851851
def load_iterables(self) -> MultiStream:

src/unitxt/schema.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,7 @@ def load_chat_source(chat_str):
6767
)
6868
return chat
6969

70-
71-
def loads_instance(batch):
70+
def loads_batch(batch):
7271
if (
7372
"source" in batch
7473
and isinstance(batch["source"][0], str)
@@ -86,6 +85,24 @@ def loads_instance(batch):
8685
batch["task_data"] = [json.loads(d) for d in batch["task_data"]]
8786
return batch
8887

88+
def loads_instance(instance):
89+
if (
90+
"source" in instance
91+
and isinstance(instance["source"], str)
92+
and (
93+
instance["source"].startswith('[{"role":')
94+
or instance["source"].startswith('[{"content":')
95+
)
96+
):
97+
instance["source"] = load_chat_source(instance["source"])
98+
if (
99+
not settings.task_data_as_text
100+
and "task_data" in instance
101+
and isinstance(instance["task_data"], str)
102+
):
103+
instance["task_data"] = json.loads(instance["task_data"])
104+
return instance
105+
89106

90107
class FinalizeDataset(InstanceOperatorValidator):
91108
group_by: List[List[str]]

0 commit comments

Comments
 (0)