-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathhandler.py
More file actions
517 lines (424 loc) · 16.9 KB
/
handler.py
File metadata and controls
517 lines (424 loc) · 16.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
"""LiteLLM callback handler that exports generations to Sigil."""
from __future__ import annotations
import json
import logging
from datetime import datetime, timezone
from typing import Any
from litellm.integrations.custom_logger import CustomLogger
from sigil_sdk import Client
from sigil_sdk.models import (
Generation,
GenerationMode,
GenerationStart,
Message,
MessageRole,
ModelRef,
Part,
PartKind,
TokenUsage,
ToolCall,
ToolDefinition,
ToolResult,
)
logger = logging.getLogger(__name__)
_CHAT_CALL_TYPES = frozenset(
{
"completion",
"acompletion",
"text_completion",
"atext_completion",
}
)
def _make_tool_call_part(*, call_id: str, name: str, arguments: str) -> Part:
"""Build a Sigil TOOL_CALL Part from normalized arguments."""
return Part(
kind=PartKind.TOOL_CALL,
tool_call=ToolCall(
id=call_id,
name=name,
input_json=arguments.encode("utf-8"),
),
)
def _map_messages(messages: list[dict[str, Any]] | None) -> tuple[list[Message], str]:
"""Map OpenAI-format messages to Sigil Messages, extracting system prompt."""
if not messages:
return [], ""
out: list[Message] = []
system_chunks: list[str] = []
for msg in messages:
role = (msg.get("role") or "").lower()
content = _extract_text_content(msg.get("content"))
if role in {"system", "developer"}:
if content:
system_chunks.append(content)
continue
mapped_role = MessageRole.USER
if role == "assistant":
mapped_role = MessageRole.ASSISTANT
elif role == "tool":
mapped_role = MessageRole.TOOL
parts: list[Part] = []
if mapped_role == MessageRole.TOOL:
out.append(
_tool_result_message(
content=content,
tool_call_id=msg.get("tool_call_id", ""),
name=msg.get("name", ""),
)
)
continue
if content:
parts.append(Part(kind=PartKind.TEXT, text=content))
if mapped_role == MessageRole.ASSISTANT:
parts.extend(_map_tool_call_parts(msg.get("tool_calls")))
if not parts:
continue
out.append(Message(role=mapped_role, parts=parts))
return out, "\n\n".join(system_chunks)
def _map_tool_call_parts(tool_calls: list[dict[str, Any]] | None) -> list[Part]:
"""Map OpenAI-format tool_calls to Sigil ToolCall parts."""
if not tool_calls:
return []
out: list[Part] = []
for tc in tool_calls:
function = tc.get("function") if isinstance(tc, dict) else getattr(tc, "function", None)
if function is None:
continue
name = function.get("name", "") if isinstance(function, dict) else getattr(function, "name", "")
if not name:
continue
arguments = function.get("arguments", "") if isinstance(function, dict) else getattr(function, "arguments", "")
call_id = tc.get("id", "") if isinstance(tc, dict) else getattr(tc, "id", "")
out.append(_make_tool_call_part(call_id=call_id or "", name=name, arguments=arguments or ""))
return out
def _tool_result_message(*, content: str, tool_call_id: str, name: str) -> Message:
"""Create a Sigil tool result Message."""
return Message(
role=MessageRole.TOOL,
parts=[
Part(
kind=PartKind.TOOL_RESULT,
tool_result=ToolResult(
tool_call_id=tool_call_id,
name=name,
content=content,
),
)
],
)
def _map_response_output(response: Any) -> list[Message]:
"""Map SLO response to Sigil output Messages.
Reads from the StandardLoggingPayload ``response`` field (dict or str)
so that LiteLLM redaction settings are honoured.
"""
if response is None:
return []
if isinstance(response, str):
if not response:
return []
return [Message(role=MessageRole.ASSISTANT, parts=[Part(kind=PartKind.TEXT, text=response)])]
if not isinstance(response, dict):
return []
choices = response.get("choices")
if not choices:
return []
out: list[Message] = []
for choice in choices:
if not isinstance(choice, dict):
continue
response_message = choice.get("message")
if not isinstance(response_message, dict):
continue
content = response_message.get("content") or ""
parts: list[Part] = []
if content:
parts.append(Part(kind=PartKind.TEXT, text=content))
parts.extend(_map_tool_call_parts(response_message.get("tool_calls")))
if not parts:
continue
out.append(Message(role=MessageRole.ASSISTANT, parts=parts))
return out
def _extract_text_content(content: Any) -> str:
"""Extract text from OpenAI message content (string or content parts array)."""
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
texts = []
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
texts.append(item.get("text", ""))
elif isinstance(item, str):
texts.append(item)
return " ".join(texts)
return str(content)
def _epoch_to_utc(epoch: float | None) -> datetime | None:
"""Convert epoch seconds to UTC datetime."""
if epoch is None or epoch == 0:
return None
return datetime.fromtimestamp(epoch, tz=timezone.utc)
def _datetime_to_utc(dt: datetime | None) -> datetime | None:
"""Ensure a datetime is UTC.
Naive datetimes are assumed to be local time (matching datetime.now()
which LiteLLM uses to create start_time/end_time).
"""
if dt is None:
return None
return dt.astimezone(timezone.utc)
def _extract_stop_reason(response: Any) -> str:
"""Extract finish_reason from the SLO response dict."""
if not isinstance(response, dict):
return ""
choices = response.get("choices")
if not choices:
return ""
first_choice = choices[0]
if not isinstance(first_choice, dict):
return ""
return first_choice.get("finish_reason") or ""
def _map_tool_definitions(kwargs: dict[str, Any]) -> list[ToolDefinition]:
"""Extract tool schemas from optional_params."""
optional_params = kwargs.get("optional_params") or {}
tools = optional_params.get("tools")
if not tools or not isinstance(tools, list):
return []
out: list[ToolDefinition] = []
for tool in tools:
if not isinstance(tool, dict):
continue
tool_type = tool.get("type", "")
function = tool.get("function") or {}
name = function.get("name", "")
if not name:
continue
description = function.get("description", "")
parameters = function.get("parameters")
schema_json = json.dumps(parameters).encode("utf-8") if parameters else b""
out.append(
ToolDefinition(
name=name,
description=description,
type=tool_type,
input_schema_json=schema_json,
)
)
return out
def _safe_cast(params: dict[str, Any], key: str, cast: type) -> Any:
"""Safely cast a model parameter, returning None on missing or invalid values."""
if key not in params:
return None
try:
return cast(params[key])
except (ValueError, TypeError):
return None
def _extract_detailed_usage(response_obj: Any, slo: dict[str, Any]) -> TokenUsage:
"""Build TokenUsage with detailed breakdowns from response_obj, basic counts from SLO."""
usage = TokenUsage(
input_tokens=slo.get("prompt_tokens") or 0,
output_tokens=slo.get("completion_tokens") or 0,
total_tokens=slo.get("total_tokens") or 0,
)
if response_obj is None:
return usage
resp_usage = getattr(response_obj, "usage", None)
if resp_usage is None:
return usage
prompt_details = getattr(resp_usage, "prompt_tokens_details", None)
if prompt_details is not None:
cached = getattr(prompt_details, "cached_tokens", None)
if cached is not None and isinstance(cached, int):
usage.cache_read_input_tokens = cached
cache_creation = getattr(prompt_details, "cache_creation_tokens", None)
if cache_creation is not None and isinstance(cache_creation, int):
usage.cache_creation_input_tokens = cache_creation
completion_details = getattr(resp_usage, "completion_tokens_details", None)
if completion_details is not None:
reasoning = getattr(completion_details, "reasoning_tokens", None)
if reasoning is not None and isinstance(reasoning, int):
usage.reasoning_tokens = reasoning
return usage
class SigilLiteLLMLogger(CustomLogger):
"""LiteLLM callback logger that exports generations to Sigil.
Uses the Sigil SDK recorder pattern directly. The SDK handles
batching and export internally, so this extends CustomLogger
(not CustomBatchLogger) to avoid double-batching.
"""
def __init__(
self,
*,
client: Client,
capture_inputs: bool = True,
capture_outputs: bool = True,
agent_name: str = "",
agent_version: str = "",
conversation_id: str = "",
extra_tags: dict[str, str] | None = None,
extra_metadata: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self._client = client
self._capture_inputs = capture_inputs
self._capture_outputs = capture_outputs
self._agent_name = agent_name
self._agent_version = agent_version
self._conversation_id = conversation_id
self._extra_tags = dict(extra_tags) if extra_tags else {}
self._extra_metadata = dict(extra_metadata) if extra_metadata else {}
def log_success_event(self, kwargs: dict, response_obj: Any, start_time: datetime, end_time: datetime) -> None:
self._log_event(kwargs, response_obj, start_time, end_time, is_failure=False)
def log_failure_event(self, kwargs: dict, response_obj: Any, start_time: datetime, end_time: datetime) -> None:
self._log_event(kwargs, response_obj, start_time, end_time, is_failure=True)
async def async_log_success_event(
self, kwargs: dict, response_obj: Any, start_time: datetime, end_time: datetime
) -> None:
self._log_event(kwargs, response_obj, start_time, end_time, is_failure=False)
async def async_log_failure_event(
self, kwargs: dict, response_obj: Any, start_time: datetime, end_time: datetime
) -> None:
self._log_event(kwargs, response_obj, start_time, end_time, is_failure=True)
def _log_event(
self,
kwargs: dict,
response_obj: Any,
start_time: datetime,
end_time: datetime,
*,
is_failure: bool,
) -> None:
slo = kwargs.get("standard_logging_object")
if slo is None:
return
try:
self._record_generation(kwargs, response_obj, slo, start_time, end_time, is_failure=is_failure)
except Exception:
logger.exception("sigil: failed to record LiteLLM generation")
def _resolve_agent_name(self, kwargs: dict[str, Any]) -> str:
"""Resolve agent_name from per-request metadata, falling back to static."""
litellm_params = kwargs.get("litellm_params") or {}
metadata = litellm_params.get("metadata") or {}
value = metadata.get("agent_name")
if value:
return str(value)
return self._agent_name
def _resolve_agent_version(self, kwargs: dict[str, Any]) -> str:
"""Resolve agent_version from per-request metadata, falling back to static."""
litellm_params = kwargs.get("litellm_params") or {}
metadata = litellm_params.get("metadata") or {}
value = metadata.get("agent_version")
if value:
return str(value)
return self._agent_version
def _resolve_conversation_id(self, kwargs: dict[str, Any]) -> str:
"""Resolve conversation_id from per-request metadata, falling back to static.
Checks metadata keys first (conversation_id, session_id, thread_id),
then LiteLLM's built-in session tracking fields (litellm_session_id,
litellm_trace_id) in both metadata and litellm_params.
"""
litellm_params = kwargs.get("litellm_params") or {}
metadata = litellm_params.get("metadata") or {}
for key in ("conversation_id", "session_id", "thread_id"):
value = metadata.get(key)
if value:
return str(value)
for key in ("litellm_session_id", "litellm_trace_id"):
value = metadata.get(key) or litellm_params.get(key)
if value:
return str(value)
return self._conversation_id
def _record_generation(
self,
kwargs: dict[str, Any],
response_obj: Any,
slo: dict[str, Any],
start_time: datetime,
end_time: datetime,
*,
is_failure: bool,
) -> None:
call_type = slo.get("call_type") or ""
if call_type and call_type not in _CHAT_CALL_TYPES:
return
is_stream = bool(slo.get("stream"))
tags: dict[str, str] = {
"sigil.framework.name": "litellm",
"sigil.framework.source": "handler",
"sigil.framework.language": "python",
}
request_tags = slo.get("request_tags") or []
for tag_value in request_tags:
tag_str = str(tag_value)
tags[f"litellm.tag.{tag_str}"] = tag_str
# extra_tags take precedence
tags.update(self._extra_tags)
metadata: dict[str, Any] = dict(self._extra_metadata)
model_params = slo.get("model_parameters") or {}
temperature = _safe_cast(model_params, "temperature", float)
max_tokens = _safe_cast(model_params, "max_tokens", int)
top_p = _safe_cast(model_params, "top_p", float)
system_prompt = ""
input_messages: list[Message] = []
if self._capture_inputs:
raw_messages = slo.get("messages")
if isinstance(raw_messages, list):
input_messages, system_prompt = _map_messages(raw_messages)
provider = (slo.get("custom_llm_provider") or "").lower()
model_name = slo.get("model") or ""
gen_id = slo.get("id") or ""
user_id = slo.get("end_user") or ""
conversation_id = self._resolve_conversation_id(kwargs)
started_at = _datetime_to_utc(start_time)
tools = _map_tool_definitions(kwargs)
seed = GenerationStart(
id=gen_id,
model=ModelRef(provider=provider, name=model_name),
mode=GenerationMode.STREAM if is_stream else GenerationMode.SYNC,
system_prompt=system_prompt,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
user_id=user_id,
agent_name=self._resolve_agent_name(kwargs),
agent_version=self._resolve_agent_version(kwargs),
conversation_id=conversation_id,
tags=tags,
metadata=metadata,
started_at=started_at,
tools=tools,
)
if is_stream:
recorder = self._client.start_streaming_generation(seed)
else:
recorder = self._client.start_generation(seed)
try:
if is_stream:
completion_start = slo.get("completionStartTime")
if completion_start:
first_token_at = _epoch_to_utc(float(completion_start))
if first_token_at is not None:
recorder.set_first_token_at(first_token_at)
if is_failure:
error_str = slo.get("error_str") or ""
if error_str:
recorder.set_call_error(RuntimeError(error_str))
slo_response = slo.get("response")
output_messages: list[Message] = []
if self._capture_outputs:
output_messages = _map_response_output(slo_response)
usage = _extract_detailed_usage(response_obj, slo)
stop_reason = _extract_stop_reason(slo_response)
recorder.set_result(
generation=Generation(
input=input_messages,
output=output_messages,
usage=usage,
stop_reason=stop_reason,
completed_at=_datetime_to_utc(end_time),
),
)
finally:
recorder.end()
err = recorder.err()
if err is not None:
logger.warning("sigil: recorder error: %s", err)