Skip to content

Commit 3f656c4

Browse files
Merge branch 'master' into spring-ai/readme-and-naming
2 parents ffba7f4 + 87c519e commit 3f656c4

3 files changed

Lines changed: 138 additions & 10 deletions

File tree

temporal-spring-ai/build.gradle

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ dependencies {
4646
testImplementation "org.mockito:mockito-core:${mockitoVersion}"
4747
testImplementation 'org.springframework.boot:spring-boot-starter-test'
4848
testImplementation 'org.springframework.ai:spring-ai-rag'
49+
// Needed only so McpPluginTest can mock/reference McpSyncClient directly.
50+
testImplementation 'org.springframework.ai:spring-ai-mcp'
4951
// Needed only so tests can reference Spring AI's NonTransientAiException to
5052
// verify the plugin's default retry classification.
5153
testImplementation 'org.springframework.ai:spring-ai-retry'

temporal-spring-ai/src/main/java/io/temporal/springai/plugin/McpPlugin.java

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import io.temporal.worker.Worker;
77
import java.util.ArrayList;
88
import java.util.List;
9+
import java.util.Map;
910
import javax.annotation.Nonnull;
1011
import org.slf4j.Logger;
1112
import org.slf4j.LoggerFactory;
@@ -39,22 +40,22 @@ public void setApplicationContext(ApplicationContext applicationContext) throws
3940
this.applicationContext = applicationContext;
4041
}
4142

