Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some bugs in inference engine tests #1682

Merged
merged 15 commits into from
Mar 18, 2025
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions .github/workflows/inference_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ jobs:
WML_URL: ${{ secrets.WML_URL }}
WML_PROJECT_ID: ${{ secrets.WML_PROJECT_ID }}
WML_APIKEY: ${{ secrets.WML_APIKEY }}
WX_URL: ${{ secrets.WX_URL }}
WX_PROJECT_ID: ${{ secrets.WX_PROJECT_ID }}
WX_API_KEY: ${{ secrets.WX_API_KEY }}
WX_URL: ${{ secrets.WML_URL }} # Similar to WML_URL
WX_PROJECT_ID: ${{ secrets.WML_PROJECT_ID }} # Similar to WML_PROJECT_ID
WX_API_KEY: ${{ secrets.WML_APIKEY }} # Similar to WML_APIKEY
GENAI_KEY: ${{ secrets.GENAI_KEY }}

steps:
- uses: actions/checkout@v4

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ watsonx = [
"ibm-watsonx-ai==1.1.14"
]
inference-tests = [
"litellm==v1.52.9",
"litellm>=1.52.9",
"tenacity",
"diskcache",
"numpy==1.26.4",
Expand Down
12 changes: 9 additions & 3 deletions src/unitxt/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from .logging_utils import get_logger
from .metric_utils import EvaluationResults, _compute, _inference_post_process
from .operator import SourceOperator
from .schema import loads_instance
from .schema import loads_batch
from .settings_utils import get_constants, get_settings
from .standard import DatasetRecipe
from .task import Task
Expand Down Expand Up @@ -98,6 +98,7 @@ def create_dataset(
train_set: Optional[List[Dict[Any, Any]]] = None,
validation_set: Optional[List[Dict[Any, Any]]] = None,
split: Optional[str] = None,
data_classification_policy: Optional[List[str]] = None,
**kwargs,
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
"""Creates dataset from input data based on a specific task.
Expand All @@ -108,6 +109,7 @@ def create_dataset(
train_set : optional train_set
validation_set: optional validation set
split: optional one split to choose
data_classification_policy: data_classification_policy
**kwargs: Arguments used to load dataset from provided datasets (see load_dataset())

Returns:
Expand All @@ -129,7 +131,11 @@ def create_dataset(
f"No 'template' was passed to the create_dataset() and the given task ('{task.__id__}') has no 'default_template' field."
)

card = TaskCard(loader=LoadFromDictionary(data=data), task=task)
args = {"data": data}
if data_classification_policy is not None:
args["default_data_classification_policy"] = data_classification_policy

card = TaskCard(loader=LoadFromDictionary(**args), task=task)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not :

card = TaskCard(loader=LoadFromDictionary(data=data), task=task, data_classification_policy=data_classification_policy) ?

return load_dataset(card=card, split=split, **kwargs)


Expand Down Expand Up @@ -283,7 +289,7 @@ def produce(
result = _get_produce_with_cache(dataset_query, **kwargs)(instance_or_instances)
if not is_list:
return result[0]
return Dataset.from_list(result).with_transform(loads_instance)
return Dataset.from_list(result).with_transform(loads_batch)


def infer(
Expand Down
13 changes: 10 additions & 3 deletions src/unitxt/dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Optional, Union
from typing import Dict, Optional, Union

import datasets

Expand Down Expand Up @@ -46,7 +46,7 @@
from .random_utils import __file__ as _
from .recipe import __file__ as _
from .register import __file__ as _
from .schema import loads_instance
from .schema import loads_batch, loads_instance
from .serializers import __file__ as _
from .settings_utils import get_constants
from .span_lableing_operators import __file__ as _
Expand Down Expand Up @@ -115,6 +115,13 @@ def _download_and_prepare(
dl_manager, "no_checks", **prepare_splits_kwargs
)

def as_streaming_dataset(self, split: Optional[str] = None, base_path: Optional[str] = None) -> Union[Dict[str, datasets.IterableDataset], datasets.IterableDataset]:
return (
super()
.as_streaming_dataset(split, base_path=base_path)
.map(loads_instance)
)

def as_dataset(
self,
split: Optional[datasets.Split] = None,
Expand Down Expand Up @@ -157,5 +164,5 @@ def as_dataset(
return (
super()
.as_dataset(split, run_post_process, verification_mode, in_memory)
.with_transform(loads_instance)
.with_transform(loads_batch)
)
32 changes: 20 additions & 12 deletions src/unitxt/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import uuid
from collections import Counter
from datetime import datetime
from itertools import islice
from multiprocessing.pool import ThreadPool
from typing import (
Any,
Expand Down Expand Up @@ -55,6 +56,11 @@
logger = get_logger()


def batched(lst, n):
it = iter(lst)
while batch := list(islice(it, n)):
yield batch

class StandardAPIParamsMixin(Artifact):
model: str
frequency_penalty: Optional[float] = None
Expand Down Expand Up @@ -227,12 +233,8 @@ def infer(
result = self._mock_infer(dataset)
else:
if self.use_cache:
if isinstance(dataset, Dataset):
dataset = dataset.to_list()
dataset_batches = [dataset[i:i + self.cache_batch_size]
for i in range(0, len(dataset), self.cache_batch_size)]
result = []
for batch_num, batch in enumerate(dataset_batches):
for batch_num, batch in enumerate(batched(dataset, self.cache_batch_size)):
cached_results = []
missing_examples = []
for i, item in enumerate(batch):
Expand All @@ -243,7 +245,7 @@ def infer(
else:
missing_examples.append((i, item)) # each element is index in batch and example
# infare on missing examples only, without indices
logger.info(f"Inferring batch {batch_num} / {len(dataset_batches)}")
logger.info(f"Inferring batch {batch_num} / {len(dataset) // self.cache_batch_size}")
inferred_results = self._infer([e[1] for e in missing_examples], return_meta_data)
# recombined to index and value
inferred_results = list(zip([e[0] for e in missing_examples], inferred_results))
Expand Down Expand Up @@ -1797,6 +1799,10 @@ class RITSInferenceEngine(
label: str = "rits"
data_classification_policy = ["public", "proprietary"]

model_names_dict = {
"microsoft/phi-4": "microsoft-phi-4"
}

def get_default_headers(self):
return {"RITS_API_KEY": self.credentials["api_key"]}

Expand All @@ -1817,8 +1823,10 @@ def get_base_url_from_model_name(model_name: str):
RITSInferenceEngine._get_model_name_for_endpoint(model_name)
)

@staticmethod
def _get_model_name_for_endpoint(model_name: str):
@classmethod
def _get_model_name_for_endpoint(cls, model_name: str):
if model_name in cls.model_names_dict:
return cls.model_names_dict[model_name]
return (
model_name.split("/")[-1]
.lower()
Expand Down Expand Up @@ -2958,15 +2966,12 @@ def prepare_engine(self):
capacity=self.max_requests_per_second,
)
self.inference_type = "litellm"
import litellm
from litellm import acompletion
from litellm.caching.caching import Cache

litellm.cache = Cache(type="disk")

self._completion = acompletion
# Initialize a semaphore to limit concurrency
self._semaphore = asyncio.Semaphore(self.max_requests_per_second)
self._semaphore = asyncio.Semaphore(round(self.max_requests_per_second))

async def _infer_instance(
self, index: int, instance: Dict[str, Any]
Expand Down Expand Up @@ -3302,6 +3307,9 @@ class HFOptionSelectingInferenceEngine(InferenceEngine, TorchDeviceMixin):
"transformers": "Install huggingface package using 'pip install --upgrade transformers"
}

def get_engine_id(self):
return get_model_and_label_id(self.model_name, "hf_option_selecting")

def prepare_engine(self):
from transformers import AutoModelForCausalLM, AutoTokenizer

Expand Down
5 changes: 3 additions & 2 deletions src/unitxt/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
from huggingface_hub import HfApi
from tqdm import tqdm

from .dataclass import NonPositionalField
from .dataclass import Field, NonPositionalField
from .error_utils import Documentation, UnitxtError, UnitxtWarning
from .fusion import FixedFusion
from .logging_utils import get_logger
Expand Down Expand Up @@ -823,6 +823,7 @@ class LoadFromDictionary(Loader):
"""

data: Dict[str, List[Dict[str, Any]]]
default_data_classification_policy: List[str] = Field(default_factory=lambda: ["proprietary"])

def verify(self):
super().verify()
Expand All @@ -845,7 +846,7 @@ def verify(self):

def _maybe_set_classification_policy(self):
self.set_default_data_classification(
["proprietary"], "when loading from python dictionary"
self.default_data_classification_policy, "when loading from python dictionary"
)

def load_iterables(self) -> MultiStream:
Expand Down
21 changes: 19 additions & 2 deletions src/unitxt/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ def load_chat_source(chat_str):
)
return chat


def loads_instance(batch):
def loads_batch(batch):
if (
"source" in batch
and isinstance(batch["source"][0], str)
Expand All @@ -86,6 +85,24 @@ def loads_instance(batch):
batch["task_data"] = [json.loads(d) for d in batch["task_data"]]
return batch

def loads_instance(instance):
if (
"source" in instance
and isinstance(instance["source"], str)
and (
instance["source"].startswith('[{"role":')
Copy link
Member

@yoavkatz yoavkatz Mar 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems very risky, a minor added space, will cause the code to fail. Need to consider alternatives.

or instance["source"].startswith('[{"content":')
)
):
instance["source"] = load_chat_source(instance["source"])
if (
not settings.task_data_as_text
and "task_data" in instance
and isinstance(instance["task_data"], str)
):
instance["task_data"] = json.loads(instance["task_data"])
return instance


class FinalizeDataset(InstanceOperatorValidator):
group_by: List[List[str]]
Expand Down
Loading
Loading