Skip to content

Commit 794513d

Browse files
Fix: AzureOpenAIInferenceEngine fails if api_version is not set (#1680)
* Fix: AzureOpenAIInferenceEngine fails if api_version is not set Signed-off-by: Martín Santillán Cooper <[email protected]> * Add HFOptionSelectingInferenceEngine get_engine_id method Signed-off-by: Martín Santillán Cooper <[email protected]> --------- Signed-off-by: Martín Santillán Cooper <[email protected]>
1 parent 287a801 commit 794513d

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

src/unitxt/inference.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -1774,7 +1774,7 @@ def _prepare_credentials(self) -> CredentialsOpenAi:
17741774
), "Error while trying to run AzureOpenAIInferenceEngine: Missing environment variable param AZURE_OPENAI_HOST or OPENAI_API_VERSION"
17751775
api_url = f"{azure_openapi_host}/openai/deployments/{self.model_name}/chat/completions?api-version={api_version}"
17761776

1777-
return {"api_key": api_key, "api_url": api_url}
1777+
return {"api_key": api_key, "api_url": api_url, "api_version": api_version}
17781778

17791779
def create_client(self):
17801780
from openai import AzureOpenAI
@@ -1783,6 +1783,7 @@ def create_client(self):
17831783
return AzureOpenAI(
17841784
api_key=self.credentials["api_key"],
17851785
base_url=self.credentials["api_url"],
1786+
api_version=self.credentials["api_version"],
17861787
default_headers=self.get_default_headers(),
17871788
)
17881789

@@ -3294,14 +3295,18 @@ class HFOptionSelectingInferenceEngine(InferenceEngine, TorchDeviceMixin):
32943295
32953296
This class uses models from the HuggingFace Transformers library to calculate log probabilities for text inputs.
32963297
"""
3297-
3298+
label = "hf_option_selection"
32983299
model_name: str
32993300
batch_size: int
33003301

33013302
_requirements_list = {
33023303
"transformers": "Install huggingface package using 'pip install --upgrade transformers"
33033304
}
33043305

3306+
def get_engine_id(self):
3307+
return get_model_and_label_id(self.model, self.label)
3308+
3309+
33053310
def prepare_engine(self):
33063311
from transformers import AutoModelForCausalLM, AutoTokenizer
33073312

0 commit comments

Comments
 (0)