-
Notifications
You must be signed in to change notification settings - Fork 31
Expand file tree
/
Copy pathazure_openai.py
More file actions
62 lines (47 loc) · 2.24 KB
/
Copy pathazure_openai.py
File metadata and controls
62 lines (47 loc) · 2.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from collections import defaultdict
from typing import List, Optional
import numpy as np
from pymilvus.model.base import BaseEmbeddingFunction
from pymilvus.model.utils import import_openai
class AzureOpenAIEmbeddingFunction(BaseEmbeddingFunction):
def __init__(
self,
model_name: str = "text-embedding-ada-002",
api_key: Optional[str] = None,
azure_endpoint: Optional[str] = None,
dimensions: Optional[int] = None,
api_version: Optional[str] = None,
**kwargs,
):
import_openai()
from openai import AzureOpenAI
self._openai_model_meta_info = defaultdict(dict)
self._openai_model_meta_info["text-embedding-3-small"]["dim"] = 1536
self._openai_model_meta_info["text-embedding-3-large"]["dim"] = 3072
self._openai_model_meta_info["text-embedding-ada-002"]["dim"] = 1536
self._model_config = dict({"api_key": api_key, "azure_endpoint": azure_endpoint, "api_version": api_version}, **kwargs)
additional_encode_config = {}
if dimensions is not None:
additional_encode_config = {"dimensions": dimensions}
self._openai_model_meta_info[model_name]["dim"] = dimensions
self._encode_config = {"model": model_name, **additional_encode_config}
self.model_name = model_name
self.client = AzureOpenAI(**self._model_config)
def encode_queries(self, queries: List[str]) -> List[np.array]:
return self._encode(queries)
def encode_documents(self, documents: List[str]) -> List[np.array]:
return self._encode(documents)
@property
def dim(self):
return self._openai_model_meta_info[self.model_name]["dim"]
def __call__(self, texts: List[str]) -> List[np.array]:
return self._encode(texts)
def _encode_query(self, query: str) -> np.array:
return self._encode(query)[0]
def _encode_document(self, document: str) -> np.array:
return self._encode(document)[0]
def _call_openai_api(self, texts: List[str]):
results = self.client.embeddings.create(input=texts, **self._encode_config).data
return [np.array(data.embedding) for data in results]
def _encode(self, texts: List[str]):
return self._call_openai_api(texts)