Skip to content

Commit f109c0c

Browse files
temporal-spring-ai: preserve Usage and RateLimit in ChatResponse metadata (#2854)
* temporal-spring-ai: plan — preserve ChatResponse metadata (Usage + RateLimit) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * temporal-spring-ai: add ResponseMetadataTest (fails pre-impl) Asserts that Usage (prompt/completion/total tokens) and RateLimit (requests/tokens limit/remaining/reset) round-trip from a stub ChatModel's ChatResponseMetadata through the chat activity and back to workflow code. The workflow flattens to primitives because Usage and RateLimit are interfaces and can't Jackson-round-trip across the workflow result without concrete-type hints. Currently fails with token counts of 0 (Spring AI's EmptyUsage sentinel) because ActivityChatModel.toResponse only rehydrates md.getModel() — Usage and RateLimit are dropped. The implementation follows in a subsequent commit. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * temporal-spring-ai: preserve Usage and RateLimit in ChatResponse metadata ActivityChatModel.toResponse now rehydrates Usage and RateLimit onto the ChatResponseMetadata it returns to workflow code, not just the model name. The activity side (ChatModelActivityImpl) already serialized these into the output record; they were being silently discarded when the workflow side rebuilt the ChatResponse. Usage is rehydrated as a Spring AI DefaultUsage(promptTokens, completionTokens, totalTokens). RateLimit is an interface with no public default impl in spring-ai-model, so we return an anonymous implementation backed by the fields from the activity output record. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * temporal-spring-ai: drop PLAN.md Planning scratchpad — not part of the shipped artifact. Removed before merge. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 64d5b2e commit f109c0c

2 files changed

Lines changed: 256 additions & 1 deletion

File tree

temporal-spring-ai/src/main/java/io/temporal/springai/model/ActivityChatModel.java

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
import javax.annotation.Nullable;
1414
import org.springframework.ai.chat.messages.*;
1515
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
16+
import org.springframework.ai.chat.metadata.DefaultUsage;
17+
import org.springframework.ai.chat.metadata.RateLimit;
18+
import org.springframework.ai.chat.metadata.Usage;
1619
import org.springframework.ai.chat.model.ChatModel;
1720
import org.springframework.ai.chat.model.ChatResponse;
1821
import org.springframework.ai.chat.model.Generation;
@@ -387,11 +390,70 @@ private ChatResponse toResponse(ChatModelTypes.ChatModelActivityOutput output) {
387390

388391
var builder = ChatResponse.builder().generations(generations);
389392
if (output.metadata() != null) {
390-
builder.metadata(ChatResponseMetadata.builder().model(output.metadata().model()).build());
393+
builder.metadata(toResponseMetadata(output.metadata()));
391394
}
392395
return builder.build();
393396
}
394397

398+
private ChatResponseMetadata toResponseMetadata(
399+
ChatModelTypes.ChatModelActivityOutput.ChatResponseMetadata md) {
400+
ChatResponseMetadata.Builder b = ChatResponseMetadata.builder().model(md.model());
401+
Usage usage = toUsage(md.usage());
402+
if (usage != null) {
403+
b.usage(usage);
404+
}
405+
RateLimit rateLimit = toRateLimit(md.rateLimit());
406+
if (rateLimit != null) {
407+
b.rateLimit(rateLimit);
408+
}
409+
return b.build();
410+
}
411+
412+
private Usage toUsage(ChatModelTypes.ChatModelActivityOutput.ChatResponseMetadata.Usage u) {
413+
if (u == null) {
414+
return null;
415+
}
416+
return new DefaultUsage(u.promptTokens(), u.completionTokens(), u.totalTokens());
417+
}
418+
419+
private RateLimit toRateLimit(
420+
ChatModelTypes.ChatModelActivityOutput.ChatResponseMetadata.RateLimit r) {
421+
if (r == null) {
422+
return null;
423+
}
424+
return new RateLimit() {
425+
@Override
426+
public Long getRequestsLimit() {
427+
return r.requestLimit();
428+
}
429+
430+
@Override
431+
public Long getRequestsRemaining() {
432+
return r.requestRemaining();
433+
}
434+
435+
@Override
436+
public java.time.Duration getRequestsReset() {
437+
return r.requestReset();
438+
}
439+
440+
@Override
441+
public Long getTokensLimit() {
442+
return r.tokenLimit();
443+
}
444+
445+
@Override
446+
public Long getTokensRemaining() {
447+
return r.tokenRemaining();
448+
}
449+
450+
@Override
451+
public java.time.Duration getTokensReset() {
452+
return r.tokenReset();
453+
}
454+
};
455+
}
456+
395457
private AssistantMessage toAssistantMessage(ChatModelTypes.Message message) {
396458
List<AssistantMessage.ToolCall> toolCalls = List.of();
397459
if (!CollectionUtils.isEmpty(message.toolCalls())) {
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
package io.temporal.springai;
2+
3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
5+
import io.temporal.client.WorkflowClient;
6+
import io.temporal.client.WorkflowOptions;
7+
import io.temporal.springai.activity.ChatModelActivityImpl;
8+
import io.temporal.springai.model.ActivityChatModel;
9+
import io.temporal.testing.TestWorkflowEnvironment;
10+
import io.temporal.worker.Worker;
11+
import io.temporal.workflow.WorkflowInterface;
12+
import io.temporal.workflow.WorkflowMethod;
13+
import java.time.Duration;
14+
import java.util.List;
15+
import org.junit.jupiter.api.AfterEach;
16+
import org.junit.jupiter.api.BeforeEach;
17+
import org.junit.jupiter.api.Test;
18+
import org.springframework.ai.chat.messages.AssistantMessage;
19+
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
20+
import org.springframework.ai.chat.metadata.DefaultUsage;
21+
import org.springframework.ai.chat.metadata.RateLimit;
22+
import org.springframework.ai.chat.metadata.Usage;
23+
import org.springframework.ai.chat.model.ChatModel;
24+
import org.springframework.ai.chat.model.ChatResponse;
25+
import org.springframework.ai.chat.model.Generation;
26+
import org.springframework.ai.chat.prompt.Prompt;
27+
28+
/**
29+
* Verifies that {@link Usage} and {@link RateLimit} metadata produced by the underlying chat model
30+
* survive the round-trip through the Temporal activity boundary.
31+
*/
32+
class ResponseMetadataTest {
33+
34+
private static final String TASK_QUEUE = "test-spring-ai-response-metadata";
35+
36+
private TestWorkflowEnvironment testEnv;
37+
private WorkflowClient client;
38+
39+
@BeforeEach
40+
void setUp() {
41+
testEnv = TestWorkflowEnvironment.newInstance();
42+
client = testEnv.getWorkflowClient();
43+
}
44+
45+
@AfterEach
46+
void tearDown() {
47+
testEnv.close();
48+
}
49+
50+
@Test
51+
void usageAndRateLimit_survivesActivityRoundTrip() {
52+
Worker worker = testEnv.newWorker(TASK_QUEUE);
53+
worker.registerWorkflowImplementationTypes(MetadataWorkflowImpl.class);
54+
worker.registerActivitiesImplementations(new ChatModelActivityImpl(new MetadataChatModel()));
55+
testEnv.start();
56+
57+
MetadataWorkflow workflow =
58+
client.newWorkflowStub(
59+
MetadataWorkflow.class, WorkflowOptions.newBuilder().setTaskQueue(TASK_QUEUE).build());
60+
61+
MetadataSnapshot snapshot = workflow.collect();
62+
63+
// Model name: was already round-tripping, keep asserting so we don't regress.
64+
assertEquals("stub-model-v1", snapshot.model());
65+
66+
// Usage: the prior code dropped this on the workflow side.
67+
assertEquals(
68+
Boolean.TRUE, snapshot.usagePresent(), "Usage should be rehydrated on the workflow side");
69+
assertEquals(10, snapshot.promptTokens());
70+
assertEquals(20, snapshot.completionTokens());
71+
assertEquals(30, snapshot.totalTokens());
72+
73+
// RateLimit: ditto.
74+
assertEquals(
75+
Boolean.TRUE,
76+
snapshot.rateLimitPresent(),
77+
"RateLimit should be rehydrated on the workflow side");
78+
assertEquals(1000L, snapshot.requestsLimit());
79+
assertEquals(987L, snapshot.requestsRemaining());
80+
assertEquals(Duration.ofSeconds(60), snapshot.requestsReset());
81+
assertEquals(500_000L, snapshot.tokensLimit());
82+
assertEquals(493_210L, snapshot.tokensRemaining());
83+
assertEquals(Duration.ofSeconds(30), snapshot.tokensReset());
84+
}
85+
86+
/**
87+
* Snapshot flattened to primitives/Strings — {@link Usage} and {@link RateLimit} are interfaces
88+
* and can't round-trip through the workflow-result serialization without extra type info, so the
89+
* workflow extracts the fields itself.
90+
*/
91+
public record MetadataSnapshot(
92+
String model,
93+
Boolean usagePresent,
94+
Integer promptTokens,
95+
Integer completionTokens,
96+
Integer totalTokens,
97+
Boolean rateLimitPresent,
98+
Long requestsLimit,
99+
Long requestsRemaining,
100+
Duration requestsReset,
101+
Long tokensLimit,
102+
Long tokensRemaining,
103+
Duration tokensReset) {}
104+
105+
@WorkflowInterface
106+
public interface MetadataWorkflow {
107+
@WorkflowMethod
108+
MetadataSnapshot collect();
109+
}
110+
111+
public static class MetadataWorkflowImpl implements MetadataWorkflow {
112+
@Override
113+
public MetadataSnapshot collect() {
114+
ActivityChatModel chatModel = ActivityChatModel.forDefault();
115+
ChatResponse response = chatModel.call(new Prompt("ping"));
116+
ChatResponseMetadata md = response.getMetadata();
117+
if (md == null) {
118+
return new MetadataSnapshot(
119+
null, false, null, null, null, false, null, null, null, null, null, null);
120+
}
121+
Usage u = md.getUsage();
122+
RateLimit r = md.getRateLimit();
123+
return new MetadataSnapshot(
124+
md.getModel(),
125+
u != null,
126+
u == null ? null : u.getPromptTokens(),
127+
u == null ? null : u.getCompletionTokens(),
128+
u == null ? null : u.getTotalTokens(),
129+
r != null,
130+
r == null ? null : r.getRequestsLimit(),
131+
r == null ? null : r.getRequestsRemaining(),
132+
r == null ? null : r.getRequestsReset(),
133+
r == null ? null : r.getTokensLimit(),
134+
r == null ? null : r.getTokensRemaining(),
135+
r == null ? null : r.getTokensReset());
136+
}
137+
}
138+
139+
/**
140+
* Returns a ChatResponse with a known model, Usage, and RateLimit so the test can assert them.
141+
*/
142+
private static class MetadataChatModel implements ChatModel {
143+
@Override
144+
public ChatResponse call(Prompt prompt) {
145+
ChatResponseMetadata md =
146+
ChatResponseMetadata.builder()
147+
.model("stub-model-v1")
148+
.usage(new DefaultUsage(10, 20, 30))
149+
.rateLimit(
150+
new RateLimit() {
151+
@Override
152+
public Long getRequestsLimit() {
153+
return 1000L;
154+
}
155+
156+
@Override
157+
public Long getRequestsRemaining() {
158+
return 987L;
159+
}
160+
161+
@Override
162+
public Duration getRequestsReset() {
163+
return Duration.ofSeconds(60);
164+
}
165+
166+
@Override
167+
public Long getTokensLimit() {
168+
return 500_000L;
169+
}
170+
171+
@Override
172+
public Long getTokensRemaining() {
173+
return 493_210L;
174+
}
175+
176+
@Override
177+
public Duration getTokensReset() {
178+
return Duration.ofSeconds(30);
179+
}
180+
})
181+
.build();
182+
return ChatResponse.builder()
183+
.generations(List.of(new Generation(new AssistantMessage("pong"))))
184+
.metadata(md)
185+
.build();
186+
}
187+
188+
@Override
189+
public reactor.core.publisher.Flux<ChatResponse> stream(Prompt prompt) {
190+
throw new UnsupportedOperationException();
191+
}
192+
}
193+
}

0 commit comments

Comments
 (0)