diff --git a/python/sigil_sdk/__init__.py b/python/sigil_sdk/__init__.py index 59fa014..f929bff 100644 --- a/python/sigil_sdk/__init__.py +++ b/python/sigil_sdk/__init__.py @@ -5,11 +5,13 @@ from .context import ( agent_name_from_context, agent_version_from_context, + content_capture_mode_from_context, conversation_id_from_context, conversation_title_from_context, user_id_from_context, with_agent_name, with_agent_version, + with_content_capture_mode, with_conversation_id, with_conversation_title, with_user_id, @@ -27,6 +29,7 @@ from .models import ( Artifact, ArtifactKind, + ContentCaptureMode, ConversationRating, ConversationRatingInput, ConversationRatingSummary, @@ -67,6 +70,7 @@ "Client", "ClientConfig", "ClientShutdownError", + "ContentCaptureMode", "ConversationRating", "ConversationRatingInput", "ConversationRatingSummary", @@ -101,6 +105,7 @@ "agent_name_from_context", "agent_version_from_context", "assistant_text_message", + "content_capture_mode_from_context", "conversation_id_from_context", "conversation_title_from_context", "text_part", @@ -112,6 +117,7 @@ "user_text_message", "with_agent_name", "with_agent_version", + "with_content_capture_mode", "with_conversation_id", "with_conversation_title", "with_user_id", diff --git a/python/sigil_sdk/client.py b/python/sigil_sdk/client.py index e543cfa..94c1805 100644 --- a/python/sigil_sdk/client.py +++ b/python/sigil_sdk/client.py @@ -4,10 +4,12 @@ import copy import json +import logging import re import secrets import threading -from dataclasses import dataclass, field +from collections.abc import Callable +from dataclasses import dataclass, field, replace from datetime import datetime, timezone from typing import Any from urllib import error as urllib_error @@ -20,8 +22,11 @@ from .config import ClientConfig, resolve_config from .context import ( + _pop_capture_mode, + _push_capture_mode, agent_name_from_context, agent_version_from_context, + content_capture_mode_from_context, conversation_id_from_context, conversation_title_from_context, user_id_from_context, @@ -36,6 +41,7 @@ ) from .exporters import GRPCGenerationExporter, HTTPGenerationExporter, NoopGenerationExporter from .models import ( + ContentCaptureMode, ConversationRating, ConversationRatingInput, ConversationRatingSummary, @@ -51,6 +57,7 @@ SubmitConversationRatingResponse, ToolExecutionEnd, ToolExecutionStart, + _metadata_key_content_capture_mode, ) from .proto_mapping import generation_to_proto from .validation import validate_embedding_result, validate_embedding_start, validate_generation @@ -128,6 +135,104 @@ _metadata_legacy_user_id_key = "user.id" +def _resolve_content_capture_mode(override: ContentCaptureMode, fallback: ContentCaptureMode) -> ContentCaptureMode: + """Returns the effective mode from an override and a fallback. DEFAULT falls through.""" + if override != ContentCaptureMode.DEFAULT: + return override + return fallback + + +def _resolve_client_content_capture_mode(mode: ContentCaptureMode) -> ContentCaptureMode: + """Resolves client-level mode. DEFAULT → NO_TOOL_CONTENT for backward compat.""" + if mode == ContentCaptureMode.DEFAULT: + return ContentCaptureMode.NO_TOOL_CONTENT + return mode + + +def _call_content_capture_resolver( + resolver: Callable | None, + metadata: dict[str, Any] | None, + logger: logging.Logger | None = None, +) -> ContentCaptureMode: + """Invokes resolver callback safely. Returns DEFAULT when None. Exceptions → METADATA_ONLY.""" + if resolver is None: + return ContentCaptureMode.DEFAULT + try: + result = resolver(metadata) + if not isinstance(result, ContentCaptureMode): + if logger is not None: + logger.warning( + "sigil: content capture resolver returned %s instead of ContentCaptureMode, " + "falling back to METADATA_ONLY", + type(result).__name__, + ) + return ContentCaptureMode.METADATA_ONLY + return result + except Exception: # noqa: BLE001 + if logger is not None: + logger.warning("sigil: content capture resolver failed, falling back to METADATA_ONLY", exc_info=True) + return ContentCaptureMode.METADATA_ONLY + + +def _should_include_tool_content( + tool_mode: ContentCaptureMode, + ctx_mode: ContentCaptureMode | None, + client_default: ContentCaptureMode, + legacy_include: bool, +) -> bool: + """Determines whether tool content should be included in span attributes.""" + resolved = _resolve_client_content_capture_mode(client_default) + if ctx_mode is not None and ctx_mode != ContentCaptureMode.DEFAULT: + resolved = ctx_mode + if tool_mode != ContentCaptureMode.DEFAULT: + resolved = tool_mode + if resolved == ContentCaptureMode.METADATA_ONLY: + return False + if resolved == ContentCaptureMode.FULL: + return True + # NO_TOOL_CONTENT / DEFAULT: honor legacy include_content opt-in. + return legacy_include + + +def _stamp_content_capture_metadata(generation: Generation, mode: ContentCaptureMode) -> None: + """Sets the content capture mode marker on the generation.""" + generation.metadata[_metadata_key_content_capture_mode] = mode.value + + +def _strip_content(generation: Generation, error_category: str) -> None: + """Strips sensitive content from a generation while preserving structure. + + Note: user-provided metadata and tags are NOT stripped. Callers are responsible + for ensuring these dicts do not contain sensitive content when using MetadataOnly mode. + """ + generation.system_prompt = "" + generation.artifacts = [] + + if generation.call_error != "": + generation.call_error = error_category if error_category else "sdk_error" + generation.metadata.pop("call_error", None) + + for message in generation.input: + _strip_message_content(message) + for message in generation.output: + _strip_message_content(message) + for tool in generation.tools: + tool.description = "" + tool.input_schema_json = b"" + + +def _strip_message_content(message: Message) -> None: + """Strips all content from message parts (text, thinking, tool call input, tool result).""" + for part in message.parts: + part.text = "" + part.thinking = "" + if part.tool_call is not None: + part.tool_call.input_json = b"" + if part.tool_result is not None: + part.tool_result.content = "" + part.tool_result.content_json = b"" + + class Client: """Sigil client that records generations, tool spans, and exports in background.""" @@ -261,12 +366,20 @@ def start_tool_execution(self, start: ToolExecutionStart) -> ToolExecutionRecord ) _set_tool_span_attributes(span, seed) + # Resolve content capture: per-tool > context (parent generation) > resolver > client default. + resolver_mode = _call_content_capture_resolver(self._config.content_capture_resolver, {}, self._config.logger) + effective_client_default = _resolve_content_capture_mode(resolver_mode, self._config.content_capture) + ctx_mode = content_capture_mode_from_context() + include_content = _should_include_tool_content( + seed.content_capture, ctx_mode, effective_client_default, seed.include_content + ) + return ToolExecutionRecorder( client=self, seed=seed, span=span, started_at=started_at, - include_content=seed.include_content, + include_content=include_content, ) def submit_conversation_rating( @@ -284,6 +397,15 @@ def submit_conversation_rating( if len(normalized_conversation_id) > _max_rating_conversation_id_len: raise ValidationError("sigil conversation rating validation failed: conversation_id is too long") + resolver_mode = _call_content_capture_resolver( + self._config.content_capture_resolver, rating.metadata, self._config.logger + ) + effective_mode = _resolve_content_capture_mode( + resolver_mode, _resolve_client_content_capture_mode(self._config.content_capture) + ) + if effective_mode == ContentCaptureMode.METADATA_ONLY: + rating = replace(rating, comment="") + normalized_rating = _normalize_conversation_rating_input(rating) endpoint = _conversation_rating_endpoint( self._config.api.endpoint, @@ -436,12 +558,27 @@ def _start_generation(self, start: GenerationStart, default_mode: GenerationMode ), ) - return GenerationRecorder( + # Resolve content capture mode: per-recording > context > resolver > client default. + resolver_mode = _call_content_capture_resolver( + self._config.content_capture_resolver, seed.metadata, self._config.logger + ) + client_mode = _resolve_client_content_capture_mode( + _resolve_content_capture_mode(resolver_mode, self._config.content_capture) + ) + ctx_mode = content_capture_mode_from_context() + if ctx_mode is not None: + client_mode = _resolve_content_capture_mode(ctx_mode, client_mode) + cc_mode = _resolve_content_capture_mode(seed.content_capture, client_mode) + + recorder = GenerationRecorder( client=self, seed=seed, span=span, started_at=started_at, + _content_capture_mode=cc_mode, ) + _push_capture_mode(id(recorder), cc_mode) + return recorder def _enqueue_generation(self, generation: Generation) -> None: if self._shutting_down or self._closed: @@ -672,6 +809,7 @@ class GenerationRecorder: span: Span started_at: datetime + _content_capture_mode: ContentCaptureMode = ContentCaptureMode.NO_TOOL_CONTENT _lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False) _ended: bool = False _call_error: Exception | None = None @@ -733,66 +871,80 @@ def end(self) -> None: result = copy.deepcopy(self._result) if self._result is not None else Generation() first_token_at = self._first_token_at - completed_at = _to_utc(self.client._now()) - generation = self._normalize_generation(result, completed_at, call_error) - _apply_trace_context_from_span(self.span, generation) + try: + completed_at = _to_utc(self.client._now()) + generation = self._normalize_generation(result, completed_at, call_error) + _apply_trace_context_from_span(self.span, generation) - self.span.update_name(_generation_span_name(generation.operation_name, generation.model.name)) - _set_generation_span_attributes(self.span, generation) + _stamp_content_capture_metadata(generation, self._content_capture_mode) + if self._content_capture_mode == ContentCaptureMode.METADATA_ONLY: + error_cat = _error_category_from_exception(call_error, fallback_sdk=True) if call_error else "" + _strip_content(generation, error_cat) - local_error: Exception | None = None - try: - validate_generation(generation) - except Exception as exc: # noqa: BLE001 - local_error = ValidationError(f"sigil: generation validation failed: {exc}") + self.span.update_name(_generation_span_name(generation.operation_name, generation.model.name)) + _set_generation_span_attributes(self.span, generation) - if local_error is None: + local_error: Exception | None = None try: - self.client._enqueue_generation(generation) - except QueueFullError as exc: - local_error = exc - except ClientShutdownError as exc: - local_error = exc + validate_generation(generation) except Exception as exc: # noqa: BLE001 - local_error = EnqueueError(f"sigil: generation enqueue failed: {exc}") - - if call_error is not None: - self.span.record_exception(call_error) - if mapping_error is not None: - self.span.record_exception(mapping_error) - if local_error is not None: - self.span.record_exception(local_error) - - error_type = "" - error_category = "" - if call_error is not None: - error_type = "provider_call_error" - error_category = _error_category_from_exception(call_error, fallback_sdk=True) - self.span.set_attribute(_span_attr_error_type, error_type) - self.span.set_attribute(_span_attr_error_category, error_category) - self.span.set_status(Status(StatusCode.ERROR, str(call_error))) - elif mapping_error is not None: - error_type = "mapping_error" - error_category = "sdk_error" - self.span.set_attribute(_span_attr_error_type, error_type) - self.span.set_attribute(_span_attr_error_category, error_category) - self.span.set_status(Status(StatusCode.ERROR, str(mapping_error))) - elif local_error is not None: - error_type = "validation_error" if isinstance(local_error, ValidationError) else "enqueue_error" - error_category = "sdk_error" - self.span.set_attribute(_span_attr_error_type, error_type) - self.span.set_attribute(_span_attr_error_category, error_category) - self.span.set_status(Status(StatusCode.ERROR, str(local_error))) - else: - self.span.set_status(Status(StatusCode.OK)) + local_error = ValidationError(f"sigil: generation validation failed: {exc}") + + if local_error is None: + try: + self.client._enqueue_generation(generation) + except QueueFullError as exc: + local_error = exc + except ClientShutdownError as exc: + local_error = exc + except Exception as exc: # noqa: BLE001 + local_error = EnqueueError(f"sigil: generation enqueue failed: {exc}") + + is_metadata_only = self._content_capture_mode == ContentCaptureMode.METADATA_ONLY + + if not is_metadata_only: + if call_error is not None: + self.span.record_exception(call_error) + if mapping_error is not None: + self.span.record_exception(mapping_error) + # SDK-internal errors (validation/enqueue) contain no user content — always record. + if local_error is not None: + self.span.record_exception(local_error) + + error_type = "" + error_category = "" + if call_error is not None: + error_type = "provider_call_error" + error_category = _error_category_from_exception(call_error, fallback_sdk=True) + self.span.set_attribute(_span_attr_error_type, error_type) + self.span.set_attribute(_span_attr_error_category, error_category) + self.span.set_status(Status(StatusCode.ERROR, error_category if is_metadata_only else str(call_error))) + elif mapping_error is not None: + error_type = "mapping_error" + error_category = "sdk_error" + self.span.set_attribute(_span_attr_error_type, error_type) + self.span.set_attribute(_span_attr_error_category, error_category) + self.span.set_status( + Status(StatusCode.ERROR, error_category if is_metadata_only else str(mapping_error)) + ) + elif local_error is not None: + error_type = "validation_error" if isinstance(local_error, ValidationError) else "enqueue_error" + error_category = "sdk_error" + self.span.set_attribute(_span_attr_error_type, error_type) + self.span.set_attribute(_span_attr_error_category, error_category) + self.span.set_status(Status(StatusCode.ERROR, error_category if is_metadata_only else str(local_error))) + else: + self.span.set_status(Status(StatusCode.OK)) - self.client._record_generation_metrics(generation, error_type, error_category, first_token_at) + self.client._record_generation_metrics(generation, error_type, error_category, first_token_at) - self.span.end(end_time=_datetime_to_ns(generation.completed_at or completed_at)) + self.span.end(end_time=_datetime_to_ns(generation.completed_at or completed_at)) - with self._lock: - self._last_generation = copy.deepcopy(generation) - self._final_error = local_error + with self._lock: + self._last_generation = copy.deepcopy(generation) + self._final_error = local_error + finally: + _pop_capture_mode(id(self)) def err(self) -> Exception | None: """Returns local validation/enqueue error after `end()`.""" diff --git a/python/sigil_sdk/config.py b/python/sigil_sdk/config.py index 908a7af..55d5c4b 100644 --- a/python/sigil_sdk/config.py +++ b/python/sigil_sdk/config.py @@ -8,12 +8,13 @@ from collections.abc import Callable from dataclasses import dataclass, field from datetime import datetime, timedelta +from typing import Any from opentelemetry.metrics import Meter from opentelemetry.trace import Tracer from .exporters.base import GenerationExporter -from .models import utc_now +from .models import ContentCaptureMode, utc_now TENANT_HEADER = "X-Scope-OrgID" AUTHORIZATION_HEADER = "Authorization" @@ -71,6 +72,8 @@ class ClientConfig: generation_export: GenerationExportConfig = field(default_factory=GenerationExportConfig) api: ApiConfig = field(default_factory=ApiConfig) embedding_capture: EmbeddingCaptureConfig = field(default_factory=EmbeddingCaptureConfig) + content_capture: ContentCaptureMode = ContentCaptureMode.DEFAULT + content_capture_resolver: Callable[[dict[str, Any]], ContentCaptureMode] | None = None tracer: Tracer | None = None meter: Meter | None = None logger: logging.Logger | None = None diff --git a/python/sigil_sdk/context.py b/python/sigil_sdk/context.py index d713810..bf41e7f 100644 --- a/python/sigil_sdk/context.py +++ b/python/sigil_sdk/context.py @@ -5,6 +5,10 @@ import contextvars from collections.abc import Iterator from contextlib import contextmanager +from typing import TYPE_CHECKING, NamedTuple + +if TYPE_CHECKING: + from .models import ContentCaptureMode _conversation_id: contextvars.ContextVar[str | None] = contextvars.ContextVar("sigil_conversation_id", default=None) _conversation_title: contextvars.ContextVar[str | None] = contextvars.ContextVar( @@ -13,6 +17,45 @@ _user_id: contextvars.ContextVar[str | None] = contextvars.ContextVar("sigil_user_id", default=None) _agent_name: contextvars.ContextVar[str | None] = contextvars.ContextVar("sigil_agent_name", default=None) _agent_version: contextvars.ContextVar[str | None] = contextvars.ContextVar("sigil_agent_version", default=None) +_content_capture_mode: contextvars.ContextVar[ContentCaptureMode | None] = contextvars.ContextVar( + "sigil_content_capture_mode", default=None +) + + +class _CaptureStackEntry(NamedTuple): + recorder_id: int + mode: ContentCaptureMode + + +# Stack of active GenerationRecorder capture modes. +# Used to correctly restore the ContextVar when overlapping generations end in +# non-LIFO order. +_capture_mode_stack: contextvars.ContextVar[tuple[_CaptureStackEntry, ...]] = contextvars.ContextVar( + "_sigil_capture_mode_stack", default=() +) +# Snapshot of _content_capture_mode before the first recorder pushed onto the +# stack. Restored when the stack empties so that a surrounding +# with_content_capture_mode() block is not clobbered. +_capture_mode_stack_base: contextvars.ContextVar[ContentCaptureMode | None] = contextvars.ContextVar( + "_sigil_capture_mode_stack_base", default=None +) + + +def _push_capture_mode(recorder_id: int, mode: ContentCaptureMode) -> None: + """Pushes a recorder's capture mode onto the stack and sets the ContextVar.""" + stack = _capture_mode_stack.get() + if not stack: + _capture_mode_stack_base.set(_content_capture_mode.get()) + _capture_mode_stack.set((*stack, _CaptureStackEntry(recorder_id, mode))) + _content_capture_mode.set(mode) + + +def _pop_capture_mode(recorder_id: int) -> None: + """Removes a recorder from the stack and restores the ContextVar.""" + stack = _capture_mode_stack.get() + new_stack = tuple(e for e in stack if e.recorder_id != recorder_id) + _capture_mode_stack.set(new_stack) + _content_capture_mode.set(new_stack[-1].mode if new_stack else _capture_mode_stack_base.get()) @contextmanager @@ -98,3 +141,20 @@ def user_id_from_context() -> str | None: """Returns the current user id from context variables.""" return _user_id.get() + + +@contextmanager +def with_content_capture_mode(mode: ContentCaptureMode) -> Iterator[None]: + """Sets the content capture mode within a context block.""" + + token = _content_capture_mode.set(mode) + try: + yield + finally: + _content_capture_mode.reset(token) + + +def content_capture_mode_from_context() -> ContentCaptureMode | None: + """Returns the content capture mode from context, or None if not set.""" + + return _content_capture_mode.get() diff --git a/python/sigil_sdk/models.py b/python/sigil_sdk/models.py index 5a1281f..17014e7 100644 --- a/python/sigil_sdk/models.py +++ b/python/sigil_sdk/models.py @@ -32,6 +32,22 @@ class PartKind(str, Enum): TOOL_RESULT = "tool_result" +class ContentCaptureMode(str, Enum): + """Controls what content is included in exported generation payloads and OTel span attributes. + + Note: user-provided metadata and tags are NOT stripped, even in METADATA_ONLY mode. + Callers are responsible for ensuring these dicts do not contain sensitive content. + """ + + DEFAULT = "default" + FULL = "full" + NO_TOOL_CONTENT = "no_tool_content" + METADATA_ONLY = "metadata_only" + + +_metadata_key_content_capture_mode = "sigil.sdk.content_capture_mode" + + class ArtifactKind(str, Enum): """Allowed raw artifact kinds.""" @@ -175,6 +191,7 @@ class GenerationStart: tool_choice: str | None = None thinking_enabled: bool | None = None tools: list[ToolDefinition] = field(default_factory=list) + content_capture: ContentCaptureMode = ContentCaptureMode.DEFAULT tags: dict[str, str] = field(default_factory=dict) metadata: dict[str, Any] = field(default_factory=dict) started_at: datetime | None = None @@ -256,6 +273,7 @@ class ToolExecutionStart: request_model: str = "" request_provider: str = "" include_content: bool = False + content_capture: ContentCaptureMode = ContentCaptureMode.DEFAULT started_at: datetime | None = None diff --git a/python/sigil_sdk/validation.py b/python/sigil_sdk/validation.py index 8c78681..1ac63e8 100644 --- a/python/sigil_sdk/validation.py +++ b/python/sigil_sdk/validation.py @@ -2,12 +2,29 @@ from __future__ import annotations -from .models import ArtifactKind, EmbeddingResult, EmbeddingStart, Generation, GenerationMode, MessageRole, PartKind +from .models import ( + ArtifactKind, + ContentCaptureMode, + EmbeddingResult, + EmbeddingStart, + Generation, + GenerationMode, + MessageRole, + PartKind, + _metadata_key_content_capture_mode, +) + + +def _is_content_stripped(generation: Generation) -> bool: + """Reports whether the generation has been through MetadataOnly stripping.""" + return generation.metadata.get(_metadata_key_content_capture_mode) == ContentCaptureMode.METADATA_ONLY.value def validate_generation(generation: Generation) -> None: """Raises ValueError when a generation payload is invalid.""" + content_stripped = _is_content_stripped(generation) + if generation.mode is not None and generation.mode not in (GenerationMode.SYNC, GenerationMode.STREAM): raise ValueError("generation.mode must be one of SYNC|STREAM") @@ -23,6 +40,7 @@ def validate_generation(generation: Generation) -> None: index, message.role.value if hasattr(message.role, "value") else str(message.role), message.parts, + content_stripped, ) for index, message in enumerate(generation.output): @@ -31,6 +49,7 @@ def validate_generation(generation: Generation) -> None: index, message.role.value if hasattr(message.role, "value") else str(message.role), message.parts, + content_stripped, ) for index, tool in enumerate(generation.tools): @@ -73,7 +92,7 @@ def validate_embedding_result(result: EmbeddingResult) -> None: raise ValueError("embedding.dimensions must be > 0") -def _validate_message(path: str, index: int, role: str, parts: list[object]) -> None: +def _validate_message(path: str, index: int, role: str, parts: list[object], content_stripped: bool = False) -> None: if role not in (MessageRole.USER.value, MessageRole.ASSISTANT.value, MessageRole.TOOL.value): raise ValueError(f"{path}[{index}].role must be one of user|assistant|tool") @@ -81,10 +100,12 @@ def _validate_message(path: str, index: int, role: str, parts: list[object]) -> raise ValueError(f"{path}[{index}].parts must not be empty") for part_index, part in enumerate(parts): - _validate_part(path, index, part_index, role, part) + _validate_part(path, index, part_index, role, part, content_stripped) -def _validate_part(path: str, message_index: int, part_index: int, role: str, part: object) -> None: +def _validate_part( + path: str, message_index: int, part_index: int, role: str, part: object, content_stripped: bool = False +) -> None: kind = part.kind.value if hasattr(part.kind, "value") else str(part.kind) if kind not in ( @@ -105,18 +126,20 @@ def _validate_part(path: str, message_index: int, part_index: int, role: str, pa if getattr(part, "tool_result", None) is not None: field_count += 1 - if field_count != 1: + # Stripped text/thinking parts have empty payloads — that's expected. + stripped_text_or_thinking = content_stripped and kind in (PartKind.TEXT.value, PartKind.THINKING.value) + if field_count != 1 and not stripped_text_or_thinking: raise ValueError(f"{path}[{message_index}].parts[{part_index}] must set exactly one payload field") if kind == PartKind.TEXT.value: - if getattr(part, "text", "").strip() == "": + if not content_stripped and getattr(part, "text", "").strip() == "": raise ValueError(f"{path}[{message_index}].parts[{part_index}].text is required") return if kind == PartKind.THINKING.value: if role != MessageRole.ASSISTANT.value: raise ValueError(f"{path}[{message_index}].parts[{part_index}].thinking only allowed for assistant role") - if getattr(part, "thinking", "").strip() == "": + if not content_stripped and getattr(part, "thinking", "").strip() == "": raise ValueError(f"{path}[{message_index}].parts[{part_index}].thinking is required") return diff --git a/python/tests/test_content_capture.py b/python/tests/test_content_capture.py new file mode 100644 index 0000000..8cdb996 --- /dev/null +++ b/python/tests/test_content_capture.py @@ -0,0 +1,1029 @@ +"""Tests for ContentCaptureMode: resolution, stripping, context propagation, tool spans.""" + +from __future__ import annotations + +import json +import threading +from datetime import timedelta +from http.server import BaseHTTPRequestHandler, HTTPServer + +import pytest +from conftest import CapturingGenerationExporter +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from sigil_sdk import ( + ApiConfig, + Client, + ClientConfig, + ContentCaptureMode, + ConversationRatingInput, + ConversationRatingValue, + Generation, + GenerationExportConfig, + GenerationStart, + Message, + MessageRole, + ModelRef, + Part, + PartKind, + TokenUsage, + ToolCall, + ToolDefinition, + ToolExecutionStart, + ToolResult, + validate_generation, +) +from sigil_sdk.context import content_capture_mode_from_context, with_content_capture_mode + +_METADATA_KEY = "sigil.sdk.content_capture_mode" + + +def _new_client(exporter: CapturingGenerationExporter, tracer=None, **overrides) -> Client: + generation_export = GenerationExportConfig( + batch_size=overrides.get("batch_size", 10), + flush_interval=overrides.get("flush_interval", timedelta(seconds=60)), + queue_size=overrides.get("queue_size", 10), + max_retries=overrides.get("max_retries", 1), + initial_backoff=overrides.get("initial_backoff", timedelta(milliseconds=1)), + max_backoff=overrides.get("max_backoff", timedelta(milliseconds=1)), + ) + return Client( + ClientConfig( + tracer=tracer, + generation_export=generation_export, + generation_exporter=exporter, + content_capture=overrides.get("content_capture", ContentCaptureMode.DEFAULT), + content_capture_resolver=overrides.get("content_capture_resolver", None), + ) + ) + + +def _seed(content_capture: ContentCaptureMode = ContentCaptureMode.DEFAULT) -> GenerationStart: + return GenerationStart( + model=ModelRef(provider="anthropic", name="claude-sonnet-4-5"), + content_capture=content_capture, + ) + + +def _full_generation() -> Generation: + return Generation( + system_prompt="You are helpful.", + input=[ + Message(role=MessageRole.USER, parts=[Part(kind=PartKind.TEXT, text="What is the weather?")]), + Message( + role=MessageRole.TOOL, + parts=[ + Part( + kind=PartKind.TOOL_RESULT, + tool_result=ToolResult( + tool_call_id="call_1", + name="weather", + content="sunny 18C", + content_json=b'{"temp":18}', + ), + ) + ], + ), + ], + output=[ + Message( + role=MessageRole.ASSISTANT, + parts=[ + Part(kind=PartKind.THINKING, thinking="let me think about weather"), + Part( + kind=PartKind.TOOL_CALL, + tool_call=ToolCall(name="weather", id="call_1", input_json=b'{"city":"Paris"}'), + ), + Part(kind=PartKind.TEXT, text="It's 18C and sunny in Paris."), + ], + ) + ], + tools=[ + ToolDefinition( + name="weather", description="Get weather info", type="function", input_schema_json=b'{"type":"object"}' + ), + ], + usage=TokenUsage(input_tokens=120, output_tokens=42), + stop_reason="end_turn", + call_error="rate limit exceeded", + artifacts=[], + metadata={"sigil.sdk.name": "sdk-python", "call_error": "rate limit exceeded"}, + ) + + +# --------------------------------------------------------------------------- +# Mode resolution +# --------------------------------------------------------------------------- + + +class TestContentCaptureModeResolution: + @pytest.mark.parametrize( + "client_mode, gen_mode, want_marker", + [ + (ContentCaptureMode.DEFAULT, ContentCaptureMode.DEFAULT, "no_tool_content"), + (ContentCaptureMode.METADATA_ONLY, ContentCaptureMode.DEFAULT, "metadata_only"), + (ContentCaptureMode.FULL, ContentCaptureMode.METADATA_ONLY, "metadata_only"), + (ContentCaptureMode.METADATA_ONLY, ContentCaptureMode.FULL, "full"), + (ContentCaptureMode.FULL, ContentCaptureMode.DEFAULT, "full"), + ], + ) + def test_generation_content_capture(self, client_mode, gen_mode, want_marker): + exporter = CapturingGenerationExporter() + client = _new_client(exporter, content_capture=client_mode) + try: + rec = client.start_generation(_seed(gen_mode)) + rec.set_result( + Generation( + input=[Message(role=MessageRole.USER, parts=[Part(kind=PartKind.TEXT, text="Hello")])], + output=[Message(role=MessageRole.ASSISTANT, parts=[Part(kind=PartKind.TEXT, text="Hi")])], + usage=TokenUsage(input_tokens=10, output_tokens=5), + ) + ) + rec.end() + assert rec.err() is None + + gen = rec.last_generation + assert gen.metadata[_METADATA_KEY] == want_marker + finally: + client.shutdown() + + def test_default_resolution_is_no_tool_content(self): + exporter = CapturingGenerationExporter() + client = _new_client(exporter) + try: + rec = client.start_generation(_seed()) + rec.set_result( + output=[Message(role=MessageRole.ASSISTANT, parts=[Part(kind=PartKind.TEXT, text="ok")])], + usage=TokenUsage(input_tokens=1, output_tokens=1), + ) + rec.end() + gen = rec.last_generation + assert gen.metadata[_METADATA_KEY] == "no_tool_content" + # Content should NOT be stripped under no_tool_content + assert gen.output[0].parts[0].text == "ok" + finally: + client.shutdown() + + +# --------------------------------------------------------------------------- +# Content stripping (METADATA_ONLY) +# --------------------------------------------------------------------------- + + +class TestContentStripping: + def test_metadata_only_strips_sensitive_content(self): + exporter = CapturingGenerationExporter() + client = _new_client(exporter, content_capture=ContentCaptureMode.METADATA_ONLY) + try: + rec = client.start_generation(_seed()) + rec.set_result(_full_generation()) + rec.end() + assert rec.err() is None + + gen = rec.last_generation + + # Stripped + assert gen.system_prompt == "" + assert gen.input[0].parts[0].text == "" + assert gen.output[0].parts[0].thinking == "" + assert gen.output[0].parts[1].tool_call.input_json == b"" + assert gen.output[0].parts[2].text == "" + assert gen.input[1].parts[0].tool_result.content == "" + assert gen.input[1].parts[0].tool_result.content_json == b"" + assert gen.tools[0].description == "" + assert gen.tools[0].input_schema_json == b"" + + # Preserved + assert len(gen.input) == 2 + assert len(gen.output) == 1 + assert len(gen.output[0].parts) == 3 + assert gen.input[0].role == MessageRole.USER + assert gen.output[0].parts[0].kind == PartKind.THINKING + assert gen.output[0].parts[1].tool_call.name == "weather" + assert gen.output[0].parts[1].tool_call.id == "call_1" + assert gen.input[1].parts[0].tool_result.tool_call_id == "call_1" + assert gen.input[1].parts[0].tool_result.name == "weather" + assert gen.tools[0].name == "weather" + assert gen.usage.input_tokens == 120 + assert gen.usage.output_tokens == 42 + assert gen.stop_reason == "end_turn" + assert gen.metadata["sigil.sdk.name"] == "sdk-python" + finally: + client.shutdown() + + def test_metadata_only_replaces_call_error_with_category(self): + exporter = CapturingGenerationExporter() + client = _new_client(exporter, content_capture=ContentCaptureMode.METADATA_ONLY) + try: + rec = client.start_generation(_seed()) + rec.set_call_error(RuntimeError("429 rate limit exceeded")) + rec.set_result( + Generation( + input=[Message(role=MessageRole.USER, parts=[Part(kind=PartKind.TEXT, text="hello")])], + output=[Message(role=MessageRole.ASSISTANT, parts=[Part(kind=PartKind.TEXT, text="world")])], + usage=TokenUsage(input_tokens=1, output_tokens=1), + ) + ) + rec.end() + + gen = rec.last_generation + assert gen.call_error == "rate_limit" + assert "call_error" not in gen.metadata + finally: + client.shutdown() + + def test_metadata_only_falls_back_to_sdk_error(self): + exporter = CapturingGenerationExporter() + client = _new_client(exporter, content_capture=ContentCaptureMode.METADATA_ONLY) + try: + rec = client.start_generation(_seed()) + rec.set_call_error(RuntimeError("something went wrong")) + rec.set_result( + Generation( + input=[Message(role=MessageRole.USER, parts=[Part(kind=PartKind.TEXT, text="hello")])], + output=[Message(role=MessageRole.ASSISTANT, parts=[Part(kind=PartKind.TEXT, text="world")])], + usage=TokenUsage(input_tokens=1, output_tokens=1), + ) + ) + rec.end() + + gen = rec.last_generation + assert gen.call_error == "sdk_error" + finally: + client.shutdown() + + def test_full_mode_preserves_all_content(self): + exporter = CapturingGenerationExporter() + client = _new_client(exporter, content_capture=ContentCaptureMode.FULL) + try: + rec = client.start_generation(_seed()) + rec.set_result( + Generation( + system_prompt="You are helpful.", + input=[Message(role=MessageRole.USER, parts=[Part(kind=PartKind.TEXT, text="Hello")])], + output=[Message(role=MessageRole.ASSISTANT, parts=[Part(kind=PartKind.TEXT, text="Hi")])], + usage=TokenUsage(input_tokens=10, output_tokens=5), + ) + ) + rec.end() + gen = rec.last_generation + assert gen.metadata[_METADATA_KEY] == "full" + assert gen.system_prompt == "You are helpful." + assert gen.input[0].parts[0].text == "Hello" + assert gen.output[0].parts[0].text == "Hi" + finally: + client.shutdown() + + +# --------------------------------------------------------------------------- +# Per-generation override +# --------------------------------------------------------------------------- + + +class TestPerGenerationOverride: + def test_per_generation_full_overrides_client_metadata_only(self): + exporter = CapturingGenerationExporter() + client = _new_client(exporter, content_capture=ContentCaptureMode.METADATA_ONLY) + try: + rec = client.start_generation(_seed(ContentCaptureMode.FULL)) + rec.set_result( + Generation( + input=[Message(role=MessageRole.USER, parts=[Part(kind=PartKind.TEXT, text="Hello")])], + output=[Message(role=MessageRole.ASSISTANT, parts=[Part(kind=PartKind.TEXT, text="Hi")])], + usage=TokenUsage(input_tokens=10, output_tokens=5), + ) + ) + rec.end() + gen = rec.last_generation + assert gen.metadata[_METADATA_KEY] == "full" + assert gen.input[0].parts[0].text == "Hello" + finally: + client.shutdown() + + def test_per_generation_metadata_only_overrides_client_full(self): + exporter = CapturingGenerationExporter() + client = _new_client(exporter, content_capture=ContentCaptureMode.FULL) + try: + rec = client.start_generation(_seed(ContentCaptureMode.METADATA_ONLY)) + rec.set_result( + Generation( + input=[Message(role=MessageRole.USER, parts=[Part(kind=PartKind.TEXT, text="Hello")])], + output=[Message(role=MessageRole.ASSISTANT, parts=[Part(kind=PartKind.TEXT, text="Hi")])], + usage=TokenUsage(input_tokens=10, output_tokens=5), + ) + ) + rec.end() + gen = rec.last_generation + assert gen.metadata[_METADATA_KEY] == "metadata_only" + assert gen.input[0].parts[0].text == "" + finally: + client.shutdown() + + +# --------------------------------------------------------------------------- +# Resolver callback +# --------------------------------------------------------------------------- + + +class TestResolverCallback: + def test_resolver_metadata_only_overrides_client_full(self): + exporter = CapturingGenerationExporter() + client = _new_client( + exporter, + content_capture=ContentCaptureMode.FULL, + content_capture_resolver=lambda _meta: ContentCaptureMode.METADATA_ONLY, + ) + try: + rec = client.start_generation(_seed()) + rec.set_result( + Generation( + input=[Message(role=MessageRole.USER, parts=[Part(kind=PartKind.TEXT, text="hello")])], + output=[Message(role=MessageRole.ASSISTANT, parts=[Part(kind=PartKind.TEXT, text="world")])], + usage=TokenUsage(input_tokens=10, output_tokens=5), + ) + ) + rec.end() + gen = rec.last_generation + assert gen.metadata[_METADATA_KEY] == "metadata_only" + assert gen.input[0].parts[0].text == "" + finally: + client.shutdown() + + def test_per_generation_overrides_resolver(self): + exporter = CapturingGenerationExporter() + client = _new_client( + exporter, + content_capture=ContentCaptureMode.DEFAULT, + content_capture_resolver=lambda _meta: ContentCaptureMode.METADATA_ONLY, + ) + try: + rec = client.start_generation(_seed(ContentCaptureMode.FULL)) + rec.set_result( + Generation( + input=[Message(role=MessageRole.USER, parts=[Part(kind=PartKind.TEXT, text="hello")])], + output=[Message(role=MessageRole.ASSISTANT, parts=[Part(kind=PartKind.TEXT, text="world")])], + usage=TokenUsage(input_tokens=10, output_tokens=5), + ) + ) + rec.end() + gen = rec.last_generation + assert gen.metadata[_METADATA_KEY] == "full" + assert gen.input[0].parts[0].text == "hello" + finally: + client.shutdown() + + def test_resolver_default_defers_to_client(self): + exporter = CapturingGenerationExporter() + client = _new_client( + exporter, + content_capture=ContentCaptureMode.METADATA_ONLY, + content_capture_resolver=lambda _meta: ContentCaptureMode.DEFAULT, + ) + try: + rec = client.start_generation(_seed()) + rec.set_result( + Generation( + input=[Message(role=MessageRole.USER, parts=[Part(kind=PartKind.TEXT, text="hello")])], + output=[Message(role=MessageRole.ASSISTANT, parts=[Part(kind=PartKind.TEXT, text="world")])], + usage=TokenUsage(input_tokens=10, output_tokens=5), + ) + ) + rec.end() + gen = rec.last_generation + assert gen.metadata[_METADATA_KEY] == "metadata_only" + assert gen.input[0].parts[0].text == "" + finally: + client.shutdown() + + def test_resolver_exception_fails_closed_to_metadata_only(self): + def bad_resolver(_meta): + raise RuntimeError("resolver bug") + + exporter = CapturingGenerationExporter() + client = _new_client( + exporter, + content_capture=ContentCaptureMode.FULL, + content_capture_resolver=bad_resolver, + ) + try: + rec = client.start_generation(_seed()) + rec.set_result( + Generation( + input=[Message(role=MessageRole.USER, parts=[Part(kind=PartKind.TEXT, text="hello")])], + output=[Message(role=MessageRole.ASSISTANT, parts=[Part(kind=PartKind.TEXT, text="world")])], + usage=TokenUsage(input_tokens=10, output_tokens=5), + ) + ) + rec.end() + gen = rec.last_generation + assert gen.metadata[_METADATA_KEY] == "metadata_only" + assert gen.input[0].parts[0].text == "" + finally: + client.shutdown() + + def test_resolver_returning_wrong_type_fails_closed(self): + """A resolver returning a plain string instead of ContentCaptureMode should fail closed.""" + exporter = CapturingGenerationExporter() + client = _new_client( + exporter, + content_capture=ContentCaptureMode.FULL, + content_capture_resolver=lambda _meta: "metadata_only", + ) + try: + rec = client.start_generation(_seed()) + rec.set_result( + Generation( + input=[Message(role=MessageRole.USER, parts=[Part(kind=PartKind.TEXT, text="hello")])], + output=[Message(role=MessageRole.ASSISTANT, parts=[Part(kind=PartKind.TEXT, text="world")])], + usage=TokenUsage(input_tokens=10, output_tokens=5), + ) + ) + rec.end() + gen = rec.last_generation + assert gen.metadata[_METADATA_KEY] == "metadata_only" + assert gen.input[0].parts[0].text == "" + finally: + client.shutdown() + + def test_resolver_full_overrides_client_metadata_only(self): + exporter = CapturingGenerationExporter() + client = _new_client( + exporter, + content_capture=ContentCaptureMode.METADATA_ONLY, + content_capture_resolver=lambda _meta: ContentCaptureMode.FULL, + ) + try: + rec = client.start_generation(_seed()) + rec.set_result( + Generation( + input=[Message(role=MessageRole.USER, parts=[Part(kind=PartKind.TEXT, text="hello")])], + output=[Message(role=MessageRole.ASSISTANT, parts=[Part(kind=PartKind.TEXT, text="world")])], + usage=TokenUsage(input_tokens=10, output_tokens=5), + ) + ) + rec.end() + gen = rec.last_generation + assert gen.metadata[_METADATA_KEY] == "full" + assert gen.input[0].parts[0].text == "hello" + finally: + client.shutdown() + + +# --------------------------------------------------------------------------- +# Tool span content capture +# --------------------------------------------------------------------------- + + +class TestToolContentCapture: + def _make_tool_client(self, content_capture=ContentCaptureMode.DEFAULT, **kw): + span_exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(span_exporter)) + tracer = provider.get_tracer("sigil-test") + exporter = CapturingGenerationExporter() + client = _new_client(exporter, tracer=tracer, content_capture=content_capture, **kw) + return client, span_exporter, provider + + def _get_tool_span(self, span_exporter): + for span in span_exporter.get_finished_spans(): + if span.name.startswith("execute_tool"): + return span + raise AssertionError("tool span not found") + + def test_client_full_includes_content(self): + client, span_exporter, provider = self._make_tool_client(ContentCaptureMode.FULL) + try: + with client.start_tool_execution(ToolExecutionStart(tool_name="test_tool", include_content=False)) as rec: + rec.set_result(arguments="args", result="result") + + span = self._get_tool_span(span_exporter) + assert span.attributes.get("gen_ai.tool.call.arguments") is not None + finally: + client.shutdown() + provider.shutdown() + + def test_client_metadata_only_suppresses_content(self): + client, span_exporter, provider = self._make_tool_client(ContentCaptureMode.METADATA_ONLY) + try: + with client.start_tool_execution(ToolExecutionStart(tool_name="test_tool", include_content=True)) as rec: + rec.set_result(arguments="args", result="result") + + span = self._get_tool_span(span_exporter) + assert span.attributes.get("gen_ai.tool.call.arguments") is None + finally: + client.shutdown() + provider.shutdown() + + def test_client_default_legacy_false_suppresses(self): + client, span_exporter, provider = self._make_tool_client(ContentCaptureMode.DEFAULT) + try: + with client.start_tool_execution(ToolExecutionStart(tool_name="test_tool", include_content=False)) as rec: + rec.set_result(arguments="args", result="result") + + span = self._get_tool_span(span_exporter) + assert span.attributes.get("gen_ai.tool.call.arguments") is None + finally: + client.shutdown() + provider.shutdown() + + def test_client_default_legacy_true_includes(self): + client, span_exporter, provider = self._make_tool_client(ContentCaptureMode.DEFAULT) + try: + with client.start_tool_execution(ToolExecutionStart(tool_name="test_tool", include_content=True)) as rec: + rec.set_result(arguments="args", result="result") + + span = self._get_tool_span(span_exporter) + assert span.attributes.get("gen_ai.tool.call.arguments") is not None + finally: + client.shutdown() + provider.shutdown() + + def test_per_tool_full_overrides_client_metadata_only(self): + client, span_exporter, provider = self._make_tool_client(ContentCaptureMode.METADATA_ONLY) + try: + with client.start_tool_execution( + ToolExecutionStart( + tool_name="test_tool", + content_capture=ContentCaptureMode.FULL, + include_content=True, + ) + ) as rec: + rec.set_result(arguments="args", result="result") + + span = self._get_tool_span(span_exporter) + assert span.attributes.get("gen_ai.tool.call.arguments") is not None + finally: + client.shutdown() + provider.shutdown() + + def test_per_tool_metadata_only_overrides_client_full(self): + client, span_exporter, provider = self._make_tool_client(ContentCaptureMode.FULL) + try: + with client.start_tool_execution( + ToolExecutionStart( + tool_name="test_tool", + content_capture=ContentCaptureMode.METADATA_ONLY, + include_content=True, + ) + ) as rec: + rec.set_result(arguments="args", result="result") + + span = self._get_tool_span(span_exporter) + assert span.attributes.get("gen_ai.tool.call.arguments") is None + finally: + client.shutdown() + provider.shutdown() + + def test_include_content_ignored_under_metadata_only(self): + client, span_exporter, provider = self._make_tool_client(ContentCaptureMode.METADATA_ONLY) + try: + with client.start_tool_execution(ToolExecutionStart(tool_name="test_tool", include_content=True)) as rec: + rec.set_result(arguments="args", result="result") + + span = self._get_tool_span(span_exporter) + assert span.attributes.get("gen_ai.tool.call.arguments") is None + finally: + client.shutdown() + provider.shutdown() + + def test_context_default_defers_to_client_full(self): + """with_content_capture_mode(DEFAULT) should fall through to client FULL, not suppress content.""" + client, span_exporter, provider = self._make_tool_client(ContentCaptureMode.FULL) + try: + with with_content_capture_mode(ContentCaptureMode.DEFAULT): + with client.start_tool_execution( + ToolExecutionStart(tool_name="test_tool", include_content=False) + ) as rec: + rec.set_result(arguments="args", result="result") + + span = self._get_tool_span(span_exporter) + assert span.attributes.get("gen_ai.tool.call.arguments") is not None + finally: + client.shutdown() + provider.shutdown() + + def test_context_default_defers_to_client_metadata_only(self): + """with_content_capture_mode(DEFAULT) should fall through to client METADATA_ONLY, not re-enable via legacy.""" + client, span_exporter, provider = self._make_tool_client(ContentCaptureMode.METADATA_ONLY) + try: + with with_content_capture_mode(ContentCaptureMode.DEFAULT): + with client.start_tool_execution( + ToolExecutionStart(tool_name="test_tool", include_content=True) + ) as rec: + rec.set_result(arguments="args", result="result") + + span = self._get_tool_span(span_exporter) + assert span.attributes.get("gen_ai.tool.call.arguments") is None + finally: + client.shutdown() + provider.shutdown() + + +# --------------------------------------------------------------------------- +# Context propagation (parent generation → child tool) +# --------------------------------------------------------------------------- + + +class TestContextPropagation: + def test_generation_context_manager_sets_content_capture_mode(self): + exporter = CapturingGenerationExporter() + span_exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(span_exporter)) + tracer = provider.get_tracer("sigil-test") + client = _new_client(exporter, tracer=tracer, content_capture=ContentCaptureMode.METADATA_ONLY) + + try: + with client.start_generation(_seed()) as gen_rec: + # Within the generation context, content capture mode should be set + mode = content_capture_mode_from_context() + assert mode == ContentCaptureMode.METADATA_ONLY + + # Tool execution within this context should inherit the mode + with client.start_tool_execution( + ToolExecutionStart(tool_name="test_tool", include_content=True) + ) as tool_rec: + tool_rec.set_result(arguments="args", result="result") + + gen_rec.set_result( + output=[Message(role=MessageRole.ASSISTANT, parts=[Part(kind=PartKind.TEXT, text="ok")])], + usage=TokenUsage(input_tokens=1, output_tokens=1), + ) + + # Tool span should NOT have content (inherited MetadataOnly suppresses) + for span in span_exporter.get_finished_spans(): + if span.name.startswith("execute_tool"): + assert span.attributes.get("gen_ai.tool.call.arguments") is None + break + else: + raise AssertionError("tool span not found") + finally: + client.shutdown() + provider.shutdown() + + def test_generation_full_context_allows_tool_content(self): + exporter = CapturingGenerationExporter() + span_exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(span_exporter)) + tracer = provider.get_tracer("sigil-test") + client = _new_client(exporter, tracer=tracer, content_capture=ContentCaptureMode.METADATA_ONLY) + + try: + # Per-generation override to FULL + with client.start_generation(_seed(ContentCaptureMode.FULL)) as gen_rec: + with client.start_tool_execution( + ToolExecutionStart(tool_name="test_tool", include_content=True) + ) as tool_rec: + tool_rec.set_result(arguments="args", result="result") + + gen_rec.set_result( + output=[Message(role=MessageRole.ASSISTANT, parts=[Part(kind=PartKind.TEXT, text="ok")])], + usage=TokenUsage(input_tokens=1, output_tokens=1), + ) + + for span in span_exporter.get_finished_spans(): + if span.name.startswith("execute_tool"): + assert span.attributes.get("gen_ai.tool.call.arguments") is not None + break + else: + raise AssertionError("tool span not found") + finally: + client.shutdown() + provider.shutdown() + + def test_with_content_capture_mode_context_manager(self): + assert content_capture_mode_from_context() is None + + with with_content_capture_mode(ContentCaptureMode.FULL): + assert content_capture_mode_from_context() == ContentCaptureMode.FULL + + assert content_capture_mode_from_context() is None + + def test_recorder_inside_with_content_capture_mode_preserves_override(self): + """A recorder starting and ending inside with_content_capture_mode must not clobber the user override.""" + exporter = CapturingGenerationExporter() + span_exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(span_exporter)) + tracer = provider.get_tracer("sigil-test") + client = _new_client(exporter, tracer=tracer, content_capture=ContentCaptureMode.METADATA_ONLY) + + try: + with with_content_capture_mode(ContentCaptureMode.FULL): + assert content_capture_mode_from_context() == ContentCaptureMode.FULL + + with client.start_generation(_seed()) as gen_rec: + gen_rec.set_result( + output=[Message(role=MessageRole.ASSISTANT, parts=[Part(kind=PartKind.TEXT, text="ok")])], + usage=TokenUsage(input_tokens=1, output_tokens=1), + ) + + # After the recorder ends, the context override must still be FULL + assert content_capture_mode_from_context() == ContentCaptureMode.FULL + + # Tool execution here should still see FULL from the context + with client.start_tool_execution( + ToolExecutionStart(tool_name="test_tool", include_content=False) + ) as tool_rec: + tool_rec.set_result(arguments="args", result="result") + + for span in span_exporter.get_finished_spans(): + if span.name.startswith("execute_tool"): + assert span.attributes.get("gen_ai.tool.call.arguments") is not None + break + else: + raise AssertionError("tool span not found") + + assert content_capture_mode_from_context() is None + finally: + client.shutdown() + provider.shutdown() + + def test_with_content_capture_mode_overrides_generation_to_metadata_only(self): + """with_content_capture_mode(METADATA_ONLY) should strip generation content even when client is FULL.""" + exporter = CapturingGenerationExporter() + client = _new_client(exporter, content_capture=ContentCaptureMode.FULL) + try: + with with_content_capture_mode(ContentCaptureMode.METADATA_ONLY): + rec = client.start_generation(_seed()) + rec.set_result( + Generation( + system_prompt="secret system prompt", + input=[Message(role=MessageRole.USER, parts=[Part(kind=PartKind.TEXT, text="Hello")])], + output=[Message(role=MessageRole.ASSISTANT, parts=[Part(kind=PartKind.TEXT, text="Hi")])], + usage=TokenUsage(input_tokens=10, output_tokens=5), + ) + ) + rec.end() + assert rec.err() is None + + gen = rec.last_generation + assert gen.metadata[_METADATA_KEY] == "metadata_only" + assert gen.system_prompt == "" + assert gen.input[0].parts[0].text == "" + assert gen.output[0].parts[0].text == "" + finally: + client.shutdown() + + def test_with_content_capture_mode_overrides_generation_to_full(self): + """with_content_capture_mode(FULL) should preserve generation content even when client is METADATA_ONLY.""" + exporter = CapturingGenerationExporter() + client = _new_client(exporter, content_capture=ContentCaptureMode.METADATA_ONLY) + try: + with with_content_capture_mode(ContentCaptureMode.FULL): + rec = client.start_generation(_seed()) + rec.set_result( + output=[Message(role=MessageRole.ASSISTANT, parts=[Part(kind=PartKind.TEXT, text="Hello")])], + usage=TokenUsage(input_tokens=10, output_tokens=5), + ) + rec.end() + assert rec.err() is None + + gen = rec.last_generation + assert gen.metadata[_METADATA_KEY] == "full" + assert gen.output[0].parts[0].text == "Hello" + finally: + client.shutdown() + + def test_per_recording_override_takes_priority_over_context(self): + """GenerationStart.content_capture should override with_content_capture_mode.""" + exporter = CapturingGenerationExporter() + client = _new_client(exporter, content_capture=ContentCaptureMode.FULL) + try: + with with_content_capture_mode(ContentCaptureMode.FULL): + rec = client.start_generation(_seed(ContentCaptureMode.METADATA_ONLY)) + rec.set_result( + Generation( + system_prompt="secret", + input=[Message(role=MessageRole.USER, parts=[Part(kind=PartKind.TEXT, text="Hello")])], + output=[Message(role=MessageRole.ASSISTANT, parts=[Part(kind=PartKind.TEXT, text="Hi")])], + usage=TokenUsage(input_tokens=10, output_tokens=5), + ) + ) + rec.end() + assert rec.err() is None + + gen = rec.last_generation + assert gen.metadata[_METADATA_KEY] == "metadata_only" + assert gen.system_prompt == "" + finally: + client.shutdown() + + +# --------------------------------------------------------------------------- +# Validation accepts stripped content +# --------------------------------------------------------------------------- + + +class TestValidationWithStrippedContent: + def test_validation_accepts_stripped_generation(self): + exporter = CapturingGenerationExporter() + client = _new_client(exporter, content_capture=ContentCaptureMode.METADATA_ONLY) + try: + rec = client.start_generation(_seed()) + rec.set_result(_full_generation()) + rec.end() + # No validation error + assert rec.err() is None + finally: + client.shutdown() + + def test_validation_accepts_stripped_text_and_thinking(self): + """Directly test that validate_generation accepts empty text/thinking when metadata marker is set.""" + gen = Generation( + model=ModelRef(provider="anthropic", name="claude-sonnet-4-5"), + input=[Message(role=MessageRole.USER, parts=[Part(kind=PartKind.TEXT, text="")])], + output=[ + Message( + role=MessageRole.ASSISTANT, + parts=[ + Part(kind=PartKind.THINKING, thinking=""), + Part(kind=PartKind.TEXT, text=""), + ], + ) + ], + usage=TokenUsage(input_tokens=1, output_tokens=1), + metadata={_METADATA_KEY: "metadata_only"}, + ) + # Should not raise + validate_generation(gen) + + def test_validation_rejects_empty_text_without_stripped_marker(self): + gen = Generation( + model=ModelRef(provider="anthropic", name="claude-sonnet-4-5"), + input=[Message(role=MessageRole.USER, parts=[Part(kind=PartKind.TEXT, text="")])], + output=[Message(role=MessageRole.ASSISTANT, parts=[Part(kind=PartKind.TEXT, text="ok")])], + usage=TokenUsage(input_tokens=1, output_tokens=1), + metadata={}, + ) + with pytest.raises(ValueError): + validate_generation(gen) + + +# --------------------------------------------------------------------------- +# Backward compatibility: include_content +# --------------------------------------------------------------------------- + + +class TestBackwardCompatibility: + def test_include_content_still_works_without_content_capture(self): + """When no ContentCaptureMode is set, include_content=True should still include content.""" + span_exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(span_exporter)) + tracer = provider.get_tracer("sigil-test") + exporter = CapturingGenerationExporter() + client = _new_client(exporter, tracer=tracer) + + try: + with client.start_tool_execution(ToolExecutionStart(tool_name="test_tool", include_content=True)) as rec: + rec.set_result(arguments="some args", result="some result") + + span = None + for s in span_exporter.get_finished_spans(): + if s.name.startswith("execute_tool"): + span = s + break + assert span is not None + assert span.attributes.get("gen_ai.tool.call.arguments") is not None + assert span.attributes.get("gen_ai.tool.call.result") is not None + finally: + client.shutdown() + provider.shutdown() + + def test_include_content_false_without_content_capture(self): + """Default client + include_content=False → content suppressed.""" + span_exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(span_exporter)) + tracer = provider.get_tracer("sigil-test") + exporter = CapturingGenerationExporter() + client = _new_client(exporter, tracer=tracer) + + try: + with client.start_tool_execution(ToolExecutionStart(tool_name="test_tool", include_content=False)) as rec: + rec.set_result(arguments="some args", result="some result") + + span = None + for s in span_exporter.get_finished_spans(): + if s.name.startswith("execute_tool"): + span = s + break + assert span is not None + assert span.attributes.get("gen_ai.tool.call.arguments") is None + finally: + client.shutdown() + provider.shutdown() + + +# --------------------------------------------------------------------------- +# Rating comment stripping +# --------------------------------------------------------------------------- + + +class TestRatingCommentStripping: + def _make_rating_handler(self, captured): + class _Handler(BaseHTTPRequestHandler): + def do_POST(self): # noqa: N802 + length = int(self.headers.get("Content-Length", "0")) + body = self.rfile.read(length) + captured["payload"] = json.loads(body.decode("utf-8")) + + response = { + "rating": { + "rating_id": "rat-1", + "conversation_id": "conv-1", + "rating": "CONVERSATION_RATING_VALUE_BAD", + "created_at": "2026-04-10T12:00:00Z", + }, + "summary": { + "total_count": 1, + "good_count": 0, + "bad_count": 1, + "latest_rating": "CONVERSATION_RATING_VALUE_BAD", + "latest_rated_at": "2026-04-10T12:00:00Z", + "has_bad_rating": True, + }, + } + encoded = json.dumps(response).encode("utf-8") + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(encoded))) + self.end_headers() + self.wfile.write(encoded) + + def log_message(self, _format, *_args): # noqa: A003 + return + + return _Handler + + def test_metadata_only_strips_rating_comment(self): + captured: dict = {} + handler = self._make_rating_handler(captured) + server = HTTPServer(("127.0.0.1", 0), handler) + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + + client = Client( + ClientConfig( + content_capture=ContentCaptureMode.METADATA_ONLY, + generation_export=GenerationExportConfig( + protocol="none", + batch_size=1, + flush_interval=timedelta(seconds=60), + ), + api=ApiConfig(endpoint=f"http://127.0.0.1:{server.server_address[1]}"), + ) + ) + + try: + client.submit_conversation_rating( + "conv-1", + ConversationRatingInput( + rating_id="rat-1", + rating=ConversationRatingValue.BAD, + comment="this is sensitive feedback", + ), + ) + # Comment should have been stripped before sending + assert "comment" not in captured["payload"] or captured["payload"].get("comment", "") == "" + finally: + client.shutdown() + server.shutdown() + server.server_close() + + def test_full_mode_preserves_rating_comment(self): + captured: dict = {} + handler = self._make_rating_handler(captured) + server = HTTPServer(("127.0.0.1", 0), handler) + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + + client = Client( + ClientConfig( + content_capture=ContentCaptureMode.FULL, + generation_export=GenerationExportConfig( + protocol="none", + batch_size=1, + flush_interval=timedelta(seconds=60), + ), + api=ApiConfig(endpoint=f"http://127.0.0.1:{server.server_address[1]}"), + ) + ) + + try: + client.submit_conversation_rating( + "conv-1", + ConversationRatingInput( + rating_id="rat-1", + rating=ConversationRatingValue.BAD, + comment="this should be preserved", + ), + ) + assert captured["payload"]["comment"] == "this should be preserved" + finally: + client.shutdown() + server.shutdown() + server.server_close()