Skip to content

Commit 64829d6

Browse files
committed
First version of lanchain adapter
1 parent 0e58a73 commit 64829d6

File tree

2 files changed

+254
-0
lines changed

2 files changed

+254
-0
lines changed

nuclia/lib/langchain.py

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
import json
2+
3+
from typing import Optional, Dict, List, Any, Iterator, AsyncIterator
4+
from datetime import datetime, timezone
5+
from base64 import b64decode
6+
7+
from langchain_core.language_models import BaseChatModel
8+
from langchain_core.callbacks import (
9+
CallbackManagerForLLMRun,
10+
AsyncCallbackManagerForLLMRun,
11+
)
12+
from langchain_core.messages import BaseMessage, AIMessage, AIMessageChunk
13+
from langchain_core.messages import HumanMessage, SystemMessage
14+
from langchain_core.outputs import (
15+
ChatResult,
16+
ChatGeneration,
17+
ChatGenerationChunk,
18+
)
19+
from pydantic import Field
20+
21+
# Nuclia (sync & async)
22+
from nuclia.lib.nua import NuaClient, AsyncNuaClient
23+
from nuclia.sdk.predict import NucliaPredict, AsyncNucliaPredict
24+
from nuclia.lib.nua_responses import ChatModel, UserPrompt
25+
from nuclia_models.predict.generative_responses import (
26+
GenerativeFullResponse,
27+
TextGenerativeResponse,
28+
)
29+
30+
31+
class NucliaNuaChat(BaseChatModel):
32+
"""
33+
A LangChain-compatible ChatModel that uses nua client under the hood
34+
"""
35+
36+
model_name: str = Field(
37+
..., description="Which model to call, e.g. 'chatgpt-azure-4o'"
38+
)
39+
token: str = Field(..., description="Nua api Key")
40+
user_id: str = Field("nuclia-nua-chat", description="User ID for the chat session")
41+
system_prompt: Optional[str] = Field(
42+
None, description="Optional system instructions"
43+
)
44+
query_context: Optional[Dict[str, str]] = Field(
45+
None, description="Extra context for the LLM"
46+
)
47+
48+
region_base_url: Optional[str] = None
49+
nc_sync: Optional[NuaClient] = None
50+
predict_sync: Optional[NucliaPredict] = None
51+
nc_async: Optional[AsyncNuaClient] = None
52+
predict_async: Optional[AsyncNucliaPredict] = None
53+
54+
def __init__(self, **data: Any):
55+
super().__init__(**data)
56+
57+
if self.token:
58+
regional_url, expiration_date = self._parse_token(self.token)
59+
now = datetime.now(timezone.utc)
60+
if expiration_date <= now:
61+
raise ValueError("Expired nua token")
62+
self.region_base_url = regional_url
63+
64+
self.nc_sync = NuaClient(
65+
region=self.region_base_url,
66+
token=self.token,
67+
account="", # Not needed for current implementation, required by the client
68+
)
69+
self.predict_sync = NucliaPredict()
70+
71+
self.nc_async = AsyncNuaClient(
72+
region=self.region_base_url,
73+
token=self.token,
74+
account="", # Not needed for current implementation, required by the client
75+
)
76+
self.predict_async = AsyncNucliaPredict()
77+
78+
@staticmethod
79+
def _parse_token(token: str):
80+
parts = token.split(".")
81+
if len(parts) < 3:
82+
raise ValueError("Invalid JWT token, missing segments")
83+
84+
b64_payload = parts[1]
85+
payload = json.loads(b64decode(b64_payload + "=="))
86+
regional_url = payload["iss"]
87+
expiration_date = datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
88+
return regional_url, expiration_date
89+
90+
@property
91+
def _llm_type(self) -> str:
92+
return "nuclia-nua-chat"
93+
94+
@property
95+
def _identifying_params(self) -> dict:
96+
return {"model_name": self.model_name, "region_base_url": self.region_base_url}
97+
98+
def _generate(
99+
self,
100+
messages: List[BaseMessage],
101+
stop: Optional[List[str]] = None,
102+
run_manager: Optional[CallbackManagerForLLMRun] = None,
103+
**kwargs: Any,
104+
) -> ChatResult:
105+
if not self.predict_sync or not self.nc_sync:
106+
raise RuntimeError("Sync clients not initialized.")
107+
108+
question, user_prompt_str = self._combine_messages(messages)
109+
110+
body = ChatModel(
111+
question=question,
112+
retrieval=False,
113+
user_id=self.user_id,
114+
system=self.system_prompt,
115+
user_prompt=UserPrompt(prompt=user_prompt_str),
116+
query_context=self.query_context or {},
117+
)
118+
response: GenerativeFullResponse = self.predict_sync.generate(
119+
text=body,
120+
model=self.model_name,
121+
nc=self.nc_sync,
122+
)
123+
ai_message = AIMessage(content=response.answer)
124+
125+
return ChatResult(generations=[ChatGeneration(message=ai_message)])
126+
127+
def _combine_messages(self, messages: List[BaseMessage]) -> tuple[str, str]:
128+
"""
129+
For now this just discards anything that is not an Human message, to be improved
130+
"""
131+
user_parts = []
132+
question = ""
133+
for m in messages:
134+
if isinstance(m, SystemMessage) and self.system_prompt is None:
135+
# We could override self.system_prompt from the prompt if we want
136+
pass
137+
elif isinstance(m, HumanMessage):
138+
question = (
139+
m.content
140+
) # Overwrite each time, so the last human message is the question
141+
else:
142+
pass
143+
144+
user_prompt_str = "\n".join(user_parts)
145+
return question, user_prompt_str
146+
147+
async def _agenerate(
148+
self,
149+
messages: List[BaseMessage],
150+
stop: Optional[List[str]] = None,
151+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
152+
**kwargs: Any,
153+
) -> ChatResult:
154+
if not self.predict_async or not self.nc_async:
155+
raise RuntimeError("Async clients not initialized.")
156+
157+
question, user_prompt_str = self._combine_messages(messages)
158+
body = ChatModel(
159+
question=question,
160+
retrieval=False,
161+
user_id=self.user_id,
162+
system=self.system_prompt,
163+
user_prompt=UserPrompt(prompt=user_prompt_str),
164+
query_context=self.query_context or {},
165+
)
166+
response: GenerativeFullResponse = await self.predict_async.generate(
167+
text=body,
168+
model=self.model_name,
169+
nc=self.nc_async,
170+
)
171+
ai_message = AIMessage(content=response.answer)
172+
return ChatResult(generations=[ChatGeneration(message=ai_message)])
173+
174+
def _stream(
175+
self,
176+
messages: List[BaseMessage],
177+
stop: Optional[List[str]] = None,
178+
run_manager: Optional[CallbackManagerForLLMRun] = None,
179+
**kwargs: Any,
180+
) -> Iterator[ChatGenerationChunk]:
181+
if not self.predict_sync or not self.nc_sync:
182+
raise RuntimeError("Sync clients not initialized.")
183+
184+
question, user_prompt_str = self._combine_messages(messages)
185+
body = ChatModel(
186+
question=question,
187+
retrieval=False,
188+
user_id=self.user_id,
189+
system=self.system_prompt,
190+
user_prompt=UserPrompt(prompt=user_prompt_str),
191+
query_context=self.query_context or {},
192+
)
193+
194+
# Loop through each partial from the Nuclia synchronous streaming method
195+
for partial in self.predict_sync.generate_stream(
196+
text=body,
197+
model=self.model_name,
198+
nc=self.nc_sync,
199+
):
200+
# Check if partial is a "generative chunk" containing a TextGenerativeResponse
201+
if not partial or not partial.chunk:
202+
continue
203+
if not isinstance(partial.chunk, TextGenerativeResponse):
204+
# Skip anything that isn't text
205+
continue
206+
207+
text = partial.chunk.text or ""
208+
msg_chunk = AIMessageChunk(content=text)
209+
chunk = ChatGenerationChunk(message=msg_chunk)
210+
211+
if run_manager:
212+
run_manager.on_llm_new_token(token=text, chunk=chunk)
213+
214+
yield chunk
215+
216+
async def _astream(
217+
self,
218+
messages: List[BaseMessage],
219+
stop: Optional[List[str]] = None,
220+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
221+
**kwargs: Any,
222+
) -> AsyncIterator[ChatGenerationChunk]:
223+
if not self.predict_async or not self.nc_async:
224+
raise RuntimeError("Async clients not initialized.")
225+
226+
question, user_prompt_str = self._combine_messages(messages)
227+
body = ChatModel(
228+
question=question,
229+
retrieval=self.retrieval,
230+
user_id=self.user_id,
231+
system=self.system_prompt,
232+
user_prompt=UserPrompt(prompt=user_prompt_str),
233+
query_context=self.query_context or {},
234+
)
235+
236+
async for partial in self.predict_async.generate_stream(
237+
text=body,
238+
model=self.model_name,
239+
nc=self.nc_async,
240+
):
241+
if not partial or not partial.chunk:
242+
continue
243+
if not isinstance(partial.chunk, TextGenerativeResponse):
244+
continue
245+
246+
text = partial.chunk.text or ""
247+
msg_chunk = AIMessageChunk(content=text)
248+
chunk = ChatGenerationChunk(message=msg_chunk)
249+
250+
if run_manager:
251+
await run_manager.on_llm_new_token(token=text, chunk=chunk)
252+
253+
yield chunk

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ aiofiles
1414
backoff
1515
deprecated
1616
tabulate
17+
langchain_core>= 0.3.29

0 commit comments

Comments
 (0)