Skip to content

Commit 4fab18f

Browse files
authored
feat!: Split track_metrics_of into sync and async variants (#112)
feat: Add optional graph_key to all LDAIConfigTracker track_* methods for graph correlation feat: Add track_tool_call/track_tool_calls to LDAIConfigTracker fix: make AIGraphTracker.track_total_tokens accept Optional[TokenUsage], skip when None or total <= 0 feat: Add get_tool_calls_from_response and sum_token_usage_from_messages to langchain_helper feat: Add get_ai_usage_from_response to openai_helper fix!: Remove node-scoped methods from AIGraphTracker (track_node_invocation, track_tool_call, track_node_judge_response), use related AIConfigTracker methods instead fix: use time.perf_counter_ns() for sub-millisecond precision in duration calculations
1 parent 9b6f06a commit 4fab18f

File tree

11 files changed

+437
-134
lines changed

11 files changed

+437
-134
lines changed

packages/ai-providers/server-ai-langchain/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ provider = await LangChainProvider.create(config)
138138
async def invoke():
139139
return await provider.invoke_model(messages)
140140

141-
response = await config.tracker.track_metrics_of(
141+
response = await config.tracker.track_metrics_of_async(
142142
invoke,
143143
lambda r: r.metrics
144144
)

packages/ai-providers/server-ai-langchain/src/ldai_langchain/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
create_langchain_model,
44
get_ai_metrics_from_response,
55
get_ai_usage_from_response,
6+
get_tool_calls_from_response,
67
map_provider,
8+
sum_token_usage_from_messages,
79
)
810
from ldai_langchain.langchain_model_runner import LangChainModelRunner
911
from ldai_langchain.langchain_runner_factory import LangChainRunnerFactory
@@ -18,5 +20,7 @@
1820
'create_langchain_model',
1921
'get_ai_metrics_from_response',
2022
'get_ai_usage_from_response',
23+
'get_tool_calls_from_response',
2124
'map_provider',
25+
'sum_token_usage_from_messages',
2226
]

packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_helper.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,41 @@ def get_ai_metrics_from_response(response: Any) -> LDAIMetrics:
115115
:return: LDAIMetrics with success status and token usage
116116
"""
117117
return LDAIMetrics(success=True, usage=get_ai_usage_from_response(response))
118+
119+
120+
def get_tool_calls_from_response(response: Any) -> List[str]:
121+
"""
122+
Get tool call names from a LangChain provider response.
123+
124+
:param response: The response from the LangChain model
125+
:return: List of tool names in order, or empty list if none
126+
"""
127+
names: List[str] = []
128+
if hasattr(response, 'tool_calls') and isinstance(response.tool_calls, list):
129+
for tc in response.tool_calls:
130+
n = tc.get('name')
131+
if n:
132+
names.append(str(n))
133+
return names
134+
135+
136+
def sum_token_usage_from_messages(messages: List[Any]) -> Optional[TokenUsage]:
137+
"""
138+
Sum token usage across LangChain messages using get_ai_usage_from_response per message.
139+
140+
:param messages: List of message objects (e.g. from a graph state)
141+
:return: Aggregated TokenUsage, or None if no usage on any message
142+
"""
143+
in_sum = 0
144+
out_sum = 0
145+
total_sum = 0
146+
for m in messages:
147+
u = get_ai_usage_from_response(m)
148+
if u is None:
149+
continue
150+
in_sum += u.input
151+
out_sum += u.output
152+
total_sum += u.total
153+
if in_sum == 0 and out_sum == 0 and total_sum == 0:
154+
return None
155+
return TokenUsage(total=total_sum, input=in_sum, output=out_sum)

packages/ai-providers/server-ai-langchain/tests/test_langchain_provider.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,15 @@
77

88
from ldai import LDMessage
99

10-
from ldai_langchain import LangChainModelRunner, LangChainRunnerFactory, convert_messages_to_langchain, get_ai_metrics_from_response, map_provider
10+
from ldai_langchain import (
11+
LangChainModelRunner,
12+
LangChainRunnerFactory,
13+
convert_messages_to_langchain,
14+
get_ai_metrics_from_response,
15+
get_tool_calls_from_response,
16+
map_provider,
17+
sum_token_usage_from_messages,
18+
)
1119

1220

1321
class TestConvertMessages:
@@ -237,6 +245,82 @@ async def test_returns_success_false_when_structured_model_invocation_throws_err
237245
assert result.metrics.usage is None
238246

239247

248+
class TestGetToolCallsFromResponse:
249+
"""Tests for get_tool_calls_from_response."""
250+
251+
def test_returns_tool_call_names_in_order(self):
252+
"""Should return tool call names from response.tool_calls."""
253+
mock_response = MagicMock()
254+
mock_response.tool_calls = [
255+
{'name': 'search', 'args': {}},
256+
{'name': 'calculator', 'args': {}},
257+
]
258+
assert get_tool_calls_from_response(mock_response) == ['search', 'calculator']
259+
260+
def test_returns_empty_list_when_tool_calls_is_empty(self):
261+
"""Should return empty list when tool_calls is an empty list."""
262+
mock_response = MagicMock()
263+
mock_response.tool_calls = []
264+
assert get_tool_calls_from_response(mock_response) == []
265+
266+
def test_returns_empty_list_when_no_tool_calls_attribute(self):
267+
"""Should return empty list when response has no tool_calls attribute."""
268+
mock_response = MagicMock(spec=[])
269+
assert get_tool_calls_from_response(mock_response) == []
270+
271+
def test_returns_empty_list_when_tool_calls_is_not_a_list(self):
272+
"""Should return empty list when tool_calls is not a list."""
273+
mock_response = MagicMock()
274+
mock_response.tool_calls = 'not-a-list'
275+
assert get_tool_calls_from_response(mock_response) == []
276+
277+
def test_skips_tool_calls_without_name(self):
278+
"""Should skip tool calls that have no name."""
279+
mock_response = MagicMock()
280+
mock_response.tool_calls = [{'args': {}}, {'name': 'search', 'args': {}}]
281+
assert get_tool_calls_from_response(mock_response) == ['search']
282+
283+
284+
class TestSumTokenUsageFromMessages:
285+
"""Tests for sum_token_usage_from_messages."""
286+
287+
def test_sums_usage_across_messages(self):
288+
"""Should sum token usage from all messages."""
289+
msg1 = AIMessage(content='a')
290+
msg1.usage_metadata = {'total_tokens': 10, 'input_tokens': 6, 'output_tokens': 4}
291+
msg2 = AIMessage(content='b')
292+
msg2.usage_metadata = {'total_tokens': 20, 'input_tokens': 12, 'output_tokens': 8}
293+
294+
result = sum_token_usage_from_messages([msg1, msg2])
295+
296+
assert result is not None
297+
assert result.total == 30
298+
assert result.input == 18
299+
assert result.output == 12
300+
301+
def test_returns_none_when_no_usage_on_any_message(self):
302+
"""Should return None when no message has usage metadata."""
303+
msg = AIMessage(content='hello')
304+
assert sum_token_usage_from_messages([msg]) is None
305+
306+
def test_returns_none_for_empty_list(self):
307+
"""Should return None for an empty message list."""
308+
assert sum_token_usage_from_messages([]) is None
309+
310+
def test_skips_messages_without_usage(self):
311+
"""Should skip messages that have no usage and sum the rest."""
312+
msg1 = AIMessage(content='a')
313+
msg2 = AIMessage(content='b')
314+
msg2.usage_metadata = {'total_tokens': 5, 'input_tokens': 3, 'output_tokens': 2}
315+
316+
result = sum_token_usage_from_messages([msg1, msg2])
317+
318+
assert result is not None
319+
assert result.total == 5
320+
assert result.input == 3
321+
assert result.output == 2
322+
323+
240324
class TestGetLlm:
241325
"""Tests for LangChainModelRunner.get_llm."""
242326

packages/sdk/server-ai/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,8 @@ async def main():
150150
# Create LangChain model from configuration
151151
llm = await LangChainProvider.create_langchain_model(ai_config)
152152

153-
# Use with tracking
154-
response = await ai_config.tracker.track_metrics_of(
153+
# Use with tracking (sync invoke)
154+
response = ai_config.tracker.track_metrics_of(
155155
lambda: llm.invoke(messages),
156156
lambda result: LangChainProvider.get_ai_metrics_from_response(result)
157157
)
@@ -190,7 +190,7 @@ async def main():
190190
temperature=ai_config.model.get_parameter('temperature') if ai_config.model else 0.5,
191191
)
192192

193-
result = await ai_config.tracker.track_metrics_of(
193+
result = await ai_config.tracker.track_metrics_of_async(
194194
call_custom_provider,
195195
map_custom_provider_metrics
196196
)

packages/sdk/server-ai/src/ldai/judge/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ async def evaluate(
7171
messages = self._construct_evaluation_messages(input_text, output_text)
7272
assert self._evaluation_response_structure is not None
7373

74-
response = await self._ai_config_tracker.track_metrics_of(
74+
response = await self._ai_config_tracker.track_metrics_of_async(
7575
lambda: self._model_runner.invoke_structured_model(messages, self._evaluation_response_structure),
7676
lambda result: result.metrics,
7777
)

packages/sdk/server-ai/src/ldai/managed_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ async def invoke(self, prompt: str) -> ModelResponse:
4848
config_messages = self._ai_config.messages or []
4949
all_messages = config_messages + self._messages
5050

51-
response = await self._tracker.track_metrics_of(
51+
response = await self._tracker.track_metrics_of_async(
5252
lambda: self._model_runner.invoke_model(all_messages),
5353
lambda result: result.metrics,
5454
)

packages/sdk/server-ai/src/ldai/providers/runner_factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def _with_fallback(
7777
continue
7878
result = fn(provider_factory)
7979
if result is not None:
80-
log.debug(f"Successfully created capability using provider '{provider_type}'")
80+
log.debug(f"Successfully invoked create function with provider '{provider_type}'")
8181
return result
8282
except Exception as exc:
8383
log.warning(f"Provider '{provider_type}' failed: {exc}")

0 commit comments

Comments
 (0)