Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from enum import Enum
from functools import wraps
from types import SimpleNamespace
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -512,7 +513,7 @@ def _finalize_sync_streaming_span(span: trace_api.Span, stream: CustomStreamWrap
)

if usage_stats:
_set_token_counts_from_usage(span, usage_stats)
_set_token_counts_from_usage(span, SimpleNamespace(usage=usage_stats))
except Exception as e:
span.record_exception(e)
raise
Expand Down Expand Up @@ -553,10 +554,12 @@ async def _finalize_streaming_span(span: trace_api.Span, stream: CustomStreamWra
span, f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.{idx}.{key}", value
)
if usage_stats:
_set_token_counts_from_usage(span, usage_stats)
_set_token_counts_from_usage(span, SimpleNamespace(usage=usage_stats))
except Exception as e:
span.record_exception(e)
raise
else:
_set_span_status(span, aggregated_output)
finally:
span.end()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,153 @@ async def test_acompletion(
)


@pytest.mark.parametrize("use_context_attributes", [False, True])
async def test_acompletion_stream(
in_memory_span_exporter: InMemorySpanExporter,
setup_litellm_instrumentation: Any,
use_context_attributes: bool,
session_id: str,
user_id: str,
metadata: Dict[str, Any],
tags: List[str],
prompt_template: str,
prompt_template_version: str,
prompt_template_variables: Dict[str, Any],
) -> None:
in_memory_span_exporter.clear()

input_messages = [{"content": "What's the capital of China?", "role": "user"}]
if use_context_attributes:
with using_attributes(
session_id=session_id,
user_id=user_id,
metadata=metadata,
tags=tags,
prompt_template=prompt_template,
prompt_template_version=prompt_template_version,
prompt_template_variables=prompt_template_variables,
):
response = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=input_messages,
mock_response="Beijing",
stream=True,
)
async for chunk in response:
print(chunk)
else:
response = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=input_messages,
mock_response="Beijing",
stream=True,
)
async for chunk in response:
print(chunk)

spans = in_memory_span_exporter.get_finished_spans()
assert len(spans) == 1
span = spans[0]
assert span.name == "acompletion"
attributes = dict(cast(Mapping[str, AttributeValue], span.attributes))
assert attributes.get(SpanAttributes.LLM_MODEL_NAME) == "gpt-3.5-turbo"
assert attributes.get(SpanAttributes.INPUT_VALUE) == safe_json_dumps(
{"messages": input_messages}
)
assert attributes.get(SpanAttributes.INPUT_MIME_TYPE) == "application/json"

assert "Beijing" == attributes.get(SpanAttributes.OUTPUT_VALUE)
assert span.status.status_code == StatusCode.OK

if use_context_attributes:
_check_context_attributes(
attributes,
session_id,
user_id,
metadata,
tags,
prompt_template,
prompt_template_version,
prompt_template_variables,
)


@pytest.mark.parametrize("use_context_attributes", [False, True])
async def test_acompletion_stream_token_count(
in_memory_span_exporter: InMemorySpanExporter,
setup_litellm_instrumentation: Any,
use_context_attributes: bool,
session_id: str,
user_id: str,
metadata: Dict[str, Any],
tags: List[str],
prompt_template: str,
prompt_template_version: str,
prompt_template_variables: Dict[str, Any],
) -> None:
in_memory_span_exporter.clear()

input_messages = [{"content": "What's the capital of China?", "role": "user"}]
if use_context_attributes:
with using_attributes(
session_id=session_id,
user_id=user_id,
metadata=metadata,
tags=tags,
prompt_template=prompt_template,
prompt_template_version=prompt_template_version,
prompt_template_variables=prompt_template_variables,
):
response = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=input_messages,
mock_response="Beijing",
stream=True,
stream_options={"include_usage": True},
)
async for chunk in response:
print(chunk)
else:
response = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=input_messages,
mock_response="Beijing",
stream=True,
stream_options={"include_usage": True},
)
async for chunk in response:
print(chunk)

spans = in_memory_span_exporter.get_finished_spans()
assert len(spans) == 1
span = spans[0]
assert span.name == "acompletion"
attributes = dict(cast(Mapping[str, AttributeValue], span.attributes))
assert attributes.get(SpanAttributes.LLM_MODEL_NAME) == "gpt-3.5-turbo"
assert attributes.get(SpanAttributes.INPUT_VALUE) == safe_json_dumps(
{"messages": input_messages}
)
assert attributes.get(SpanAttributes.INPUT_MIME_TYPE) == "application/json"

assert "Beijing" == attributes.get(SpanAttributes.OUTPUT_VALUE)
assert attributes.get(SpanAttributes.LLM_TOKEN_COUNT_PROMPT) == 14
assert attributes.get(SpanAttributes.LLM_TOKEN_COUNT_COMPLETION) == 2
assert attributes.get(SpanAttributes.LLM_TOKEN_COUNT_TOTAL) == 16
assert span.status.status_code == StatusCode.OK

if use_context_attributes:
_check_context_attributes(
attributes,
session_id,
user_id,
metadata,
tags,
prompt_template,
prompt_template_version,
prompt_template_variables,
)


async def test_acompletion_with_invalid_model_triggers_exception_event(
in_memory_span_exporter: InMemorySpanExporter,
setup_litellm_instrumentation: None,
Expand Down
Loading