Skip to content

Commit d42e47b

Browse files
Merge branch 'master' into spring-ai/sideeffect-replay-tests
2 parents c3a8a15 + f109c0c commit d42e47b

2 files changed

Lines changed: 256 additions & 1 deletion

File tree

temporal-spring-ai/src/main/java/io/temporal/springai/model/ActivityChatModel.java

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
import javax.annotation.Nullable;
1414
import org.springframework.ai.chat.messages.*;
1515
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
16+
import org.springframework.ai.chat.metadata.DefaultUsage;
17+
import org.springframework.ai.chat.metadata.RateLimit;
18+
import org.springframework.ai.chat.metadata.Usage;
1619
import org.springframework.ai.chat.model.ChatModel;
1720
import org.springframework.ai.chat.model.ChatResponse;
1821
import org.springframework.ai.chat.model.Generation;
@@ -387,11 +390,70 @@ private ChatResponse toResponse(ChatModelTypes.ChatModelActivityOutput output) {
387390

388391
var builder = ChatResponse.builder().generations(generations);
389392
if (output.metadata() != null) {
390-
builder.metadata(ChatResponseMetadata.builder().model(output.metadata().model()).build());
393+
builder.metadata(toResponseMetadata(output.metadata()));
391394
}
392395
return builder.build();
393396
}
394397

