Skip to content

Commit 85c78b8

Browse files
authored
feat: support tracing on streaming endpoints (#70)
1 parent ba1431f commit 85c78b8

5 files changed

Lines changed: 80 additions & 59 deletions

File tree

python/src/cairo_coder/core/rag_pipeline.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing import Any
1111

1212
import dspy
13+
import langsmith as ls
1314
import structlog
1415
from dspy.adapters import XMLAdapter
1516
from dspy.utils.callback import BaseCallback
@@ -164,6 +165,7 @@ async def aforward(
164165
query=query, context=context, chat_history=chat_history_str
165166
)
166167

168+
167169
async def aforward_streaming(
168170
self,
169171
query: str,
@@ -218,13 +220,15 @@ async def aforward_streaming(
218220

219221
# Stream response generation. Use ChatAdapter for streaming, which performs better.
220222
with dspy.context(
221-
lm=dspy.LM("gemini/gemini-flash-lite-latest", max_tokens=10000),
222-
adapter=dspy.adapters.XMLAdapter(),
223-
):
224-
async for chunk in self.generation_program.aforward_streaming(
225-
query=query, context=context, chat_history=chat_history_str
226-
):
227-
yield StreamEvent(type=StreamEventType.RESPONSE, data=chunk)
223+
adapter=dspy.adapters.ChatAdapter()
224+
), ls.trace(name="GenerationProgramStreaming", run_type="llm", inputs={"query": query, "chat_history": chat_history_str, "context": context}) as rt:
225+
chunk_accumulator = ""
226+
async for chunk in self.generation_program.aforward_streaming(
227+
query=query, context=context, chat_history=chat_history_str
228+
):
229+
chunk_accumulator += chunk
230+
yield StreamEvent(type=StreamEventType.RESPONSE, data=chunk)
231+
rt.end(outputs={"output": chunk_accumulator})
228232

229233
# Pipeline completed
230234
yield StreamEvent(type=StreamEventType.END, data=None)

python/src/cairo_coder/dspy/generation_program.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def get_lm_usage(self) -> dict[str, int]:
203203
"""
204204
return self.generation_program.get_lm_usage()
205205

206-
@traceable(name="GenerationProgram", run_type="llm")
206+
@traceable(name="GenerationProgram", run_type="llm", metadata={"llm_provider": dspy.settings.lm})
207207
async def aforward(self, query: str, context: str, chat_history: Optional[str] = None) -> dspy.Prediction | None :
208208
"""
209209
Generate Cairo code response based on query and context - async

python/src/cairo_coder/dspy/query_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def __init__(self):
123123
"foundry",
124124
}
125125

126-
@traceable(name="QueryProcessorProgram", run_type="llm")
126+
@traceable(name="QueryProcessorProgram", run_type="llm", metadata={"llm_provider": dspy.settings.lm})
127127
async def aforward(self, query: str, chat_history: Optional[str] = None) -> ProcessedQuery:
128128
"""
129129
Process a user query into a structured format for document retrieval.

python/src/cairo_coder/dspy/retrieval_judge.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,17 +47,21 @@ class RetrievalRecallPrecision(dspy.Signature):
4747
"""
4848

4949
query: str = dspy.InputField()
50-
system_resource: str = dspy.InputField(desc="Single resource text (content + minimal metadata/title)")
50+
system_resource: str = dspy.InputField(
51+
desc="Single resource text (content + minimal metadata/title)"
52+
)
5153
reasoning: str = dspy.OutputField(
5254
desc="A short sentence, on why a selected resource will be useful. If it's not selected, reason about why it's not going to be useful. Start by Resource <resource_title>..."
5355
)
5456
resource_note: float = dspy.OutputField(
5557
desc="A note between 0 and 1.0 on how useful the resource is to directly answer the query. 0 being completely unrelated, 1.0 being very relevant, 0.5 being 'not directly related but still informative and can be useful for context'."
5658
)
5759

60+
5861
DEFAULT_THRESHOLD = 0.4
5962
DEFAULT_PARALLEL_THREADS = 5
6063

64+
6165
class RetrievalJudge(dspy.Module):
6266
"""
6367
LLM-based judge that scores retrieved documents for relevance to a query.
@@ -88,13 +92,17 @@ def __init__(self):
8892
raise FileNotFoundError(f"{compiled_program_path} not found")
8993
self.rater.load(compiled_program_path)
9094

91-
@traceable(name="RetrievalJudge", run_type="llm")
95+
@traceable(
96+
name="RetrievalJudge", run_type="llm", metadata={"llm_provider": dspy.settings.lm}
97+
)
9298
async def aforward(self, query: str, documents: list[Document]) -> list[Document]:
9399
"""Async judge."""
94100
if not documents:
95101
return documents
96102

97-
keep_docs, judged_indices, judged_payloads = self._split_templates_and_prepare_docs(documents)
103+
keep_docs, judged_indices, judged_payloads = self._split_templates_and_prepare_docs(
104+
documents
105+
)
98106

99107
# TODO: can we use dspy.Parallel here instead of asyncio gather?
100108
if judged_payloads:
@@ -114,7 +122,11 @@ async def judge_one(doc_string: str):
114122
keep_docs=keep_docs,
115123
)
116124
except Exception as e:
117-
logger.error("Retrieval judge failed (async), returning all docs", error=str(e), exc_info=True)
125+
logger.error(
126+
"Retrieval judge failed (async), returning all docs",
127+
error=str(e),
128+
exc_info=True,
129+
)
118130
return documents
119131

120132
return keep_docs
@@ -155,7 +167,9 @@ def _split_templates_and_prepare_docs(
155167
return keep_docs, judged_indices, judged_payloads
156168

157169
@staticmethod
158-
def _document_to_string(title: str, content: str, max_len: int = JUDGE_DOCUMENT_PREVIEW_MAX_LEN) -> str:
170+
def _document_to_string(
171+
title: str, content: str, max_len: int = JUDGE_DOCUMENT_PREVIEW_MAX_LEN
172+
) -> str:
159173
"""Build the string seen by the judge for one doc."""
160174
preview = content[:max_len]
161175
if len(content) > max_len:

python/src/cairo_coder/server/app.py

Lines changed: 48 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414
from contextlib import asynccontextmanager
1515

1616
import dspy
17+
import langsmith as ls
1718
import structlog
1819
import uvicorn
19-
from dspy.adapters import XMLAdapter
20+
from dspy.adapters import ChatAdapter, XMLAdapter
2021
from fastapi import Depends, FastAPI, Header, HTTPException, Request
2122
from fastapi.middleware.cors import CORSMiddleware
2223
from fastapi.responses import StreamingResponse
@@ -185,7 +186,7 @@ def __init__(
185186
embedder = dspy.Embedder("gemini/gemini-embedding-001", dimensions=3072, batch_size=512)
186187
dspy.configure(
187188
lm=dspy.LM("gemini/gemini-flash-latest", max_tokens=30000, cache=False),
188-
adapter=XMLAdapter(),
189+
adapter=ChatAdapter(),
189190
embedder=embedder,
190191
callbacks=[AgentLoggingCallback()],
191192
track_usage=True,
@@ -420,49 +421,51 @@ async def _stream_chat_completion(
420421
content_buffer = ""
421422

422423
try:
423-
async for event in agent.aforward_streaming(
424-
query=query, chat_history=history, mcp_mode=mcp_mode
425-
):
426-
if event.type == "sources":
427-
# Emit sources event for clients to display
428-
sources_chunk = {
429-
"type": "sources",
430-
"data": event.data,
431-
}
432-
yield f"data: {json.dumps(sources_chunk)}\n\n"
433-
elif event.type == "response":
434-
content_buffer += event.data
435-
436-
# Send content chunk
437-
chunk = {
438-
"id": response_id,
439-
"object": "chat.completion.chunk",
440-
"created": created,
441-
"model": "cairo-coder",
442-
"choices": [
443-
{"index": 0, "delta": {"content": event.data}, "finish_reason": None}
444-
],
445-
}
446-
yield f"data: {json.dumps(chunk)}\n\n"
447-
elif event.type == "error":
448-
# Emit an error as a final delta and stop
449-
error_chunk = {
450-
"id": response_id,
451-
"object": "chat.completion.chunk",
452-
"created": created,
453-
"model": "cairo-coder",
454-
"choices": [
455-
{
456-
"index": 0,
457-
"delta": {"content": f"\n\nError: {event.data}"},
458-
"finish_reason": "stop",
459-
}
460-
],
461-
}
462-
yield f"data: {json.dumps(error_chunk)}\n\n"
463-
break
464-
elif event.type == "end":
465-
break
424+
with ls.trace(name="RagPipelineStreaming", run_type="chain", inputs={"query": query, "chat_history": history, "mcp_mode": mcp_mode}) as rt:
425+
async for event in agent.aforward_streaming(
426+
query=query, chat_history=history, mcp_mode=mcp_mode
427+
):
428+
if event.type == "sources":
429+
# Emit sources event for clients to display
430+
sources_chunk = {
431+
"type": "sources",
432+
"data": event.data,
433+
}
434+
yield f"data: {json.dumps(sources_chunk)}\n\n"
435+
elif event.type == "response":
436+
content_buffer += event.data
437+
438+
# Send content chunk
439+
chunk = {
440+
"id": response_id,
441+
"object": "chat.completion.chunk",
442+
"created": created,
443+
"model": "cairo-coder",
444+
"choices": [
445+
{"index": 0, "delta": {"content": event.data}, "finish_reason": None}
446+
],
447+
}
448+
yield f"data: {json.dumps(chunk)}\n\n"
449+
elif event.type == "error":
450+
# Emit an error as a final delta and stop
451+
error_chunk = {
452+
"id": response_id,
453+
"object": "chat.completion.chunk",
454+
"created": created,
455+
"model": "cairo-coder",
456+
"choices": [
457+
{
458+
"index": 0,
459+
"delta": {"content": f"\n\nError: {event.data}"},
460+
"finish_reason": "stop",
461+
}
462+
],
463+
}
464+
yield f"data: {json.dumps(error_chunk)}\n\n"
465+
break
466+
elif event.type == "end":
467+
break
468+
rt.end(outputs={"output": content_buffer})
466469

467470
except Exception as e:
468471
logger.error("Error during agent streaming", error=str(e), exc_info=True)

0 commit comments

Comments
 (0)