Skip to content

Commit

Permalink
fix(llm-observability): parallel traces (#172)
Browse files Browse the repository at this point in the history
* fix: parallel traces

* fix: linters

* chore: bump

* fix: better naming for clarity
  • Loading branch information
skoob13 authored Jan 23, 2025
1 parent 2835af4 commit 45dc933
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 18 deletions.
47 changes: 33 additions & 14 deletions posthog/ai/langchain/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,22 @@


class RunMetadata(TypedDict, total=False):
messages: Union[List[Dict[str, Any]], List[str]]
input: Any
"""Input of the run: messages, prompt variables, etc."""
name: str
"""Name of the run: chain name, model name, etc."""
provider: str
"""Provider of the run: OpenAI, Anthropic"""
model: str
"""Model used in the run"""
model_params: Dict[str, Any]
"""Model parameters of the run: temperature, max_tokens, etc."""
base_url: str
"""Base URL of the provider's API used in the run."""
start_time: float
"""Start time of the run."""
end_time: float
"""End time of the run."""


RunStorage = Dict[UUID, RunMetadata]
Expand Down Expand Up @@ -119,8 +128,7 @@ def on_chain_start(
self._log_debug_event("on_chain_start", run_id, parent_run_id, inputs=inputs)
self._set_parent_of_run(run_id, parent_run_id)
if parent_run_id is None and self._trace_name is None:
self._trace_name = self._get_langchain_run_name(serialized, **kwargs)
self._trace_input = inputs
self._set_span_metadata(run_id, self._get_langchain_run_name(serialized, **kwargs), inputs)

def on_chat_model_start(
self,
Expand All @@ -134,7 +142,7 @@ def on_chat_model_start(
self._log_debug_event("on_chat_model_start", run_id, parent_run_id, messages=messages)
self._set_parent_of_run(run_id, parent_run_id)
input = [_convert_message_to_dict(message) for row in messages for message in row]
self._set_run_metadata(serialized, run_id, input, **kwargs)
self._set_llm_metadata(serialized, run_id, input, **kwargs)

def on_llm_start(
self,
Expand All @@ -147,7 +155,7 @@ def on_llm_start(
):
self._log_debug_event("on_llm_start", run_id, parent_run_id, prompts=prompts)
self._set_parent_of_run(run_id, parent_run_id)
self._set_run_metadata(serialized, run_id, prompts, **kwargs)
self._set_llm_metadata(serialized, run_id, prompts, **kwargs)

def on_llm_new_token(
self,
Expand Down Expand Up @@ -204,7 +212,7 @@ def on_chain_end(
self._pop_parent_of_run(run_id)

if parent_run_id is None:
self._capture_trace(run_id, outputs=outputs)
self._pop_trace_and_capture(run_id, outputs=outputs)

def on_chain_error(
self,
Expand All @@ -218,7 +226,7 @@ def on_chain_error(
self._pop_parent_of_run(run_id)

if parent_run_id is None:
self._capture_trace(run_id, outputs=None)
self._pop_trace_and_capture(run_id, outputs=None)

def on_llm_end(
self,
Expand Down Expand Up @@ -253,7 +261,7 @@ def on_llm_end(
"$ai_provider": run.get("provider"),
"$ai_model": run.get("model"),
"$ai_model_parameters": run.get("model_params"),
"$ai_input": with_privacy_mode(self._client, self._privacy_mode, run.get("messages")),
"$ai_input": with_privacy_mode(self._client, self._privacy_mode, run.get("input")),
"$ai_output_choices": with_privacy_mode(self._client, self._privacy_mode, output),
"$ai_http_status": 200,
"$ai_input_tokens": input_tokens,
Expand Down Expand Up @@ -292,7 +300,7 @@ def on_llm_error(
"$ai_provider": run.get("provider"),
"$ai_model": run.get("model"),
"$ai_model_parameters": run.get("model_params"),
"$ai_input": with_privacy_mode(self._client, self._privacy_mode, run.get("messages")),
"$ai_input": with_privacy_mode(self._client, self._privacy_mode, run.get("input")),
"$ai_http_status": _get_http_status(error),
"$ai_latency": latency,
"$ai_trace_id": trace_id,
Expand Down Expand Up @@ -377,7 +385,14 @@ def _find_root_run(self, run_id: UUID) -> UUID:
id = self._parent_tree[id]
return id

def _set_run_metadata(
def _set_span_metadata(self, run_id: UUID, name: str, input: Any):
self._runs[run_id] = {
"name": name,
"input": input,
"start_time": time.time(),
}

def _set_llm_metadata(
self,
serialized: Dict[str, Any],
run_id: UUID,
Expand All @@ -387,7 +402,7 @@ def _set_run_metadata(
**kwargs,
):
run: RunMetadata = {
"messages": messages,
"input": messages,
"start_time": time.time(),
}
if isinstance(invocation_params, dict):
Expand Down Expand Up @@ -450,12 +465,16 @@ def _get_langchain_run_name(self, serialized: Optional[Dict[str, Any]], **kwargs
except (KeyError, TypeError):
pass

def _capture_trace(self, run_id: UUID, *, outputs: Optional[Dict[str, Any]]):
def _pop_trace_and_capture(self, run_id: UUID, *, outputs: Optional[Dict[str, Any]]):
trace_id = self._get_trace_id(run_id)
run = self._pop_run_metadata(run_id)
if not run:
return
event_properties = {
"$ai_trace_name": self._trace_name,
"$ai_trace_name": run.get("name"),
"$ai_trace_id": trace_id,
"$ai_input_state": with_privacy_mode(self._client, self._privacy_mode, self._trace_input),
"$ai_input_state": with_privacy_mode(self._client, self._privacy_mode, run.get("input")),
"$ai_latency": run.get("end_time", 0) - run.get("start_time", 0),
**self._properties,
}
if outputs is not None:
Expand Down
54 changes: 51 additions & 3 deletions posthog/test/ai/langchain/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
import math
import os
Expand Down Expand Up @@ -67,7 +68,7 @@ def test_metadata_capture(mock_client):
callbacks = CallbackHandler(mock_client)
run_id = uuid.uuid4()
with patch("time.time", return_value=1234567890):
callbacks._set_run_metadata(
callbacks._set_llm_metadata(
{"kwargs": {"openai_api_base": "https://us.posthog.com"}},
run_id,
messages=[{"role": "user", "content": "Who won the world series in 2020?"}],
Expand All @@ -76,7 +77,7 @@ def test_metadata_capture(mock_client):
)
expected = {
"model": "hog-mini",
"messages": [{"role": "user", "content": "Who won the world series in 2020?"}],
"input": [{"role": "user", "content": "Who won the world series in 2020?"}],
"start_time": 1234567890,
"model_params": {"temperature": 0.5},
"provider": "posthog",
Expand All @@ -90,6 +91,19 @@ def test_metadata_capture(mock_client):
callbacks._pop_run_metadata(uuid.uuid4()) # should not raise


def test_run_metadata_capture(mock_client):
callbacks = CallbackHandler(mock_client)
run_id = uuid.uuid4()
with patch("time.time", return_value=1234567890):
callbacks._set_span_metadata(run_id, "test", 1)
expected = {
"name": "test",
"input": 1,
"start_time": 1234567890,
}
assert callbacks._runs[run_id] == expected


@pytest.mark.parametrize("stream", [True, False])
def test_basic_chat_chain(mock_client, stream):
prompt = ChatPromptTemplate.from_messages(
Expand Down Expand Up @@ -514,7 +528,11 @@ def test_callbacks_logic(mock_client):
assert callbacks._parent_tree == {}

def assert_intermediary_run(m):
assert callbacks._runs == {}
assert len(callbacks._runs) != 0
run = next(iter(callbacks._runs.values()))
assert run["name"] == "RunnableSequence"
assert run["input"] == {}
assert run["start_time"] is not None
assert len(callbacks._parent_tree.items()) == 1
return [m]

Expand Down Expand Up @@ -981,3 +999,33 @@ def test_tool_calls(mock_client):
}
]
assert "additional_kwargs" not in generation_call["properties"]["$ai_output_choices"][0]


async def test_async_traces(mock_client):
async def sleep(x): # -> Any:
await asyncio.sleep(0.1)
return x

prompt = ChatPromptTemplate.from_messages([("user", "Foo")])
chain1 = RunnableLambda(sleep)
chain2 = prompt | FakeMessagesListChatModel(responses=[AIMessage(content="Bar")])

cb = CallbackHandler(mock_client)

start_time = time.time()
await asyncio.gather(
chain1.ainvoke({}, config={"callbacks": [cb]}),
chain2.ainvoke({}, config={"callbacks": [cb]}),
)
approximate_latency = math.floor(time.time() - start_time)
assert mock_client.capture.call_count == 3

first_call, second_call, third_call = mock_client.capture.call_args_list
assert first_call[1]["event"] == "$ai_generation"
assert second_call[1]["event"] == "$ai_trace"
assert second_call[1]["properties"]["$ai_trace_name"] == "RunnableSequence"
assert third_call[1]["event"] == "$ai_trace"
assert third_call[1]["properties"]["$ai_trace_name"] == "sleep"
assert (
min(approximate_latency - 1, 0) <= math.floor(third_call[1]["properties"]["$ai_latency"]) <= approximate_latency
)
2 changes: 1 addition & 1 deletion posthog/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
VERSION = "3.9.2"
VERSION = "3.9.3"

if __name__ == "__main__":
print(VERSION, end="") # noqa: T201

0 comments on commit 45dc933

Please sign in to comment.