1+ """Embedding service using NVIDIA NIM or OpenAI."""
2+
3+ from __future__ import annotations
4+
5+ import os
6+ from typing import Any
7+
8+ import httpx
9+
10+ from smp .logging import get_logger
11+
12+ log = get_logger (__name__ )
13+
14+
15+ class EmbeddingService :
16+ """Generate embeddings via NVIDIA NIM or OpenAI."""
17+
18+ def __init__ (
19+ self ,
20+ provider : str = "nvidia" ,
21+ api_key : str | None = None ,
22+ model : str | None = None ,
23+ base_url : str | None = None ,
24+ dimension : int = 768 ,
25+ ) -> None :
26+ self ._provider = provider
27+ self ._api_key = api_key or os .environ .get ("NVIDIA_NIM_API_KEY" ) or os .environ .get ("OPENAI_API_KEY" , "" )
28+ self ._model = model or os .environ .get ("EMBEDDING_MODEL" , "nvidia/nv-embed-qa-4" )
29+ self ._base_url = base_url or os .environ .get (
30+ "EMBEDDING_BASE_URL" , "https://integrate.api.nvidia.com/v1"
31+ )
32+ self ._dimension = dimension
33+ self ._client : httpx .AsyncClient | None = None
34+
35+ async def connect (self ) -> None :
36+ self ._client = httpx .AsyncClient (
37+ base_url = self ._base_url ,
38+ headers = {"Authorization" : f"Bearer { self ._api_key } " },
39+ timeout = 60.0 ,
40+ )
41+ log .info ("embedding_service_connected" , provider = self ._provider , model = self ._model )
42+
43+ async def close (self ) -> None :
44+ if self ._client :
45+ await self ._client .aclose ()
46+ self ._client = None
47+
48+ @property
49+ def dimension (self ) -> int :
50+ return self ._dimension
51+
52+ async def embed (self , text : str ) -> list [float ]:
53+ """Generate embedding for a single text."""
54+ if self ._client is None :
55+ raise RuntimeError ("EmbeddingService not connected" )
56+
57+ if self ._provider == "nvidia" :
58+ return await self ._embed_nvidia (text )
59+ elif self ._provider == "openai" :
60+ return await self ._embed_openai (text )
61+ else :
62+ raise ValueError (f"Unknown provider: { self ._provider } " )
63+
64+ async def embed_batch (self , texts : list [str ]) -> list [list [float ]]:
65+ """Generate embeddings for multiple texts."""
66+ if self ._client is None :
67+ raise RuntimeError ("EmbeddingService not connected" )
68+
69+ if self ._provider == "nvidia" :
70+ return await self ._embed_batch_nvidia (texts )
71+ elif self ._provider == "openai" :
72+ return await self ._embed_batch_openai (texts )
73+ else :
74+ raise ValueError (f"Unknown provider: { self ._provider } " )
75+
76+ async def _embed_nvidia (self , text : str ) -> list [float ]:
77+ payload = {
78+ "input" : text ,
79+ "model" : self ._model ,
80+ }
81+ response = await self ._client .post ("/embeddings" , json = payload )
82+ response .raise_for_status ()
83+ data = response .json ()
84+ return data ["data" ][0 ]["embedding" ]
85+
86+ async def _embed_batch_nvidia (self , texts : list [str ]) -> list [list [float ]]:
87+ payload = {
88+ "input" : texts ,
89+ "model" : self ._model ,
90+ }
91+ response = await self ._client .post ("/embeddings" , json = payload )
92+ response .raise_for_status ()
93+ data = response .json ()
94+ return [item ["embedding" ] for item in data ["data" ]]
95+
96+ async def _embed_openai (self , text : str ) -> list [float ]:
97+ payload = {
98+ "input" : text ,
99+ "model" : self ._model ,
100+ }
101+ response = await self ._client .post ("/embeddings" , json = payload )
102+ response .raise_for_status ()
103+ data = response .json ()
104+ return data ["data" ][0 ]["embedding" ]
105+
106+ async def _embed_batch_openai (self , texts : list [str ]) -> list [list [float ]]:
107+ payload = {
108+ "input" : texts ,
109+ "model" : self ._model ,
110+ }
111+ response = await self ._client .post ("/embeddings" , json = payload )
112+ response .raise_for_status ()
113+ data = response .json ()
114+ return [item ["embedding" ] for item in data ["data" ]]
115+
116+
117+ def create_embedding_service () -> EmbeddingService :
118+ """Create embedding service from environment variables."""
119+ provider = os .getenv ("EMBEDDING_PROVIDER" , "nvidia" )
120+ api_key = os .getenv ("NVIDIA_NIM_API_KEY" ) or os .getenv ("OPENAI_API_KEY" )
121+ model = os .getenv ("EMBEDDING_MODEL" )
122+ base_url = os .getenv ("EMBEDDING_BASE_URL" )
123+ dimension = int (os .getenv ("EMBEDDING_DIMENSION" , "768" ))
124+ return EmbeddingService (provider = provider , api_key = api_key , model = model , base_url = base_url , dimension = dimension )
0 commit comments