diff --git a/CHANGELOG.md b/CHANGELOG.md index 582a1e5..35b9d0e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,11 +1,18 @@ -## 3.11.0 - 2025-01-27 +## 3.11.0 - 2025-01-28 -1. Fix serialiazation of Pydantic models in methods. +1. Add the `$ai_span` event to the LangChain callback handler to capture the input and output of intermediary chains. + > LLM observability naming change: event property `$ai_trace_name` is now `$ai_span_name`. + +2. Fix serialiazation of Pydantic models in methods. ## 3.10.0 - 2025-01-24 1. Add `$ai_error` and `$ai_is_error` properties to LangChain callback handler, OpenAI, and Anthropic. +## 3.9.3 - 2025-01-23 + +1. Fix capturing of multiple traces in the LangChain callback handler. + ## 3.9.2 - 2025-01-22 1. Fix importing of LangChain callback handler under certain circumstances. diff --git a/posthog/ai/langchain/callbacks.py b/posthog/ai/langchain/callbacks.py index d375fb5..1b865f0 100644 --- a/posthog/ai/langchain/callbacks.py +++ b/posthog/ai/langchain/callbacks.py @@ -5,14 +5,14 @@ import logging import time -import uuid +from dataclasses import dataclass from typing import ( Any, Dict, List, Optional, + Sequence, Tuple, - TypedDict, Union, cast, ) @@ -20,6 +20,7 @@ from langchain.callbacks.base import BaseCallbackHandler from langchain.schema.agent import AgentAction, AgentFinish +from langchain_core.documents import Document from langchain_core.messages import AIMessage, BaseMessage, FunctionMessage, HumanMessage, SystemMessage, ToolMessage from langchain_core.outputs import ChatGeneration, LLMResult from pydantic import BaseModel @@ -31,26 +32,38 @@ log = logging.getLogger("posthog") -class RunMetadata(TypedDict, total=False): - input: Any - """Input of the run: messages, prompt variables, etc.""" +@dataclass +class SpanMetadata: name: str """Name of the run: chain name, model name, etc.""" - provider: str + start_time: float + """Start time of the run.""" + end_time: Optional[float] + """End time of the run.""" + input: Optional[Any] + """Input of the run: messages, prompt variables, etc.""" + + @property + def latency(self) -> float: + if not self.end_time: + return 0 + return self.end_time - self.start_time + + +@dataclass +class GenerationMetadata(SpanMetadata): + provider: Optional[str] = None """Provider of the run: OpenAI, Anthropic""" - model: str + model: Optional[str] = None """Model used in the run""" - model_params: Dict[str, Any] + model_params: Optional[Dict[str, Any]] = None """Model parameters of the run: temperature, max_tokens, etc.""" - base_url: str + base_url: Optional[str] = None """Base URL of the provider's API used in the run.""" - start_time: float - """Start time of the run.""" - end_time: float - """End time of the run.""" -RunStorage = Dict[UUID, RunMetadata] +RunMetadata = Union[SpanMetadata, GenerationMetadata] +RunMetadataStorage = Dict[UUID, RunMetadata] class CallbackHandler(BaseCallbackHandler): @@ -76,7 +89,7 @@ class CallbackHandler(BaseCallbackHandler): _properties: Optional[Dict[str, Any]] """Global properties to be sent with every event.""" - _runs: RunStorage + _runs: RunMetadataStorage """Mapping of run IDs to run metadata as run metadata is only available on the start of generation.""" _parent_tree: Dict[UUID, UUID] @@ -104,11 +117,12 @@ def __init__( privacy_mode: Whether to redact the input and output of the trace. groups: Optional additional PostHog groups to use for the trace. """ - self._client = client or default_client + posthog_client = client or default_client + if posthog_client is None: + raise ValueError("PostHog client is required") + self._client = posthog_client self._distinct_id = distinct_id self._trace_id = trace_id - self._trace_name = None - self._trace_input = None self._properties = properties or {} self._privacy_mode = privacy_mode self._groups = groups or {} @@ -127,8 +141,29 @@ def on_chain_start( ): self._log_debug_event("on_chain_start", run_id, parent_run_id, inputs=inputs) self._set_parent_of_run(run_id, parent_run_id) - if parent_run_id is None and self._trace_name is None: - self._set_span_metadata(run_id, self._get_langchain_run_name(serialized, **kwargs), inputs) + self._set_trace_or_span_metadata(serialized, inputs, run_id, parent_run_id, **kwargs) + + def on_chain_end( + self, + outputs: Dict[str, Any], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ): + self._log_debug_event("on_chain_end", run_id, parent_run_id, outputs=outputs) + self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, outputs) + + def on_chain_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ): + self._log_debug_event("on_chain_error", run_id, parent_run_id, error=error) + self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, error) def on_chat_model_start( self, @@ -168,178 +203,103 @@ def on_llm_new_token( """Run on new LLM token. Only available when streaming is enabled.""" self._log_debug_event("on_llm_new_token", run_id, parent_run_id, token=token) - def on_tool_start( + def on_llm_end( self, - serialized: Optional[Dict[str, Any]], - input_str: str, + response: LLMResult, *, run_id: UUID, parent_run_id: Optional[UUID] = None, - metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, - ) -> Any: - self._log_debug_event("on_tool_start", run_id, parent_run_id, input_str=input_str) + ): + """ + The callback works for both streaming and non-streaming runs. For streaming runs, the chain must set `stream_usage=True` in the LLM. + """ + self._log_debug_event("on_llm_end", run_id, parent_run_id, response=response, kwargs=kwargs) + self._pop_run_and_capture_generation(run_id, parent_run_id, response) - def on_tool_end( + def on_llm_error( self, - output: str, + error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, - ) -> Any: - self._log_debug_event("on_tool_end", run_id, parent_run_id, output=output) + ): + self._log_debug_event("on_llm_error", run_id, parent_run_id, error=error) + self._pop_run_and_capture_generation(run_id, parent_run_id, error) - def on_tool_error( + def on_tool_start( self, - error: Union[Exception, KeyboardInterrupt], + serialized: Optional[Dict[str, Any]], + input_str: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, + metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: - self._log_debug_event("on_tool_error", run_id, parent_run_id, error=error) + self._log_debug_event("on_tool_start", run_id, parent_run_id, input_str=input_str) + self._set_trace_or_span_metadata(serialized, input_str, run_id, parent_run_id, **kwargs) - def on_chain_end( + def on_tool_end( self, - outputs: Dict[str, Any], + output: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, - ): - self._log_debug_event("on_chain_end", run_id, parent_run_id, outputs=outputs) - self._pop_parent_of_run(run_id) - - if parent_run_id is None: - self._pop_trace_and_capture(run_id, outputs=outputs) + ) -> Any: + self._log_debug_event("on_tool_end", run_id, parent_run_id, output=output) + self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, output) - def on_chain_error( + def on_tool_error( self, error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, + tags: Optional[list[str]] = None, **kwargs: Any, - ): - self._log_debug_event("on_chain_error", run_id, parent_run_id, error=error) - self._pop_parent_of_run(run_id) - - if parent_run_id is None: - self._pop_trace_and_capture(run_id, outputs=None) + ) -> Any: + self._log_debug_event("on_tool_error", run_id, parent_run_id, error=error) + self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, error) - def on_llm_end( + def on_retriever_start( self, - response: LLMResult, + serialized: Optional[Dict[str, Any]], + query: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, + metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, - ): - """ - The callback works for both streaming and non-streaming runs. For streaming runs, the chain must set `stream_usage=True` in the LLM. - """ - self._log_debug_event("on_llm_end", run_id, parent_run_id, response=response, kwargs=kwargs) - trace_id = self._get_trace_id(run_id) - self._pop_parent_of_run(run_id) - run = self._pop_run_metadata(run_id) - if not run: - return - - latency = run.get("end_time", 0) - run.get("start_time", 0) - input_tokens, output_tokens = _parse_usage(response) - - generation_result = response.generations[-1] - if isinstance(generation_result[-1], ChatGeneration): - output = [ - _convert_message_to_dict(cast(ChatGeneration, generation).message) for generation in generation_result - ] - else: - output = [_extract_raw_esponse(generation) for generation in generation_result] - - event_properties = { - "$ai_provider": run.get("provider"), - "$ai_model": run.get("model"), - "$ai_model_parameters": run.get("model_params"), - "$ai_input": with_privacy_mode(self._client, self._privacy_mode, run.get("input")), - "$ai_output_choices": with_privacy_mode(self._client, self._privacy_mode, output), - "$ai_http_status": 200, - "$ai_input_tokens": input_tokens, - "$ai_output_tokens": output_tokens, - "$ai_latency": latency, - "$ai_trace_id": trace_id, - "$ai_base_url": run.get("base_url"), - **self._properties, - } - if self._distinct_id is None: - event_properties["$process_person_profile"] = False - self._client.capture( - distinct_id=self._distinct_id or trace_id, - event="$ai_generation", - properties=event_properties, - groups=self._groups, - ) + ) -> Any: + self._log_debug_event("on_retriever_start", run_id, parent_run_id, query=query) + self._set_trace_or_span_metadata(serialized, query, run_id, parent_run_id, **kwargs) - def on_llm_error( + def on_retriever_end( self, - error: BaseException, + documents: Sequence[Document], *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, ): - self._log_debug_event("on_llm_error", run_id, parent_run_id, error=error) - trace_id = self._get_trace_id(run_id) - self._pop_parent_of_run(run_id) - run = self._pop_run_metadata(run_id) - if not run: - return - - latency = run.get("end_time", 0) - run.get("start_time", 0) - event_properties = { - "$ai_provider": run.get("provider"), - "$ai_model": run.get("model"), - "$ai_model_parameters": run.get("model_params"), - "$ai_input": with_privacy_mode(self._client, self._privacy_mode, run.get("input")), - "$ai_http_status": _get_http_status(error), - "$ai_latency": latency, - "$ai_trace_id": trace_id, - "$ai_base_url": run.get("base_url"), - "$ai_is_error": True, - "$ai_error": error.__str__(), - **self._properties, - } - if self._distinct_id is None: - event_properties["$process_person_profile"] = False - self._client.capture( - distinct_id=self._distinct_id or trace_id, - event="$ai_generation", - properties=event_properties, - groups=self._groups, - ) - - def on_retriever_start( - self, - serialized: Optional[Dict[str, Any]], - query: str, - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - metadata: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> Any: - self._log_debug_event("on_retriever_start", run_id, parent_run_id, query=query) + self._log_debug_event("on_retriever_end", run_id, parent_run_id, documents=documents) + self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, documents) def on_retriever_error( self, - error: Union[Exception, KeyboardInterrupt], + error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, + tags: Optional[list[str]] = None, **kwargs: Any, ) -> Any: """Run when Retriever errors.""" self._log_debug_event("on_retriever_error", run_id, parent_run_id, error=error) + self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, error) def on_agent_action( self, @@ -351,6 +311,8 @@ def on_agent_action( ) -> Any: """Run on agent action.""" self._log_debug_event("on_agent_action", run_id, parent_run_id, action=action) + self._set_parent_of_run(run_id, parent_run_id) + self._set_trace_or_span_metadata(None, action, run_id, parent_run_id, **kwargs) def on_agent_finish( self, @@ -361,6 +323,7 @@ def on_agent_finish( **kwargs: Any, ) -> Any: self._log_debug_event("on_agent_finish", run_id, parent_run_id, finish=finish) + self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, finish) def _set_parent_of_run(self, run_id: UUID, parent_run_id: Optional[UUID] = None): """ @@ -387,12 +350,17 @@ def _find_root_run(self, run_id: UUID) -> UUID: id = self._parent_tree[id] return id - def _set_span_metadata(self, run_id: UUID, name: str, input: Any): - self._runs[run_id] = { - "name": name, - "input": input, - "start_time": time.time(), - } + def _set_trace_or_span_metadata( + self, + serialized: Optional[Dict[str, Any]], + input: Any, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs, + ): + default_name = "trace" if parent_run_id is None else "span" + run_name = _get_langchain_run_name(serialized, **kwargs) or default_name + self._runs[run_id] = SpanMetadata(name=run_name, input=input, start_time=time.time(), end_time=None) def _set_llm_metadata( self, @@ -403,24 +371,22 @@ def _set_llm_metadata( invocation_params: Optional[Dict[str, Any]] = None, **kwargs, ): - run: RunMetadata = { - "input": messages, - "start_time": time.time(), - } + run_name = _get_langchain_run_name(serialized, **kwargs) or "generation" + generation = GenerationMetadata(name=run_name, input=messages, start_time=time.time(), end_time=None) if isinstance(invocation_params, dict): - run["model_params"] = get_model_params(invocation_params) + generation.model_params = get_model_params(invocation_params) if isinstance(metadata, dict): if model := metadata.get("ls_model_name"): - run["model"] = model + generation.model = model if provider := metadata.get("ls_provider"): - run["provider"] = provider + generation.provider = provider try: base_url = serialized["kwargs"]["openai_api_base"] if base_url is not None: - run["base_url"] = base_url + generation.base_url = base_url except KeyError: pass - self._runs[run_id] = run + self._runs[run_id] = generation def _pop_run_metadata(self, run_id: UUID) -> Optional[RunMetadata]: end_time = time.time() @@ -429,63 +395,140 @@ def _pop_run_metadata(self, run_id: UUID) -> Optional[RunMetadata]: except KeyError: log.warning(f"No run metadata found for run {run_id}") return None - run["end_time"] = end_time + run.end_time = end_time return run def _get_trace_id(self, run_id: UUID): trace_id = self._trace_id or self._find_root_run(run_id) if not trace_id: - trace_id = uuid.uuid4() + return run_id return trace_id - def _get_langchain_run_name(self, serialized: Optional[Dict[str, Any]], **kwargs: Any) -> str: - """Retrieve the name of a serialized LangChain runnable. - - The prioritization for the determination of the run name is as follows: - - The value assigned to the "name" key in `kwargs`. - - The value assigned to the "name" key in `serialized`. - - The last entry of the value assigned to the "id" key in `serialized`. - - "". + def _get_parent_run_id(self, trace_id: Any, run_id: UUID, parent_run_id: Optional[UUID]): + """ + Replace the parent run ID with the trace ID for second level runs when a custom trace ID is set. + """ + if parent_run_id is not None and parent_run_id not in self._parent_tree: + return trace_id + return parent_run_id - Args: - serialized (Optional[Dict[str, Any]]): A dictionary containing the runnable's serialized data. - **kwargs (Any): Additional keyword arguments, potentially including the 'name' override. + def _pop_run_and_capture_trace_or_span(self, run_id: UUID, parent_run_id: Optional[UUID], outputs: Any): + trace_id = self._get_trace_id(run_id) + self._pop_parent_of_run(run_id) + run = self._pop_run_metadata(run_id) + if not run: + return + if isinstance(run, GenerationMetadata): + log.warning(f"Run {run_id} is a generation, but attempted to be captured as a trace or span.") + return + self._capture_trace_or_span( + trace_id, run_id, run, outputs, self._get_parent_run_id(trace_id, run_id, parent_run_id) + ) - Returns: - str: The determined name of the Langchain runnable. - """ - if "name" in kwargs and kwargs["name"] is not None: - return kwargs["name"] + def _capture_trace_or_span( + self, + trace_id: Any, + run_id: UUID, + run: SpanMetadata, + outputs: Any, + parent_run_id: Optional[UUID], + ): + event_name = "$ai_trace" if parent_run_id is None else "$ai_span" + event_properties = { + "$ai_trace_id": trace_id, + "$ai_input_state": with_privacy_mode(self._client, self._privacy_mode, run.input), + "$ai_latency": run.latency, + "$ai_span_name": run.name, + "$ai_span_id": run_id, + } + if parent_run_id is not None: + event_properties["$ai_parent_id"] = parent_run_id + if self._properties: + event_properties.update(self._properties) + + if isinstance(outputs, BaseException): + event_properties["$ai_error"] = _stringify_exception(outputs) + event_properties["$ai_is_error"] = True + elif outputs is not None: + event_properties["$ai_output_state"] = with_privacy_mode(self._client, self._privacy_mode, outputs) - try: - return serialized["name"] - except (KeyError, TypeError): - pass + if self._distinct_id is None: + event_properties["$process_person_profile"] = False - try: - return serialized["id"][-1] - except (KeyError, TypeError): - pass + self._client.capture( + distinct_id=self._distinct_id or run_id, + event=event_name, + properties=event_properties, + groups=self._groups, + ) - def _pop_trace_and_capture(self, run_id: UUID, *, outputs: Optional[Dict[str, Any]]): + def _pop_run_and_capture_generation( + self, run_id: UUID, parent_run_id: Optional[UUID], response: Union[LLMResult, BaseException] + ): trace_id = self._get_trace_id(run_id) + self._pop_parent_of_run(run_id) run = self._pop_run_metadata(run_id) if not run: return + if not isinstance(run, GenerationMetadata): + log.warning(f"Run {run_id} is not a generation, but attempted to be captured as a generation.") + return + self._capture_generation( + trace_id, run_id, run, response, self._get_parent_run_id(trace_id, run_id, parent_run_id) + ) + + def _capture_generation( + self, + trace_id: Any, + run_id: UUID, + run: GenerationMetadata, + output: Union[LLMResult, BaseException], + parent_run_id: Optional[UUID] = None, + ): event_properties = { - "$ai_trace_name": run.get("name"), "$ai_trace_id": trace_id, - "$ai_input_state": with_privacy_mode(self._client, self._privacy_mode, run.get("input")), - "$ai_latency": run.get("end_time", 0) - run.get("start_time", 0), - **self._properties, + "$ai_span_id": run_id, + "$ai_span_name": run.name, + "$ai_parent_id": parent_run_id, + "$ai_provider": run.provider, + "$ai_model": run.model, + "$ai_model_parameters": run.model_params, + "$ai_input": with_privacy_mode(self._client, self._privacy_mode, run.input), + "$ai_http_status": 200, + "$ai_latency": run.latency, + "$ai_base_url": run.base_url, } - if outputs is not None: - event_properties["$ai_output_state"] = with_privacy_mode(self._client, self._privacy_mode, outputs) + + if isinstance(output, BaseException): + event_properties["$ai_http_status"] = _get_http_status(output) + event_properties["$ai_error"] = _stringify_exception(output) + event_properties["$ai_is_error"] = True + else: + # Add usage + input_tokens, output_tokens = _parse_usage(output) + event_properties["$ai_input_tokens"] = input_tokens + event_properties["$ai_output_tokens"] = output_tokens + + # Generation results + generation_result = output.generations[-1] + if isinstance(generation_result[-1], ChatGeneration): + completions = [ + _convert_message_to_dict(cast(ChatGeneration, generation).message) + for generation in generation_result + ] + else: + completions = [_extract_raw_esponse(generation) for generation in generation_result] + event_properties["$ai_output_choices"] = with_privacy_mode(self._client, self._privacy_mode, completions) + + if self._properties: + event_properties.update(self._properties) + if self._distinct_id is None: event_properties["$process_person_profile"] = False + self._client.capture( distinct_id=self._distinct_id or trace_id, - event="$ai_trace", + event="$ai_generation", properties=event_properties, groups=self._groups, ) @@ -616,3 +659,41 @@ def _get_http_status(error: BaseException) -> int: # Google: https://github.com/googleapis/python-api-core/blob/main/google/api_core/exceptions.py status_code = getattr(error, "status_code", getattr(error, "code", 0)) return status_code + + +def _get_langchain_run_name(serialized: Optional[Dict[str, Any]], **kwargs: Any) -> Optional[str]: + """Retrieve the name of a serialized LangChain runnable. + + The prioritization for the determination of the run name is as follows: + - The value assigned to the "name" key in `kwargs`. + - The value assigned to the "name" key in `serialized`. + - The last entry of the value assigned to the "id" key in `serialized`. + - "". + + Args: + serialized (Optional[Dict[str, Any]]): A dictionary containing the runnable's serialized data. + **kwargs (Any): Additional keyword arguments, potentially including the 'name' override. + + Returns: + str: The determined name of the Langchain runnable. + """ + if "name" in kwargs and kwargs["name"] is not None: + return kwargs["name"] + if serialized is None: + return None + try: + return serialized["name"] + except (KeyError, TypeError): + pass + try: + return serialized["id"][-1] + except (KeyError, TypeError): + pass + return None + + +def _stringify_exception(exception: BaseException) -> str: + description = str(exception) + if description: + return f"{exception.__class__.__name__}: {description}" + return exception.__class__.__name__ diff --git a/posthog/test/ai/langchain/test_callbacks.py b/posthog/test/ai/langchain/test_callbacks.py index a994bf5..ca63b51 100644 --- a/posthog/test/ai/langchain/test_callbacks.py +++ b/posthog/test/ai/langchain/test_callbacks.py @@ -4,7 +4,7 @@ import os import time import uuid -from typing import List, Optional, TypedDict, Union +from typing import List, Literal, Optional, TypedDict, Union from unittest.mock import patch import pytest @@ -14,10 +14,13 @@ from langchain_core.messages import AIMessage, HumanMessage from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnableLambda +from langchain_core.tools import tool from langchain_openai.chat_models import ChatOpenAI from langgraph.graph.state import END, START, StateGraph +from langgraph.prebuilt import create_react_agent from posthog.ai.langchain import CallbackHandler +from posthog.ai.langchain.callbacks import GenerationMetadata, SpanMetadata OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") @@ -74,19 +77,23 @@ def test_metadata_capture(mock_client): messages=[{"role": "user", "content": "Who won the world series in 2020?"}], invocation_params={"temperature": 0.5}, metadata={"ls_model_name": "hog-mini", "ls_provider": "posthog"}, + name="test", ) - expected = { - "model": "hog-mini", - "input": [{"role": "user", "content": "Who won the world series in 2020?"}], - "start_time": 1234567890, - "model_params": {"temperature": 0.5}, - "provider": "posthog", - "base_url": "https://us.posthog.com", - } + expected = GenerationMetadata( + model="hog-mini", + input=[{"role": "user", "content": "Who won the world series in 2020?"}], + start_time=1234567890, + model_params={"temperature": 0.5}, + provider="posthog", + base_url="https://us.posthog.com", + name="test", + end_time=None, + ) assert callbacks._runs[run_id] == expected with patch("time.time", return_value=1234567891): run = callbacks._pop_run_metadata(run_id) - assert run == {**expected, "end_time": 1234567891} + expected.end_time = 1234567891 + assert run == expected assert callbacks._runs == {} callbacks._pop_run_metadata(uuid.uuid4()) # should not raise @@ -95,12 +102,32 @@ def test_run_metadata_capture(mock_client): callbacks = CallbackHandler(mock_client) run_id = uuid.uuid4() with patch("time.time", return_value=1234567890): - callbacks._set_span_metadata(run_id, "test", 1) - expected = { - "name": "test", - "input": 1, - "start_time": 1234567890, - } + callbacks._set_trace_or_span_metadata(None, 1, run_id) + expected = SpanMetadata( + name="trace", + input=1, + start_time=1234567890, + end_time=None, + ) + assert callbacks._runs[run_id] == expected + with patch("time.time", return_value=1234567890): + callbacks._set_trace_or_span_metadata(None, 1, run_id, uuid.uuid4()) + expected = SpanMetadata( + name="span", + input=1, + start_time=1234567890, + end_time=None, + ) + assert callbacks._runs[run_id] == expected + + with patch("time.time", return_value=1234567890): + callbacks._set_trace_or_span_metadata({"name": "test"}, 1, run_id) + expected = SpanMetadata( + name="test", + input=1, + start_time=1234567890, + end_time=None, + ) assert callbacks._runs[run_id] == expected @@ -132,11 +159,24 @@ def test_basic_chat_chain(mock_client, stream): result = chain.invoke({}, config={"callbacks": callbacks}) assert result.content == "The Los Angeles Dodgers won the World Series in 2020." - assert mock_client.capture.call_count == 2 - generation_args = mock_client.capture.call_args_list[0][1] + assert mock_client.capture.call_count == 3 + + span_args = mock_client.capture.call_args_list[0][1] + span_props = span_args["properties"] + + generation_args = mock_client.capture.call_args_list[1][1] generation_props = generation_args["properties"] - trace_args = mock_client.capture.call_args_list[1][1] + trace_args = mock_client.capture.call_args_list[2][1] + trace_props = trace_args["properties"] + + # Span is first + assert span_args["event"] == "$ai_span" + assert span_props["$ai_trace_id"] == generation_props["$ai_trace_id"] + assert span_props["$ai_parent_id"] == trace_props["$ai_trace_id"] + assert "$ai_span_id" in span_props + + # Generation is second assert generation_args["event"] == "$ai_generation" assert "distinct_id" in generation_args assert "$ai_model" in generation_props @@ -154,9 +194,16 @@ def test_basic_chat_chain(mock_client, stream): assert generation_props["$ai_input_tokens"] == 10 assert generation_props["$ai_output_tokens"] == 10 assert generation_props["$ai_http_status"] == 200 - assert generation_props["$ai_trace_id"] is not None assert isinstance(generation_props["$ai_latency"], float) + assert "$ai_span_id" in generation_props + assert generation_props["$ai_parent_id"] == trace_props["$ai_trace_id"] + assert generation_props["$ai_trace_id"] == trace_props["$ai_trace_id"] + assert generation_props["$ai_span_name"] == "FakeMessagesListChatModel" + + # Trace is last assert trace_args["event"] == "$ai_trace" + assert "$ai_trace_id" in trace_props + assert "$ai_parent_id" not in trace_props @pytest.mark.parametrize("stream", [True, False]) @@ -186,13 +233,22 @@ async def test_async_basic_chat_chain(mock_client, stream): else: result = await chain.ainvoke({}, config={"callbacks": callbacks}) assert result.content == "The Los Angeles Dodgers won the World Series in 2020." - assert mock_client.capture.call_count == 2 + assert mock_client.capture.call_count == 3 - generation_args = mock_client.capture.call_args_list[0][1] + span_args = mock_client.capture.call_args_list[0][1] + span_props = span_args["properties"] + generation_args = mock_client.capture.call_args_list[1][1] generation_props = generation_args["properties"] - trace_args = mock_client.capture.call_args_list[1][1] + trace_args = mock_client.capture.call_args_list[2][1] trace_props = trace_args["properties"] + # Span is first + assert span_args["event"] == "$ai_span" + assert span_props["$ai_trace_id"] == generation_props["$ai_trace_id"] + assert span_props["$ai_parent_id"] == trace_props["$ai_trace_id"] + assert "$ai_span_id" in span_props + + # Generation is second assert generation_args["event"] == "$ai_generation" assert "distinct_id" in generation_args assert "$ai_model" in generation_props @@ -210,12 +266,16 @@ async def test_async_basic_chat_chain(mock_client, stream): assert generation_props["$ai_input_tokens"] == 10 assert generation_props["$ai_output_tokens"] == 10 assert generation_props["$ai_http_status"] == 200 - assert generation_props["$ai_trace_id"] is not None assert isinstance(generation_props["$ai_latency"], float) + assert "$ai_span_id" in generation_props + assert generation_props["$ai_parent_id"] == trace_props["$ai_trace_id"] + assert generation_props["$ai_trace_id"] == trace_props["$ai_trace_id"] + # Trace is last assert trace_args["event"] == "$ai_trace" assert "distinct_id" in generation_args assert trace_props["$ai_trace_id"] == generation_props["$ai_trace_id"] + assert "$ai_parent_id" not in trace_props @pytest.mark.parametrize( @@ -290,34 +350,68 @@ async def test_async_basic_llm_chain(mock_client, Model, stream): assert isinstance(props["$ai_latency"], float) -def test_trace_id_for_multiple_chains(mock_client): +def test_trace_id_and_inputs_for_multiple_chains(mock_client): prompt = ChatPromptTemplate.from_messages( [ - ("user", "Foo"), + ("user", "Foo {var}"), ] ) model = FakeMessagesListChatModel(responses=[AIMessage(content="Bar")]) callbacks = [CallbackHandler(mock_client)] chain = prompt | model | RunnableLambda(lambda x: [x]) | model - result = chain.invoke({}, config={"callbacks": callbacks}) + result = chain.invoke({"var": "bar"}, config={"callbacks": callbacks}) assert result.content == "Bar" - assert mock_client.capture.call_count == 3 + # span, generation, span, generation, trace + assert mock_client.capture.call_count == 5 + + first_span_args = mock_client.capture.call_args_list[0][1] + first_span_props = first_span_args["properties"] - first_call_args = mock_client.capture.call_args_list[0][1] - first_call_props = first_call_args["properties"] - assert first_call_args["event"] == "$ai_generation" - assert "distinct_id" in first_call_args - assert "$ai_model" in first_call_props - assert "$ai_provider" in first_call_props - assert first_call_props["$ai_input"] == [{"role": "user", "content": "Foo"}] - assert first_call_props["$ai_output_choices"] == [{"role": "assistant", "content": "Bar"}] - assert first_call_props["$ai_http_status"] == 200 - assert first_call_props["$ai_trace_id"] is not None - assert isinstance(first_call_props["$ai_latency"], float) - - second_generation_args = mock_client.capture.call_args_list[1][1] + first_generation_args = mock_client.capture.call_args_list[1][1] + first_generation_props = first_generation_args["properties"] + + second_span_args = mock_client.capture.call_args_list[2][1] + second_span_props = second_span_args["properties"] + + second_generation_args = mock_client.capture.call_args_list[3][1] second_generation_props = second_generation_args["properties"] + + trace_args = mock_client.capture.call_args_list[4][1] + trace_props = trace_args["properties"] + + # Prompt span + assert first_span_args["event"] == "$ai_span" + assert first_span_props["$ai_input_state"] == {"var": "bar"} + assert first_span_props["$ai_trace_id"] == trace_props["$ai_trace_id"] + assert first_span_props["$ai_parent_id"] == trace_props["$ai_trace_id"] + assert "$ai_span_id" in first_span_props + assert first_span_props["$ai_output_state"] == ChatPromptTemplate( + messages=[HumanMessage(content="Foo bar")] + ).invoke({}) + + # first model + assert first_generation_args["event"] == "$ai_generation" + assert "distinct_id" in first_generation_args + assert "$ai_model" in first_generation_props + assert "$ai_provider" in first_generation_props + assert first_generation_props["$ai_input"] == [{"role": "user", "content": "Foo bar"}] + assert first_generation_props["$ai_output_choices"] == [{"role": "assistant", "content": "Bar"}] + assert first_generation_props["$ai_http_status"] == 200 + assert isinstance(first_generation_props["$ai_latency"], float) + assert "$ai_span_id" in first_generation_props + assert first_generation_props["$ai_parent_id"] == trace_props["$ai_trace_id"] + assert first_generation_props["$ai_trace_id"] == trace_props["$ai_trace_id"] + + # lambda span + assert second_span_args["event"] == "$ai_span" + assert second_span_props["$ai_input_state"].content == "Bar" + assert second_span_props["$ai_trace_id"] == trace_props["$ai_trace_id"] + assert second_span_props["$ai_parent_id"] == trace_props["$ai_trace_id"] + assert "$ai_span_id" in second_span_props + assert second_span_props["$ai_output_state"][0].content == "Bar" + + # second model assert second_generation_args["event"] == "$ai_generation" assert "distinct_id" in second_generation_args assert "$ai_model" in second_generation_props @@ -328,40 +422,49 @@ def test_trace_id_for_multiple_chains(mock_client): assert second_generation_props["$ai_trace_id"] is not None assert isinstance(second_generation_props["$ai_latency"], float) - trace_args = mock_client.capture.call_args_list[2][1] - trace_props = trace_args["properties"] + # trace assert trace_args["event"] == "$ai_trace" assert "distinct_id" in trace_args - assert trace_props["$ai_input_state"] == {} + assert trace_props["$ai_input_state"] == {"var": "bar"} assert isinstance(trace_props["$ai_output_state"], AIMessage) assert trace_props["$ai_output_state"].content == "Bar" assert trace_props["$ai_trace_id"] is not None - assert trace_props["$ai_trace_name"] == "RunnableSequence" - - # Check that the trace_id is the same as the first call - assert first_call_props["$ai_trace_id"] == second_generation_props["$ai_trace_id"] - assert first_call_props["$ai_trace_id"] == trace_props["$ai_trace_id"] + assert trace_props["$ai_span_name"] == "RunnableSequence" def test_personless_mode(mock_client): prompt = ChatPromptTemplate.from_messages([("user", "Foo")]) chain = prompt | FakeMessagesListChatModel(responses=[AIMessage(content="Bar")]) chain.invoke({}, config={"callbacks": [CallbackHandler(mock_client)]}) - assert mock_client.capture.call_count == 2 - generation_args = mock_client.capture.call_args_list[0][1] - trace_args = mock_client.capture.call_args_list[1][1] + assert mock_client.capture.call_count == 3 + span_args = mock_client.capture.call_args_list[0][1] + generation_args = mock_client.capture.call_args_list[1][1] + trace_args = mock_client.capture.call_args_list[2][1] + + # span + assert span_args["event"] == "$ai_span" + assert span_args["properties"]["$process_person_profile"] is False + # generation assert generation_args["event"] == "$ai_generation" assert generation_args["properties"]["$process_person_profile"] is False + # trace assert trace_args["event"] == "$ai_trace" assert trace_args["properties"]["$process_person_profile"] is False id = uuid.uuid4() chain.invoke({}, config={"callbacks": [CallbackHandler(mock_client, distinct_id=id)]}) - assert mock_client.capture.call_count == 4 - generation_args = mock_client.capture.call_args_list[2][1] - trace_args = mock_client.capture.call_args_list[3][1] + assert mock_client.capture.call_count == 6 + span_args = mock_client.capture.call_args_list[3][1] + generation_args = mock_client.capture.call_args_list[4][1] + trace_args = mock_client.capture.call_args_list[5][1] + + # span + assert "$process_person_profile" not in span_args["properties"] + assert span_args["distinct_id"] == id + # generation assert "$process_person_profile" not in generation_args["properties"] assert generation_args["distinct_id"] == id + # trace assert "$process_person_profile" not in trace_args["properties"] assert trace_args["distinct_id"] == id @@ -372,22 +475,41 @@ def test_personless_mode_exception(mock_client): callbacks = CallbackHandler(mock_client) with pytest.raises(Exception): chain.invoke({}, config={"callbacks": [callbacks]}) - assert mock_client.capture.call_count == 2 - generation_args = mock_client.capture.call_args_list[0][1] - trace_args = mock_client.capture.call_args_list[1][1] + assert mock_client.capture.call_count == 3 + span_args = mock_client.capture.call_args_list[0][1] + generation_args = mock_client.capture.call_args_list[1][1] + trace_args = mock_client.capture.call_args_list[2][1] + + # span + assert span_args["event"] == "$ai_span" + assert span_args["properties"]["$process_person_profile"] is False + # generation assert generation_args["event"] == "$ai_generation" assert generation_args["properties"]["$process_person_profile"] is False + # trace assert trace_args["event"] == "$ai_trace" assert trace_args["properties"]["$process_person_profile"] is False id = uuid.uuid4() with pytest.raises(Exception): chain.invoke({}, config={"callbacks": [CallbackHandler(mock_client, distinct_id=id)]}) - assert mock_client.capture.call_count == 4 - generation_args = mock_client.capture.call_args_list[2][1] - trace_args = mock_client.capture.call_args_list[3][1] + assert mock_client.capture.call_count == 6 + span_args = mock_client.capture.call_args_list[3][1] + generation_args = mock_client.capture.call_args_list[4][1] + trace_args = mock_client.capture.call_args_list[5][1] + + # span + assert span_args["event"] == "$ai_span" + assert "$process_person_profile" not in span_args["properties"] + assert span_args["distinct_id"] == id + + # generation + assert generation_args["event"] == "$ai_generation" assert "$process_person_profile" not in generation_args["properties"] assert generation_args["distinct_id"] == id + + # trace + assert trace_args["event"] == "$ai_trace" assert "$process_person_profile" not in trace_args["properties"] assert trace_args["distinct_id"] == id @@ -411,9 +533,18 @@ def test_metadata(mock_client): result = chain.invoke({"plan": None}, config={"callbacks": callbacks}) assert result.content == "Bar" - assert mock_client.capture.call_count == 2 + assert mock_client.capture.call_count == 3 + + span_call_args = mock_client.capture.call_args_list[0][1] + span_call_props = span_call_args["properties"] + assert span_call_args["distinct_id"] == "test_id" + assert span_call_args["event"] == "$ai_span" + assert span_call_props["$ai_trace_id"] == "test-trace-id" + assert span_call_props["foo"] == "bar" + assert "$ai_parent_id" in span_call_props + assert "$ai_span_id" in span_call_props - generation_call_args = mock_client.capture.call_args_list[0][1] + generation_call_args = mock_client.capture.call_args_list[1][1] generation_call_props = generation_call_args["properties"] assert generation_call_args["distinct_id"] == "test_id" assert generation_call_args["event"] == "$ai_generation" @@ -424,12 +555,12 @@ def test_metadata(mock_client): assert generation_call_props["$ai_http_status"] == 200 assert isinstance(generation_call_props["$ai_latency"], float) - trace_call_args = mock_client.capture.call_args_list[1][1] + trace_call_args = mock_client.capture.call_args_list[2][1] trace_call_props = trace_call_args["properties"] assert trace_call_args["distinct_id"] == "test_id" assert trace_call_args["event"] == "$ai_trace" assert trace_call_props["$ai_trace_id"] == "test-trace-id" - assert trace_call_props["$ai_trace_name"] == "RunnableSequence" + assert trace_call_props["$ai_span_name"] == "RunnableSequence" assert trace_call_props["foo"] == "bar" assert trace_call_props["$ai_input_state"] == {"plan": None} assert isinstance(trace_call_props["$ai_output_state"], AIMessage) @@ -476,10 +607,8 @@ def test_graph_state(mock_client): graph.add_edge("fake_plain", "fake_llm") graph.add_edge("fake_llm", END) - result = graph.compile().invoke( - {"messages": [HumanMessage(content="What's a bar?")], "xyz": None}, - config=config, - ) + initial_state = {"messages": [HumanMessage(content="What's a bar?")], "xyz": None} + result = graph.compile().invoke(initial_state, config=config) assert len(result["messages"]) == 3 assert isinstance(result["messages"][0], HumanMessage) @@ -489,26 +618,104 @@ def test_graph_state(mock_client): assert isinstance(result["messages"][2], AIMessage) assert result["messages"][2].content == "It's a type of greeble." - assert mock_client.capture.call_count == 2 - generation_args = mock_client.capture.call_args_list[0][1] - trace_args = mock_client.capture.call_args_list[1][1] - assert generation_args["event"] == "$ai_generation" + assert mock_client.capture.call_count == 11 + calls = [call[1] for call in mock_client.capture.call_args_list] + + trace_args = calls[10] + trace_props = calls[10]["properties"] + + # Events are captured in the reverse order. + # Check all trace_ids + for call in calls: + assert call["properties"]["$ai_trace_id"] == trace_props["$ai_trace_id"] + + # First span, write the state + assert calls[0]["event"] == "$ai_span" + assert calls[0]["properties"]["$ai_parent_id"] == calls[2]["properties"]["$ai_span_id"] + assert "$ai_span_id" in calls[0]["properties"] + assert calls[0]["properties"]["$ai_input_state"] == initial_state + assert calls[0]["properties"]["$ai_output_state"] == initial_state + + # Second span, set the START node + assert calls[1]["event"] == "$ai_span" + assert calls[1]["properties"]["$ai_parent_id"] == calls[2]["properties"]["$ai_span_id"] + assert "$ai_span_id" in calls[1]["properties"] + assert calls[1]["properties"]["$ai_input_state"] == initial_state + assert calls[1]["properties"]["$ai_output_state"] == initial_state + + # Third span, finish initialization + assert calls[2]["event"] == "$ai_span" + assert "$ai_span_id" in calls[2]["properties"] + assert calls[2]["properties"]["$ai_span_name"] == START + assert calls[2]["properties"]["$ai_parent_id"] == trace_props["$ai_trace_id"] + assert calls[2]["properties"]["$ai_input_state"] == initial_state + assert calls[2]["properties"]["$ai_output_state"] == initial_state + + # Fourth span, save the value of fake_plain during its execution + second_state = { + "messages": [HumanMessage(content="What's a bar?"), AIMessage(content="Let's explore bar.")], + "xyz": "abc", + } + assert calls[3]["event"] == "$ai_span" + assert calls[3]["properties"]["$ai_parent_id"] == calls[4]["properties"]["$ai_span_id"] + assert "$ai_span_id" in calls[3]["properties"] + assert calls[3]["properties"]["$ai_input_state"] == second_state + assert calls[3]["properties"]["$ai_output_state"] == second_state + + # Fifth span, run the fake_plain node + assert calls[4]["event"] == "$ai_span" + assert "$ai_span_id" in calls[4]["properties"] + assert calls[4]["properties"]["$ai_span_name"] == "fake_plain" + assert calls[4]["properties"]["$ai_parent_id"] == trace_props["$ai_trace_id"] + assert calls[4]["properties"]["$ai_input_state"] == initial_state + assert calls[4]["properties"]["$ai_output_state"] == second_state + + # Sixth span, chat prompt template + assert calls[5]["event"] == "$ai_span" + assert calls[5]["properties"]["$ai_parent_id"] == calls[7]["properties"]["$ai_span_id"] + assert "$ai_span_id" in calls[5]["properties"] + assert calls[5]["properties"]["$ai_span_name"] == "ChatPromptTemplate" + + # 7. Generation, fake_llm + assert calls[6]["event"] == "$ai_generation" + assert calls[6]["properties"]["$ai_parent_id"] == calls[7]["properties"]["$ai_span_id"] + assert "$ai_span_id" in calls[6]["properties"] + assert calls[6]["properties"]["$ai_span_name"] == "FakeMessagesListChatModel" + + # 8. Span, RunnableSequence + assert calls[7]["event"] == "$ai_span" + assert calls[7]["properties"]["$ai_parent_id"] == calls[9]["properties"]["$ai_span_id"] + assert "$ai_span_id" in calls[7]["properties"] + assert calls[7]["properties"]["$ai_span_name"] == "RunnableSequence" + + # 9. Span, fake_llm write + assert calls[8]["event"] == "$ai_span" + assert calls[8]["properties"]["$ai_parent_id"] == calls[9]["properties"]["$ai_span_id"] + assert "$ai_span_id" in calls[8]["properties"] + + # 10. Span, fake_llm node + assert calls[9]["event"] == "$ai_span" + assert calls[9]["properties"]["$ai_parent_id"] == trace_props["$ai_trace_id"] + assert "$ai_span_id" in calls[9]["properties"] + assert calls[9]["properties"]["$ai_span_name"] == "fake_llm" + + # 11. Trace assert trace_args["event"] == "$ai_trace" - assert trace_args["properties"]["$ai_trace_name"] == "LangGraph" - - assert len(trace_args["properties"]["$ai_input_state"]["messages"]) == 1 - assert isinstance(trace_args["properties"]["$ai_input_state"]["messages"][0], HumanMessage) - assert trace_args["properties"]["$ai_input_state"]["messages"][0].content == "What's a bar?" - assert trace_args["properties"]["$ai_input_state"]["messages"][0].type == "human" - assert trace_args["properties"]["$ai_input_state"]["xyz"] is None - assert len(trace_args["properties"]["$ai_output_state"]["messages"]) == 3 - - assert isinstance(trace_args["properties"]["$ai_output_state"]["messages"][0], HumanMessage) - assert trace_args["properties"]["$ai_output_state"]["messages"][0].content == "What's a bar?" - assert isinstance(trace_args["properties"]["$ai_output_state"]["messages"][1], AIMessage) - assert trace_args["properties"]["$ai_output_state"]["messages"][1].content == "Let's explore bar." - assert isinstance(trace_args["properties"]["$ai_output_state"]["messages"][2], AIMessage) - assert trace_args["properties"]["$ai_output_state"]["messages"][2].content == "It's a type of greeble." + assert trace_props["$ai_span_name"] == "LangGraph" + + assert len(trace_props["$ai_input_state"]["messages"]) == 1 + assert isinstance(trace_props["$ai_input_state"]["messages"][0], HumanMessage) + assert trace_props["$ai_input_state"]["messages"][0].content == "What's a bar?" + assert trace_props["$ai_input_state"]["messages"][0].type == "human" + assert trace_props["$ai_input_state"]["xyz"] is None + assert len(trace_props["$ai_output_state"]["messages"]) == 3 + + assert isinstance(trace_props["$ai_output_state"]["messages"][0], HumanMessage) + assert trace_props["$ai_output_state"]["messages"][0].content == "What's a bar?" + assert isinstance(trace_props["$ai_output_state"]["messages"][1], AIMessage) + assert trace_props["$ai_output_state"]["messages"][1].content == "Let's explore bar." + assert isinstance(trace_props["$ai_output_state"]["messages"][2], AIMessage) + assert trace_props["$ai_output_state"]["messages"][2].content == "It's a type of greeble." assert trace_args["properties"]["$ai_output_state"]["xyz"] == "abc" @@ -530,9 +737,9 @@ def test_callbacks_logic(mock_client): def assert_intermediary_run(m): assert len(callbacks._runs) != 0 run = next(iter(callbacks._runs.values())) - assert run["name"] == "RunnableSequence" - assert run["input"] == {} - assert run["start_time"] is not None + assert run.name == "RunnableSequence" + assert run.input == {} + assert run.start_time is not None assert len(callbacks._parent_tree.items()) == 1 return [m] @@ -554,7 +761,7 @@ def runnable(_): assert mock_client.capture.call_count == 1 trace_call_args = mock_client.capture.call_args_list[0][1] assert trace_call_args["event"] == "$ai_trace" - assert trace_call_args["properties"]["$ai_trace_name"] == "runnable" + assert trace_call_args["properties"]["$ai_span_name"] == "runnable" def test_openai_error(mock_client): @@ -568,8 +775,8 @@ def test_openai_error(mock_client): assert callbacks._runs == {} assert callbacks._parent_tree == {} - assert mock_client.capture.call_count == 2 - generation_args = mock_client.capture.call_args_list[0][1] + assert mock_client.capture.call_count == 3 + generation_args = mock_client.capture.call_args_list[1][1] props = generation_args["properties"] assert props["$ai_http_status"] == 401 assert props["$ai_input"] == [{"role": "user", "content": "Foo"}] @@ -601,40 +808,40 @@ def test_openai_chain(mock_client): approximate_latency = math.floor(time.time() - start_time) assert result.content == "Bar" - assert mock_client.capture.call_count == 2 + assert mock_client.capture.call_count == 3 - first_call_args = mock_client.capture.call_args_list[0][1] - first_call_props = first_call_args["properties"] - assert first_call_args["event"] == "$ai_generation" - assert first_call_props["$ai_trace_id"] == "test-trace-id" - assert first_call_props["$ai_provider"] == "openai" - assert first_call_props["$ai_model"] == "gpt-4o-mini" - assert first_call_props["foo"] == "bar" + gen_args = mock_client.capture.call_args_list[1][1] + gen_props = gen_args["properties"] + assert gen_args["event"] == "$ai_generation" + assert gen_props["$ai_trace_id"] == "test-trace-id" + assert gen_props["$ai_provider"] == "openai" + assert gen_props["$ai_model"] == "gpt-4o-mini" + assert gen_props["foo"] == "bar" # langchain-openai for langchain v3 - if "max_completion_tokens" in first_call_props["$ai_model_parameters"]: - assert first_call_props["$ai_model_parameters"] == { + if "max_completion_tokens" in gen_props["$ai_model_parameters"]: + assert gen_props["$ai_model_parameters"] == { "temperature": 0.0, "max_completion_tokens": 1, "stream": False, } else: - assert first_call_props["$ai_model_parameters"] == { + assert gen_props["$ai_model_parameters"] == { "temperature": 0.0, "max_tokens": 1, "n": 1, "stream": False, } - assert first_call_props["$ai_input"] == [ + assert gen_props["$ai_input"] == [ {"role": "system", "content": 'You must always answer with "Bar".'}, {"role": "user", "content": "Foo"}, ] - assert first_call_props["$ai_output_choices"] == [{"role": "assistant", "content": "Bar", "refusal": None}] - assert first_call_props["$ai_http_status"] == 200 - assert isinstance(first_call_props["$ai_latency"], float) - assert min(approximate_latency - 1, 0) <= math.floor(first_call_props["$ai_latency"]) <= approximate_latency - assert first_call_props["$ai_input_tokens"] == 20 - assert first_call_props["$ai_output_tokens"] == 1 + assert gen_props["$ai_output_choices"] == [{"role": "assistant", "content": "Bar", "refusal": None}] + assert gen_props["$ai_http_status"] == 200 + assert isinstance(gen_props["$ai_latency"], float) + assert min(approximate_latency - 1, 0) <= math.floor(gen_props["$ai_latency"]) <= approximate_latency + assert gen_props["$ai_input_tokens"] == 20 + assert gen_props["$ai_output_tokens"] == 1 @pytest.mark.skipif(not OPENAI_API_KEY, reason="OpenAI API key not set") @@ -656,19 +863,19 @@ def test_openai_captures_multiple_generations(mock_client): result = chain.invoke({}, config={"callbacks": [callbacks]}) assert result.content == "Bar" - assert mock_client.capture.call_count == 2 + assert mock_client.capture.call_count == 3 - first_call_args = mock_client.capture.call_args_list[0][1] - first_call_props = first_call_args["properties"] - second_call_args = mock_client.capture.call_args_list[1][1] - second_call_props = second_call_args["properties"] + gen_args = mock_client.capture.call_args_list[1][1] + gen_props = gen_args["properties"] + trace_args = mock_client.capture.call_args_list[2][1] + trace_props = trace_args["properties"] - assert first_call_args["event"] == "$ai_generation" - assert first_call_props["$ai_input"] == [ + assert gen_args["event"] == "$ai_generation" + assert gen_props["$ai_input"] == [ {"role": "system", "content": 'You must always answer with "Bar".'}, {"role": "user", "content": "Foo"}, ] - assert first_call_props["$ai_output_choices"] == [ + assert gen_props["$ai_output_choices"] == [ {"role": "assistant", "content": "Bar", "refusal": None}, { "role": "assistant", @@ -677,25 +884,25 @@ def test_openai_captures_multiple_generations(mock_client): ] # langchain-openai for langchain v3 - if "max_completion_tokens" in first_call_props["$ai_model_parameters"]: - assert first_call_props["$ai_model_parameters"] == { + if "max_completion_tokens" in gen_props["$ai_model_parameters"]: + assert gen_props["$ai_model_parameters"] == { "temperature": 0.0, "max_completion_tokens": 1, "stream": False, "n": 2, } else: - assert first_call_props["$ai_model_parameters"] == { + assert gen_props["$ai_model_parameters"] == { "temperature": 0.0, "max_tokens": 1, "stream": False, "n": 2, } - assert first_call_props["$ai_http_status"] == 200 + assert gen_props["$ai_http_status"] == 200 - assert second_call_args["event"] == "$ai_trace" - assert second_call_props["$ai_input_state"] == {} - assert isinstance(second_call_props["$ai_output_state"], AIMessage) + assert trace_args["event"] == "$ai_trace" + assert trace_props["$ai_input_state"] == {} + assert isinstance(trace_props["$ai_output_state"], AIMessage) @pytest.mark.skipif(not OPENAI_API_KEY, reason="OpenAI API key not set") @@ -719,27 +926,27 @@ def test_openai_streaming(mock_client): result = sum(result[1:], result[0]) assert result.content == "Bar" - assert mock_client.capture.call_count == 2 + assert mock_client.capture.call_count == 3 - first_call_args = mock_client.capture.call_args_list[0][1] - first_call_props = first_call_args["properties"] - second_call_args = mock_client.capture.call_args_list[1][1] - second_call_props = second_call_args["properties"] + gen_args = mock_client.capture.call_args_list[1][1] + gen_props = gen_args["properties"] + trace_args = mock_client.capture.call_args_list[2][1] + trace_props = trace_args["properties"] - assert first_call_args["event"] == "$ai_generation" - assert first_call_props["$ai_model_parameters"]["stream"] - assert first_call_props["$ai_input"] == [ + assert gen_args["event"] == "$ai_generation" + assert gen_props["$ai_model_parameters"]["stream"] + assert gen_props["$ai_input"] == [ {"role": "system", "content": 'You must always answer with "Bar".'}, {"role": "user", "content": "Foo"}, ] - assert first_call_props["$ai_output_choices"] == [{"role": "assistant", "content": "Bar"}] - assert first_call_props["$ai_http_status"] == 200 - assert first_call_props["$ai_input_tokens"] == 20 - assert first_call_props["$ai_output_tokens"] == 1 + assert gen_props["$ai_output_choices"] == [{"role": "assistant", "content": "Bar"}] + assert gen_props["$ai_http_status"] == 200 + assert gen_props["$ai_input_tokens"] == 20 + assert gen_props["$ai_output_tokens"] == 1 - assert second_call_args["event"] == "$ai_trace" - assert second_call_props["$ai_input_state"] == {"input": ""} - assert isinstance(second_call_props["$ai_output_state"], AIMessage) + assert trace_args["event"] == "$ai_trace" + assert trace_props["$ai_input_state"] == {"input": ""} + assert isinstance(trace_props["$ai_output_state"], AIMessage) @pytest.mark.skipif(not OPENAI_API_KEY, reason="OpenAI API key not set") @@ -763,27 +970,27 @@ async def test_async_openai_streaming(mock_client): result = sum(result[1:], result[0]) assert result.content == "Bar" - assert mock_client.capture.call_count == 2 + assert mock_client.capture.call_count == 3 - first_call_args = mock_client.capture.call_args_list[0][1] - first_call_props = first_call_args["properties"] - second_call_args = mock_client.capture.call_args_list[1][1] - second_call_props = second_call_args["properties"] + gen_args = mock_client.capture.call_args_list[1][1] + gen_props = gen_args["properties"] + trace_args = mock_client.capture.call_args_list[2][1] + trace_props = trace_args["properties"] - assert first_call_args["event"] == "$ai_generation" - assert first_call_props["$ai_model_parameters"]["stream"] - assert first_call_props["$ai_input"] == [ + assert gen_args["event"] == "$ai_generation" + assert gen_props["$ai_model_parameters"]["stream"] + assert gen_props["$ai_input"] == [ {"role": "system", "content": 'You must always answer with "Bar".'}, {"role": "user", "content": "Foo"}, ] - assert first_call_props["$ai_output_choices"] == [{"role": "assistant", "content": "Bar"}] - assert first_call_props["$ai_http_status"] == 200 - assert first_call_props["$ai_input_tokens"] == 20 - assert first_call_props["$ai_output_tokens"] == 1 + assert gen_props["$ai_output_choices"] == [{"role": "assistant", "content": "Bar"}] + assert gen_props["$ai_http_status"] == 200 + assert gen_props["$ai_input_tokens"] == 20 + assert gen_props["$ai_output_tokens"] == 1 - assert second_call_args["event"] == "$ai_trace" - assert second_call_props["$ai_input_state"] == {"input": ""} - assert isinstance(second_call_props["$ai_output_state"], AIMessage) + assert trace_args["event"] == "$ai_trace" + assert trace_props["$ai_input_state"] == {"input": ""} + assert isinstance(trace_props["$ai_output_state"], AIMessage) def test_base_url_retrieval(mock_client): @@ -797,8 +1004,8 @@ def test_base_url_retrieval(mock_client): with pytest.raises(Exception): chain.invoke({}, config={"callbacks": [callbacks]}) - assert mock_client.capture.call_count == 2 - generation_call = mock_client.capture.call_args_list[0][1] + assert mock_client.capture.call_count == 3 + generation_call = mock_client.capture.call_args_list[1][1] assert generation_call["properties"]["$ai_base_url"] == "https://test.posthog.com" @@ -814,8 +1021,8 @@ def test_groups(mock_client): callbacks = CallbackHandler(mock_client, groups={"company": "test_company"}) chain.invoke({}, config={"callbacks": [callbacks]}) - assert mock_client.capture.call_count == 2 - generation_call = mock_client.capture.call_args_list[0][1] + assert mock_client.capture.call_count == 3 + generation_call = mock_client.capture.call_args_list[1][1] assert generation_call["groups"] == {"company": "test_company"} @@ -831,8 +1038,8 @@ def test_privacy_mode_local(mock_client): callbacks = CallbackHandler(mock_client, privacy_mode=True) chain.invoke({}, config={"callbacks": [callbacks]}) - assert mock_client.capture.call_count == 2 - generation_call = mock_client.capture.call_args_list[0][1] + assert mock_client.capture.call_count == 3 + generation_call = mock_client.capture.call_args_list[1][1] assert generation_call["properties"]["$ai_input"] is None assert generation_call["properties"]["$ai_output_choices"] is None @@ -850,8 +1057,8 @@ def test_privacy_mode_global(mock_client): callbacks = CallbackHandler(mock_client) chain.invoke({}, config={"callbacks": [callbacks]}) - assert mock_client.capture.call_count == 2 - generation_call = mock_client.capture.call_args_list[0][1] + assert mock_client.capture.call_count == 3 + generation_call = mock_client.capture.call_args_list[1][1] assert generation_call["properties"]["$ai_input"] is None assert generation_call["properties"]["$ai_output_choices"] is None @@ -881,38 +1088,38 @@ def test_anthropic_chain(mock_client): approximate_latency = math.floor(time.time() - start_time) assert result.content == "Bar" - assert mock_client.capture.call_count == 2 + assert mock_client.capture.call_count == 3 - first_call_args = mock_client.capture.call_args_list[0][1] - first_call_props = first_call_args["properties"] - second_call_args = mock_client.capture.call_args_list[1][1] - second_call_props = second_call_args["properties"] + gen_args = mock_client.capture.call_args_list[1][1] + gen_props = gen_args["properties"] + trace_args = mock_client.capture.call_args_list[2][1] + trace_props = trace_args["properties"] - assert first_call_args["event"] == "$ai_generation" - assert first_call_props["$ai_trace_id"] == "test-trace-id" - assert first_call_props["$ai_provider"] == "anthropic" - assert first_call_props["$ai_model"] == "claude-3-opus-20240229" - assert first_call_props["foo"] == "bar" + assert gen_args["event"] == "$ai_generation" + assert gen_props["$ai_trace_id"] == "test-trace-id" + assert gen_props["$ai_provider"] == "anthropic" + assert gen_props["$ai_model"] == "claude-3-opus-20240229" + assert gen_props["foo"] == "bar" - assert first_call_props["$ai_model_parameters"] == { + assert gen_props["$ai_model_parameters"] == { "temperature": 0.0, "max_tokens": 1, "streaming": False, } - assert first_call_props["$ai_input"] == [ + assert gen_props["$ai_input"] == [ {"role": "system", "content": 'You must always answer with "Bar".'}, {"role": "user", "content": "Foo"}, ] - assert first_call_props["$ai_output_choices"] == [{"role": "assistant", "content": "Bar"}] - assert first_call_props["$ai_http_status"] == 200 - assert isinstance(first_call_props["$ai_latency"], float) - assert min(approximate_latency - 1, 0) <= math.floor(first_call_props["$ai_latency"]) <= approximate_latency - assert first_call_props["$ai_input_tokens"] == 17 - assert first_call_props["$ai_output_tokens"] == 1 + assert gen_props["$ai_output_choices"] == [{"role": "assistant", "content": "Bar"}] + assert gen_props["$ai_http_status"] == 200 + assert isinstance(gen_props["$ai_latency"], float) + assert min(approximate_latency - 1, 0) <= math.floor(gen_props["$ai_latency"]) <= approximate_latency + assert gen_props["$ai_input_tokens"] == 17 + assert gen_props["$ai_output_tokens"] == 1 - assert second_call_args["event"] == "$ai_trace" - assert second_call_props["$ai_input_state"] == {} - assert isinstance(second_call_props["$ai_output_state"], AIMessage) + assert trace_args["event"] == "$ai_trace" + assert trace_props["$ai_input_state"] == {} + assert isinstance(trace_props["$ai_output_state"], AIMessage) @pytest.mark.skipif(not ANTHROPIC_API_KEY, reason="ANTHROPIC_API_KEY is not set") @@ -936,29 +1143,29 @@ async def test_async_anthropic_streaming(mock_client): result = sum(result[1:], result[0]) assert result.content == "Bar" - assert mock_client.capture.call_count == 2 + assert mock_client.capture.call_count == 3 - first_call_args = mock_client.capture.call_args_list[0][1] - first_call_props = first_call_args["properties"] - second_call_args = mock_client.capture.call_args_list[1][1] - second_call_props = second_call_args["properties"] + gen_args = mock_client.capture.call_args_list[1][1] + gen_props = gen_args["properties"] + trace_args = mock_client.capture.call_args_list[2][1] + trace_props = trace_args["properties"] - assert first_call_args["event"] == "$ai_generation" - assert first_call_props["$ai_model_parameters"]["streaming"] - assert first_call_props["$ai_input"] == [ + assert gen_args["event"] == "$ai_generation" + assert gen_props["$ai_model_parameters"]["streaming"] + assert gen_props["$ai_input"] == [ {"role": "system", "content": 'You must always answer with "Bar".'}, {"role": "user", "content": "Foo"}, ] - assert first_call_props["$ai_output_choices"] == [{"role": "assistant", "content": "Bar"}] - assert first_call_props["$ai_http_status"] == 200 - assert first_call_props["$ai_input_tokens"] == 17 - assert first_call_props["$ai_output_tokens"] is not None + assert gen_props["$ai_output_choices"] == [{"role": "assistant", "content": "Bar"}] + assert gen_props["$ai_http_status"] == 200 + assert gen_props["$ai_input_tokens"] == 17 + assert gen_props["$ai_output_tokens"] is not None - assert second_call_args["event"] == "$ai_trace" - assert second_call_props["$ai_input_state"] == { + assert trace_args["event"] == "$ai_trace" + assert trace_props["$ai_input_state"] == { "input": "", } - assert isinstance(second_call_props["$ai_output_state"], AIMessage) + assert isinstance(trace_props["$ai_output_state"], AIMessage) def test_tool_calls(mock_client): @@ -986,8 +1193,8 @@ def test_tool_calls(mock_client): callbacks = CallbackHandler(mock_client) chain.invoke({}, config={"callbacks": [callbacks]}) - assert mock_client.capture.call_count == 2 - generation_call = mock_client.capture.call_args_list[0][1] + assert mock_client.capture.call_count == 3 + generation_call = mock_client.capture.call_args_list[1][1] assert generation_call["properties"]["$ai_output_choices"][0]["tool_calls"] == [ { "type": "function", @@ -1018,14 +1225,122 @@ async def sleep(x): # -> Any: chain2.ainvoke({}, config={"callbacks": [cb]}), ) approximate_latency = math.floor(time.time() - start_time) - assert mock_client.capture.call_count == 3 + assert mock_client.capture.call_count == 4 - first_call, second_call, third_call = mock_client.capture.call_args_list - assert first_call[1]["event"] == "$ai_generation" - assert second_call[1]["event"] == "$ai_trace" - assert second_call[1]["properties"]["$ai_trace_name"] == "RunnableSequence" + first_call, second_call, third_call, fourth_call = mock_client.capture.call_args_list + assert first_call[1]["event"] == "$ai_span" + assert second_call[1]["event"] == "$ai_generation" assert third_call[1]["event"] == "$ai_trace" - assert third_call[1]["properties"]["$ai_trace_name"] == "sleep" + assert third_call[1]["properties"]["$ai_span_name"] == "RunnableSequence" + assert fourth_call[1]["event"] == "$ai_trace" + assert fourth_call[1]["properties"]["$ai_span_name"] == "sleep" assert ( min(approximate_latency - 1, 0) <= math.floor(third_call[1]["properties"]["$ai_latency"]) <= approximate_latency ) + + +@pytest.mark.skipif(not OPENAI_API_KEY, reason="OPENAI_API_KEY is not set") +def test_langgraph_agent(mock_client): + @tool + def get_weather(city: Literal["nyc", "sf"]): + """ + Use this to get weather information. + + Args: + city: The city to get weather information for. + """ + if city == "sf": + return "It's always sunny in sf" + return "No info" + + tools = [get_weather] + model = ChatOpenAI(api_key=OPENAI_API_KEY, model="gpt-4o-mini", temperature=0) + graph = create_react_agent(model, tools=tools) + inputs = {"messages": [("user", "what is the weather in sf")]} + cb = CallbackHandler(mock_client, trace_id="test-trace-id", distinct_id="test-distinct-id") + graph.invoke(inputs, config={"callbacks": [cb]}) + calls = [call[1] for call in mock_client.capture.call_args_list] + assert len(calls) == 21 + for call in calls: + assert call["properties"]["$ai_trace_id"] == "test-trace-id" + assert len([call for call in calls if call["event"] == "$ai_generation"]) == 2 + assert len([call for call in calls if call["event"] == "$ai_span"]) == 18 + assert len([call for call in calls if call["event"] == "$ai_trace"]) == 1 + + +@pytest.mark.parametrize("trace_id", ["test-trace-id", None]) +def test_span_set_parent_ids(mock_client, trace_id): + prompt = ChatPromptTemplate.from_messages( + [ + ("system", "You are a helpful assistant."), + ("user", "Who won the world series in 2020?"), + ] + ) + model = FakeMessagesListChatModel( + responses=[AIMessage(content="The Los Angeles Dodgers won the World Series in 2020.")] + ) + callbacks = [CallbackHandler(mock_client, trace_id=trace_id)] + chain = prompt | model + chain.invoke({}, config={"callbacks": callbacks}) + + assert mock_client.capture.call_count == 3 + + span_props = mock_client.capture.call_args_list[0][1] + assert span_props["properties"]["$ai_trace_id"] == span_props["properties"]["$ai_parent_id"] + + generation_props = mock_client.capture.call_args_list[1][1] + assert generation_props["properties"]["$ai_trace_id"] == generation_props["properties"]["$ai_parent_id"] + + +@pytest.mark.parametrize("trace_id", ["test-trace-id", None]) +def test_span_set_parent_ids_for_third_level_run(mock_client, trace_id): + def span_1(_): + def span_2(_): + def span_3(_): + return "span 3" + + return RunnableLambda(span_3) + + return RunnableLambda(span_2) + + callbacks = [CallbackHandler(mock_client, trace_id=trace_id)] + chain = RunnableLambda(span_1) + chain.invoke({}, config={"callbacks": callbacks}) + + assert mock_client.capture.call_count == 3 + + span2, span1, trace = [call[1]["properties"] for call in mock_client.capture.call_args_list] + assert span2["$ai_parent_id"] == span1["$ai_span_id"] + assert span1["$ai_parent_id"] == trace["$ai_trace_id"] + + +def test_captures_error_with_details_in_span(mock_client): + def span(_): + raise ValueError("test") + + callbacks = [CallbackHandler(mock_client)] + chain = RunnableLambda(span) | RunnableLambda(lambda _: "foo") + try: + chain.invoke({}, config={"callbacks": callbacks}) + except ValueError: + pass + + assert mock_client.capture.call_count == 2 + assert mock_client.capture.call_args_list[1][1]["properties"]["$ai_error"] == "ValueError: test" + assert mock_client.capture.call_args_list[1][1]["properties"]["$ai_is_error"] + + +def test_captures_error_without_details_in_span(mock_client): + def span(_): + raise ValueError + + callbacks = [CallbackHandler(mock_client)] + chain = RunnableLambda(span) | RunnableLambda(lambda _: "foo") + try: + chain.invoke({}, config={"callbacks": callbacks}) + except ValueError: + pass + + assert mock_client.capture.call_count == 2 + assert mock_client.capture.call_args_list[1][1]["properties"]["$ai_error"] == "ValueError" + assert mock_client.capture.call_args_list[1][1]["properties"]["$ai_is_error"] diff --git a/posthog/version.py b/posthog/version.py index e557061..20ac2ac 100644 --- a/posthog/version.py +++ b/posthog/version.py @@ -1,4 +1,4 @@ -VERSION = "3.10.0" +VERSION = "3.11.0" if __name__ == "__main__": print(VERSION, end="") # noqa: T201