398+
private ChatResponseMetadata toResponseMetadata(
399+
ChatModelTypes.ChatModelActivityOutput.ChatResponseMetadata md) {
400+
ChatResponseMetadata.Builder b = ChatResponseMetadata.builder().model(md.model());
401+
Usage usage = toUsage(md.usage());
402+
if (usage != null) {
403+
b.usage(usage);
404+
}
405+
RateLimit rateLimit = toRateLimit(md.rateLimit());
406+
if (rateLimit != null) {
407+
b.rateLimit(rateLimit);
408+
}
409+
return b.build();
410+
}
411+
412+
private Usage toUsage(ChatModelTypes.ChatModelActivityOutput.ChatResponseMetadata.Usage u) {
413+
if (u == null) {
414+
return null;
415+
}
416+
return new DefaultUsage(u.promptTokens(), u.completionTokens(), u.totalTokens());
417+
}
418+
419+
private RateLimit toRateLimit(
420+
ChatModelTypes.ChatModelActivityOutput.ChatResponseMetadata.RateLimit r) {
421+
if (r == null) {
422+
return null;
423+
}
424+
return new RateLimit() {
425+
@Override
426+
public Long getRequestsLimit() {
427+
return r.requestLimit();
428+
}
429+
430+
@Override
431+
public Long getRequestsRemaining() {
432+
return r.requestRemaining();
433+
}
434+
435+
@Override
436+
public java.time.Duration getRequestsReset() {
437+
return r.requestReset();
438+
}
439+
440+
@Override
441+
public Long getTokensLimit() {
442+
return r.tokenLimit();
443+
}
444+
445+
@Override
446+
public Long getTokensRemaining() {
447+
return r.tokenRemaining();
448+
}
449+
450+
@Override
451+
public java.time.Duration getTokensReset() {
452+
return r.tokenReset();
453+
}
454+
};
455+
}
456+
395457
private AssistantMessage toAssistantMessage(ChatModelTypes.Message message) {
396458
List<AssistantMessage.ToolCall> toolCalls = List.of();
397459
if (!CollectionUtils.isEmpty(message.toolCalls())) {
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
package io.temporal.springai;
2+
3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
5+
import io.temporal.client.WorkflowClient;
6+
import io.temporal.client.WorkflowOptions;
7+
import io.temporal.springai.activity.ChatModelActivityImpl;
8+
import io.temporal.springai.model.ActivityChatModel;
9+
import io.temporal.testing.TestWorkflowEnvironment;
10+
import io.temporal.worker.Worker;
11+
import io.temporal.workflow.WorkflowInterface;
12+
import io.temporal.workflow.WorkflowMethod;
13+
import java.time.Duration;
14+
import java.util.List;
15+
import org.junit.jupiter.api.AfterEach;
16+
import org.junit.jupiter.api.BeforeEach;
17+
import org.junit.jupiter.api.Test;
18+
import org.springframework.ai.chat.messages.AssistantMessage;
19+
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
20+
import org.springframework.ai.chat.metadata.DefaultUsage;
21+
import org.springframework.ai.chat.metadata.RateLimit;
22+
import org.springframework.ai.chat.metadata.Usage;
23+
import org.springframework.ai.chat.model.ChatModel;
24+
import org.springframework.ai.chat.model.ChatResponse;
25+
import org.springframework.ai.chat.model.Generation;
26+
import org.springframework.ai.chat.prompt.Prompt;
27+
28+
/**
29+
* Verifies that {@link Usage} and {@link RateLimit} metadata produced by the underlying chat model
30+
* survive the round-trip through the Temporal activity boundary.
31+
*/
32+
class ResponseMetadataTest {
33+
34+
private static final String TASK_QUEUE = "test-spring-ai-response-metadata";
35+
36+
private TestWorkflowEnvironment testEnv;
37+
private WorkflowClient client;
38+
39+
@BeforeEach
40+
void setUp() {
41+
testEnv = TestWorkflowEnvironment.newInstance();
42+
client = testEnv.getWorkflowClient();
43+
}
44+
45+
@AfterEach
46+
void tearDown() {
47+
testEnv.close();
48+
}
49+
50+
@Test
51+
void usageAndRateLimit_survivesActivityRoundTrip() {
52+
Worker worker = testEnv.newWorker(TASK_QUEUE);
53+
worker.registerWorkflowImplementationTypes(MetadataWorkflowImpl.class);
54+
worker.registerActivitiesImplementations(new ChatModelActivityImpl(new MetadataChatModel()));
55+
testEnv.start();
56+
57+
MetadataWorkflow workflow =
58+
client.newWorkflowStub(
59+
MetadataWorkflow.class, WorkflowOptions.newBuilder().setTaskQueue(TASK_QUEUE).build());
60+
61+
MetadataSnapshot snapshot = workflow.collect();
62+
63+
// Model name: was already round-tripping, keep asserting so we don't regress.
64+
assertEquals("stub-model-v1", snapshot.model());
65+
66+
// Usage: the prior code dropped this on the workflow side.
67+
assertEquals(
68+
Boolean.TRUE, snapshot.usagePresent(), "Usage should be rehydrated on the workflow side");
69+
assertEquals(10, snapshot.promptTokens());
70+
assertEquals(20, snapshot.completionTokens());
71+
assertEquals(30, snapshot.totalTokens());
72+
73+
// RateLimit: ditto.
74+
assertEquals(
75+
Boolean.TRUE,
76+
snapshot.rateLimitPresent(),
77+
"RateLimit should be rehydrated on the workflow side");
78+
assertEquals(1000L, snapshot.requestsLimit());
79+
assertEquals(987L, snapshot.requestsRemaining());
80+
assertEquals(Duration.ofSeconds(60), snapshot.requestsReset());
81+
assertEquals(500_000L, snapshot.tokensLimit());
82+
assertEquals(493_210L, snapshot.tokensRemaining());
83+
assertEquals(Duration.ofSeconds(30), snapshot.tokensReset());
84+
}
85+
86+
/**
87+
* Snapshot flattened to primitives/Strings — {@link Usage} and {@link RateLimit} are interfaces
88+
* and can't round-trip through the workflow-result serialization without extra type info, so the
89+
* workflow extracts the fields itself.
90+
*/
91+
public record MetadataSnapshot(
92+
String model,
93+
Boolean usagePresent,
94+
Integer promptTokens,
95+
Integer completionTokens,
96+
Integer totalTokens,
97+
Boolean rateLimitPresent,
98+
Long requestsLimit,
99+
Long requestsRemaining,
100+
Duration requestsReset,
101+
Long tokensLimit,
102+
Long tokensRemaining,
103+
Duration tokensReset) {}
104+
105+
@WorkflowInterface
106+
public interface MetadataWorkflow {
107+
@WorkflowMethod
108+
MetadataSnapshot collect();
109+
}
110+
111+
public static class MetadataWorkflowImpl implements MetadataWorkflow {
112+
@Override
113+
public MetadataSnapshot collect() {
114+
ActivityChatModel chatModel = ActivityChatModel.forDefault();
115+
ChatResponse response = chatModel.call(new Prompt("ping"));
116+
ChatResponseMetadata md = response.getMetadata();
117+
if (md == null) {
118+
return new MetadataSnapshot(
119+
null, false, null, null, null, false, null, null, null, null, null, null);
120+
}
121+
Usage u = md.getUsage();
122+
RateLimit r = md.getRateLimit();
123+
return new MetadataSnapshot(
124+
md.getModel(),
125+
u != null,
126+
u == null ? null : u.getPromptTokens(),
127+
u == null ? null : u.getCompletionTokens(),
128+
u == null ? null : u.getTotalTokens(),
129+
r != null,
130+
r == null ? null : r.getRequestsLimit(),
131+
r == null ? null : r.getRequestsRemaining(),
132+
r == null ? null : r.getRequestsReset(),
133+
r == null ? null : r.getTokensLimit(),
134+
r == null ? null : r.getTokensRemaining(),
135+
r == null ? null : r.getTokensReset());
136+
}
137+
}
138+
139+
/**
140+
* Returns a ChatResponse with a known model, Usage, and RateLimit so the test can assert them.
141+
*/
142+
private static class MetadataChatModel implements ChatModel {
143+
@Override
144+
public ChatResponse call(Prompt prompt) {
145+
ChatResponseMetadata md =
146+
ChatResponseMetadata.builder()
147+
.model("stub-model-v1")
148+
.usage(new DefaultUsage(10, 20, 30))
149+
.rateLimit(
150+
new RateLimit() {
151+
@Override
152+
public Long getRequestsLimit() {
153+
return 1000L;
154+
}
155+
156+
@Override
157+
public Long getRequestsRemaining() {
158+
return 987L;
159+
}
160+
161+
@Override
162+
public Duration getRequestsReset() {
163+
return Duration.ofSeconds(60);
164+
}
165+
166+
@Override
167+
public Long getTokensLimit() {
168+
return 500_000L;
169+
}
170+
171+
@Override
172+
public Long getTokensRemaining() {
173+
return 493_210L;
174+
}
175+
176+
@Override
177+
public Duration getTokensReset() {
178+
return Duration.ofSeconds(30);
179+
}
180+
})
181+
.build();
182+
return ChatResponse.builder()
183+
.generations(List.of(new Generation(new AssistantMessage("pong"))))
184+
.metadata(md)
185+
.build();
186+
}
187+
188+
@Override
189+
public reactor.core.publisher.Flux<ChatResponse> stream(Prompt prompt) {
190+
throw new UnsupportedOperationException();
191+
}
192+
}
193+
}

0 commit comments

Comments
 (0)