42-
@SuppressWarnings("unchecked")
4343
private List<McpSyncClient> getMcpClients() {
4444
if (!mcpClients.isEmpty()) {
4545
return mcpClients;
4646
}
47+
if (applicationContext == null) {
48+
return mcpClients;
49+
}
4750

48-
if (applicationContext != null && applicationContext.containsBean("mcpSyncClients")) {
49-
try {
50-
Object bean = applicationContext.getBean("mcpSyncClients");
51-
if (bean instanceof List<?> clientList && !clientList.isEmpty()) {
52-
mcpClients = (List<McpSyncClient>) clientList;
53-
log.info("Found {} MCP client(s) in ApplicationContext", mcpClients.size());
54-
}
55-
} catch (Exception e) {
56-
log.debug("Failed to get mcpSyncClients bean: {}", e.getMessage());
51+
try {
52+
Map<String, McpSyncClient> beans = applicationContext.getBeansOfType(McpSyncClient.class);
53+
if (!beans.isEmpty()) {
54+
mcpClients = List.copyOf(beans.values());
55+
log.info("Discovered {} MCP client bean(s): {}", beans.size(), beans.keySet());
5756
}
57+
} catch (Exception e) {
58+
log.debug("Failed to look up McpSyncClient beans: {}", e.getMessage());
5859
}
5960

6061
return mcpClients;
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
package io.temporal.springai.plugin;
2+
3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
import static org.junit.jupiter.api.Assertions.assertThrows;
5+
import static org.junit.jupiter.api.Assertions.assertTrue;
6+
import static org.mockito.ArgumentMatchers.any;
7+
import static org.mockito.Mockito.atLeastOnce;
8+
import static org.mockito.Mockito.mock;
9+
import static org.mockito.Mockito.verify;
10+
import static org.mockito.Mockito.verifyNoInteractions;
11+
import static org.mockito.Mockito.when;
12+
13+
import io.modelcontextprotocol.client.McpSyncClient;
14+
import io.modelcontextprotocol.spec.McpSchema;
15+
import io.temporal.springai.mcp.McpClientActivityImpl;
16+
import io.temporal.worker.Worker;
17+
import java.util.LinkedHashMap;
18+
import java.util.Map;
19+
import org.junit.jupiter.api.Test;
20+
import org.mockito.ArgumentCaptor;
21+
import org.springframework.context.ApplicationContext;
22+
23+
class McpPluginTest {
24+
25+
@Test
26+
void discoversMcpClientBeansByType() {
27+
McpSyncClient clientA = mockClientNamed("alpha");
28+
McpSyncClient clientB = mockClientNamed("beta");
29+
30+
// Spring's getBeansOfType keeps insertion order via LinkedHashMap; use that for determinism.
31+
Map<String, McpSyncClient> beans = new LinkedHashMap<>();
32+
beans.put("mcpClientAlpha", clientA);
33+
beans.put("mcpClientBeta", clientB);
34+
35+
ApplicationContext ctx = mock(ApplicationContext.class);
36+
when(ctx.getBeansOfType(McpSyncClient.class)).thenReturn(beans);
37+
38+
McpPlugin plugin = new McpPlugin();
39+
plugin.setApplicationContext(ctx);
40+
41+
Worker worker = mock(Worker.class);
42+
plugin.initializeWorker("mcp-tq", worker);
43+
44+
ArgumentCaptor<Object> captor = ArgumentCaptor.forClass(Object.class);
45+
verify(worker, atLeastOnce()).registerActivitiesImplementations(captor.capture());
46+
Object registered = captor.getValue();
47+
assertEquals(McpClientActivityImpl.class, registered.getClass());
48+
}
49+
50+
@Test
51+
void twoMcpBeans_duplicateClientInfoNames_throws() {
52+
// Two distinct beans that both report the same clientInfo().name() — the activity impl
53+
// has to reject this because it keys its internal client map by that name.
54+
McpSyncClient clientA = mockClientNamed("shared");
55+
McpSyncClient clientB = mockClientNamed("shared");
56+
57+
Map<String, McpSyncClient> beans = new LinkedHashMap<>();
58+
beans.put("mcpClientA", clientA);
59+
beans.put("mcpClientB", clientB);
60+
61+
ApplicationContext ctx = mock(ApplicationContext.class);
62+
when(ctx.getBeansOfType(McpSyncClient.class)).thenReturn(beans);
63+
64+
McpPlugin plugin = new McpPlugin();
65+
plugin.setApplicationContext(ctx);
66+
67+
IllegalArgumentException thrown =
68+
assertThrows(
69+
IllegalArgumentException.class,
70+
() -> plugin.initializeWorker("mcp-tq", mock(Worker.class)));
71+
assertTrue(
72+
thrown.getMessage().contains("shared"),
73+
"expected duplicate name in message, got: " + thrown.getMessage());
74+
}
75+
76+
@Test
77+
void noMcpBeans_defersWorker_thenClearsAfterSingletonsInstantiated() {
78+
ApplicationContext ctx = mock(ApplicationContext.class);
79+
when(ctx.getBeansOfType(McpSyncClient.class)).thenReturn(Map.of());
80+
81+
McpPlugin plugin = new McpPlugin();
82+
plugin.setApplicationContext(ctx);
83+
84+
Worker worker = mock(Worker.class);
85+
plugin.initializeWorker("mcp-tq", worker);
86+
87+
// No beans → nothing registered yet, worker queued for deferred attempt.
88+
verifyNoInteractions(worker);
89+
90+
plugin.afterSingletonsInstantiated();
91+
92+
// Still no beans — the deferred attempt also finds nothing and doesn't crash.
93+
verify(worker, org.mockito.Mockito.never()).registerActivitiesImplementations((Object[]) any());
94+
}
95+
96+
@Test
97+
void beansAppearLate_registeredViaAfterSingletonsInstantiated() {
98+
ApplicationContext ctx = mock(ApplicationContext.class);
99+
// First lookup returns empty (Spring AI MCP bean hasn't been created yet when
100+
// initializeWorker runs).
101+
when(ctx.getBeansOfType(McpSyncClient.class))
102+
.thenReturn(Map.of())
103+
.thenReturn(Map.of("mcpClient", mockClientNamed("late")));
104+
105+
McpPlugin plugin = new McpPlugin();
106+
plugin.setApplicationContext(ctx);
107+
108+
Worker worker = mock(Worker.class);
109+
plugin.initializeWorker("mcp-tq", worker);
110+
verifyNoInteractions(worker);
111+
112+
plugin.afterSingletonsInstantiated();
113+
114+
ArgumentCaptor<Object> captor = ArgumentCaptor.forClass(Object.class);
115+
verify(worker, atLeastOnce()).registerActivitiesImplementations(captor.capture());
116+
assertEquals(McpClientActivityImpl.class, captor.getValue().getClass());
117+
}
118+
119+
private static McpSyncClient mockClientNamed(String name) {
120+
McpSyncClient client = mock(McpSyncClient.class);
121+
McpSchema.Implementation info = new McpSchema.Implementation(name, "1.0.0");
122+
when(client.getClientInfo()).thenReturn(info);
123+
return client;
124+
}
125+
}

0 commit comments

Comments
 (0)