Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/pymilvus/model/dense/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pymilvus.model.dense.openai import OpenAIEmbeddingFunction
from pymilvus.model.dense.azure_openai import AzureOpenAIEmbeddingFunction
from pymilvus.model.dense.sentence_transformer import SentenceTransformerEmbeddingFunction
from pymilvus.model.dense.voyageai import VoyageEmbeddingFunction
from pymilvus.model.dense.jinaai import JinaEmbeddingFunction
Expand All @@ -13,6 +14,7 @@

__all__ = [
"OpenAIEmbeddingFunction",
"AzureOpenAIEmbeddingFunction",
"SentenceTransformerEmbeddingFunction",
"VoyageEmbeddingFunction",
"JinaEmbeddingFunction",
Expand Down
62 changes: 62 additions & 0 deletions src/pymilvus/model/dense/azure_openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from collections import defaultdict
from typing import List, Optional

import numpy as np

from pymilvus.model.base import BaseEmbeddingFunction
from pymilvus.model.utils import import_openai


class AzureOpenAIEmbeddingFunction(BaseEmbeddingFunction):
def __init__(
self,
model_name: str = "text-embedding-ada-002",
api_key: Optional[str] = None,
azure_endpoint: Optional[str] = None,
dimensions: Optional[int] = None,
api_version: Optional[str] = None,
**kwargs,
):
import_openai()
from openai import AzureOpenAI

self._openai_model_meta_info = defaultdict(dict)
self._openai_model_meta_info["text-embedding-3-small"]["dim"] = 1536
self._openai_model_meta_info["text-embedding-3-large"]["dim"] = 3072
self._openai_model_meta_info["text-embedding-ada-002"]["dim"] = 1536

self._model_config = dict({"api_key": api_key, "azure_endpoint": azure_endpoint, "api_version": api_version}, **kwargs)
additional_encode_config = {}
if dimensions is not None:
additional_encode_config = {"dimensions": dimensions}
self._openai_model_meta_info[model_name]["dim"] = dimensions

self._encode_config = {"model": model_name, **additional_encode_config}
self.model_name = model_name
self.client = AzureOpenAI(**self._model_config)

def encode_queries(self, queries: List[str]) -> List[np.array]:
return self._encode(queries)

def encode_documents(self, documents: List[str]) -> List[np.array]:
return self._encode(documents)

@property
def dim(self):
return self._openai_model_meta_info[self.model_name]["dim"]

def __call__(self, texts: List[str]) -> List[np.array]:
return self._encode(texts)

def _encode_query(self, query: str) -> np.array:
return self._encode(query)[0]

def _encode_document(self, document: str) -> np.array:
return self._encode(document)[0]

def _call_openai_api(self, texts: List[str]):
results = self.client.embeddings.create(input=texts, **self._encode_config).data
return [np.array(data.embedding) for data in results]

def _encode(self, texts: List[str]):
return self._call_openai_api(texts)