Skip to content

Commit 99c65cb

Browse files
committed
Correctly adapt oneOf unions nested in oneOfUnions
1 parent d0b8e6c commit 99c65cb

File tree

3 files changed

+235
-7
lines changed

3 files changed

+235
-7
lines changed

mcp/mcp-server/src/it/java/software/amazon/smithy/java/mcp/server/McpServerIntegrationTest.java

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import static org.junit.jupiter.api.Assertions.assertEquals;
99
import static org.junit.jupiter.api.Assertions.assertFalse;
10+
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
1011
import static org.junit.jupiter.api.Assertions.assertNotNull;
1112
import static org.junit.jupiter.api.Assertions.assertNull;
1213
import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively;
@@ -48,6 +49,7 @@
4849
import software.amazon.smithy.java.mcp.test.model.McpEchoInput;
4950
import software.amazon.smithy.java.mcp.test.model.McpEchoOutput;
5051
import software.amazon.smithy.java.mcp.test.model.Shape;
52+
import software.amazon.smithy.java.mcp.test.model.TestUnion;
5153
import software.amazon.smithy.java.mcp.test.service.CalculateAreaOperation;
5254
import software.amazon.smithy.java.mcp.test.service.McpEchoOperation;
5355
import software.amazon.smithy.java.mcp.test.service.TestService;
@@ -1022,6 +1024,170 @@ void testUnionWithNestedOption() {
10221024
assertEquals(77, nested.getMember("innerNumber").asNumber().intValue());
10231025
}
10241026

1027+
@Test
1028+
void testUnionInList() {
1029+
initializeLatestProtocol();
1030+
var union1 = Document.of(Map.of("stringOption", Document.of("first")));
1031+
var union2 = Document.of(Map.of("integerOption", Document.of(42)));
1032+
var union3 = Document.of(Map.of("nestedOption",
1033+
Document.of(Map.of(
1034+
"innerString",
1035+
Document.of("nested"),
1036+
"innerNumber",
1037+
Document.of(99)))));
1038+
var unionList = Document.of(List.of(union1, union2, union3));
1039+
1040+
var echo = echoSingleField("unionList", unionList);
1041+
var resultList = echo.getMember("unionList").asList();
1042+
1043+
assertEquals(3, resultList.size());
1044+
assertEquals("first", resultList.get(0).getMember("stringOption").asString());
1045+
assertEquals(42, resultList.get(1).getMember("integerOption").asNumber().intValue());
1046+
assertEquals("nested", resultList.get(2).getMember("nestedOption").getMember("innerString").asString());
1047+
1048+
// Verify server correctly deserialized unions inside list
1049+
var lastInputUnionList = echoOperation.getLastInput().getEcho().getUnionList();
1050+
assertEquals(3, lastInputUnionList.size());
1051+
assertInstanceOf(TestUnion.StringOptionMember.class, lastInputUnionList.get(0));
1052+
assertEquals("first", ((TestUnion.StringOptionMember) lastInputUnionList.get(0)).stringOption());
1053+
assertInstanceOf(TestUnion.IntegerOptionMember.class, lastInputUnionList.get(1));
1054+
assertEquals(42, ((TestUnion.IntegerOptionMember) lastInputUnionList.get(1)).integerOption());
1055+
assertInstanceOf(TestUnion.NestedOptionMember.class, lastInputUnionList.get(2));
1056+
assertEquals("nested",
1057+
((TestUnion.NestedOptionMember) lastInputUnionList.get(2)).nestedOption().getInnerString());
1058+
}
1059+
1060+
@Test
1061+
void testUnionInMap() {
1062+
initializeLatestProtocol();
1063+
var union1 = Document.of(Map.of("stringOption", Document.of("value1")));
1064+
var union2 = Document.of(Map.of("integerOption", Document.of(123)));
1065+
var unionMap = Document.of(Map.of(
1066+
"key1",
1067+
union1,
1068+
"key2",
1069+
union2));
1070+
1071+
var echo = echoSingleField("unionMap", unionMap);
1072+
var resultMap = echo.getMember("unionMap").asStringMap();
1073+
1074+
assertEquals(2, resultMap.size());
1075+
assertEquals("value1", resultMap.get("key1").getMember("stringOption").asString());
1076+
assertEquals(123, resultMap.get("key2").getMember("integerOption").asNumber().intValue());
1077+
1078+
// Verify server correctly deserialized unions inside map
1079+
var lastInputUnionMap = echoOperation.getLastInput().getEcho().getUnionMap();
1080+
assertEquals(2, lastInputUnionMap.size());
1081+
assertInstanceOf(TestUnion.StringOptionMember.class, lastInputUnionMap.get("key1"));
1082+
assertEquals("value1", ((TestUnion.StringOptionMember) lastInputUnionMap.get("key1")).stringOption());
1083+
assertInstanceOf(TestUnion.IntegerOptionMember.class, lastInputUnionMap.get("key2"));
1084+
assertEquals(123, ((TestUnion.IntegerOptionMember) lastInputUnionMap.get("key2")).integerOption());
1085+
}
1086+
1087+
// ========== @oneOf Document in Collections Tests ==========
1088+
1089+
@Test
1090+
void testOneOfDocumentInList() {
1091+
initializeLatestProtocol();
1092+
// Input is in MCP wrapper format: {"circle": {"radius": 5}}
1093+
var shape1 = Document.of(Map.of("circle", Document.of(Map.of("radius", Document.of(10)))));
1094+
var shape2 = Document.of(Map.of("square", Document.of(Map.of("side", Document.of(20)))));
1095+
var shapeList = Document.of(List.of(shape1, shape2));
1096+
1097+
var echo = echoSingleField("shapeWithOneOfList", shapeList);
1098+
var resultList = echo.getMember("shapeWithOneOfList").asList();
1099+
1100+
// Output should be back in MCP wrapper format
1101+
assertEquals(2, resultList.size());
1102+
assertEquals(10, resultList.get(0).getMember("circle").getMember("radius").asNumber().intValue());
1103+
assertEquals(20, resultList.get(1).getMember("square").getMember("side").asNumber().intValue());
1104+
1105+
// Verify server received correctly transformed input (discriminated format internally)
1106+
var lastInputList = echoOperation.getLastInput().getEcho().getShapeWithOneOfList();
1107+
assertEquals(2, lastInputList.size());
1108+
// The internal format should have __type discriminator
1109+
assertEquals("smithy.java.mcp.test#Circle", lastInputList.get(0).getMember("__type").asString());
1110+
assertEquals(10, lastInputList.get(0).getMember("radius").asNumber().intValue());
1111+
assertEquals("smithy.java.mcp.test#Square", lastInputList.get(1).getMember("__type").asString());
1112+
assertEquals(20, lastInputList.get(1).getMember("side").asNumber().intValue());
1113+
}
1114+
1115+
@Test
1116+
void testOneOfDocumentInMap() {
1117+
initializeLatestProtocol();
1118+
var shape1 = Document.of(Map.of("circle", Document.of(Map.of("radius", Document.of(15)))));
1119+
var shape2 = Document.of(Map.of("rectangle",
1120+
Document.of(Map.of(
1121+
"length",
1122+
Document.of(30),
1123+
"breadth",
1124+
Document.of(40)))));
1125+
var shapeMap = Document.of(Map.of("key1", shape1, "key2", shape2));
1126+
1127+
var echo = echoSingleField("shapeWithOneOfMap", shapeMap);
1128+
var resultMap = echo.getMember("shapeWithOneOfMap").asStringMap();
1129+
1130+
// Output should be back in MCP wrapper format
1131+
assertEquals(2, resultMap.size());
1132+
assertEquals(15, resultMap.get("key1").getMember("circle").getMember("radius").asNumber().intValue());
1133+
assertEquals(30, resultMap.get("key2").getMember("rectangle").getMember("length").asNumber().intValue());
1134+
1135+
// Verify server received correctly transformed input (discriminated format internally)
1136+
var lastInputMap = echoOperation.getLastInput().getEcho().getShapeWithOneOfMap();
1137+
assertEquals(2, lastInputMap.size());
1138+
assertEquals("smithy.java.mcp.test#Circle", lastInputMap.get("key1").getMember("__type").asString());
1139+
assertEquals(15, lastInputMap.get("key1").getMember("radius").asNumber().intValue());
1140+
assertEquals("smithy.java.mcp.test#Rectangle", lastInputMap.get("key2").getMember("__type").asString());
1141+
assertEquals(30, lastInputMap.get("key2").getMember("length").asNumber().intValue());
1142+
}
1143+
1144+
@Test
1145+
void testNestedOneOfDocumentInList() {
1146+
initializeLatestProtocol();
1147+
1148+
// Circle with nested shapes (list of @oneOf documents)
1149+
var nestedShape1 = Document.of(Map.of("square", Document.of(Map.of("side", Document.of(10)))));
1150+
var nestedShape2 = Document.of(Map.of("rectangle",
1151+
Document.of(Map.of(
1152+
"length",
1153+
Document.of(20),
1154+
"breadth",
1155+
Document.of(30)))));
1156+
var nestedShapeList = List.of(nestedShape1, nestedShape2);
1157+
1158+
var circleWithNested = Document.of(Map.of("circleWithNested",
1159+
Document.of(Map.of(
1160+
"radius",
1161+
Document.of(5),
1162+
"nestedShapes",
1163+
Document.of(nestedShapeList)))));
1164+
1165+
var echo = echoSingleField("nestedShapeWithOneOf", circleWithNested);
1166+
1167+
// Verify output is back in MCP wrapper format
1168+
var result = echo.getMember("nestedShapeWithOneOf");
1169+
var circleData = result.getMember("circleWithNested");
1170+
assertEquals(5, circleData.getMember("radius").asNumber().intValue());
1171+
1172+
var resultNestedList = circleData.getMember("nestedShapes").asList();
1173+
assertEquals(2, resultNestedList.size());
1174+
assertEquals(10, resultNestedList.get(0).getMember("square").getMember("side").asNumber().intValue());
1175+
assertEquals(20, resultNestedList.get(1).getMember("rectangle").getMember("length").asNumber().intValue());
1176+
1177+
// Verify server received correctly transformed input (discriminated format)
1178+
var lastInput = echoOperation.getLastInput().getEcho().getNestedShapeWithOneOf();
1179+
assertEquals("smithy.java.mcp.test#CircleWithNested", lastInput.getMember("__type").asString());
1180+
assertEquals(5, lastInput.getMember("radius").asNumber().intValue());
1181+
1182+
// Verify nested shapes were also transformed to discriminated format
1183+
var lastNestedList = lastInput.getMember("nestedShapes").asList();
1184+
assertEquals(2, lastNestedList.size());
1185+
assertEquals("smithy.java.mcp.test#Square", lastNestedList.get(0).getMember("__type").asString());
1186+
assertEquals(10, lastNestedList.get(0).getMember("side").asNumber().intValue());
1187+
assertEquals("smithy.java.mcp.test#Rectangle", lastNestedList.get(1).getMember("__type").asString());
1188+
assertEquals(20, lastNestedList.get(1).getMember("length").asNumber().intValue());
1189+
}
1190+
10251191
// ========== Input Deserialization Verification Tests ==========
10261192

10271193
@Test

mcp/mcp-server/src/it/resources/META-INF/smithy/main.smithy

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,28 @@ union Shape {
5353
])
5454
document ShapeWithOneOf
5555

56+
/// Circle structure with optional nested shapes for testing recursive @oneOf adaptation
57+
structure CircleWithNested {
58+
@required
59+
radius : Integer
60+
61+
/// List of nested shapes (for testing recursive @oneOf document adaptation)
62+
nestedShapes: ShapeWithOneOfList
63+
}
64+
65+
/// @oneOf document with nested list of @oneOf documents
66+
@oneOf(discriminator: "__type", members: [
67+
{name: "circleWithNested", target: CircleWithNested},
68+
{name: "square", target: Square},
69+
{name: "rectangle", target: Rectangle}
70+
])
71+
document NestedShapeWithOneOf
72+
73+
/// List of nested @oneOf documents
74+
list NestedShapeWithOneOfList {
75+
member: NestedShapeWithOneOf
76+
}
77+
5678
structure Circle {
5779
@required
5880
radius : Integer
@@ -126,6 +148,21 @@ structure Echo {
126148
// Union type
127149
unionValue: TestUnion
128150

151+
// Union in collections (for testing nested union adaptation)
152+
unionList: UnionList
153+
unionMap: UnionMap
154+
155+
// @oneOf document in collections (for testing Document-based polymorphic types)
156+
shapeWithOneOfList: ShapeWithOneOfList
157+
shapeWithOneOfMap: ShapeWithOneOfMap
158+
159+
// Nested @oneOf documents (for testing recursive adaptation)
160+
nestedShapeWithOneOf: NestedShapeWithOneOf
161+
nestedShapeWithOneOfList: NestedShapeWithOneOfList
162+
163+
// Helper to make CircleWithNested reachable for schema generation
164+
circleWithNested: CircleWithNested
165+
129166
// Required field to test required validation
130167
@required
131168
requiredField: String
@@ -189,3 +226,25 @@ union TestUnion {
189226
integerOption: Integer
190227
nestedOption: NestedEcho
191228
}
229+
230+
/// List of unions for testing nested union adaptation
231+
list UnionList {
232+
member: TestUnion
233+
}
234+
235+
/// Map of unions for testing nested union adaptation
236+
map UnionMap {
237+
key: String
238+
value: TestUnion
239+
}
240+
241+
/// List of @oneOf documents for testing Document-based polymorphic types in collections
242+
list ShapeWithOneOfList {
243+
member: ShapeWithOneOf
244+
}
245+
246+
/// Map of @oneOf documents for testing Document-based polymorphic types in collections
247+
map ShapeWithOneOfMap {
248+
key: String
249+
value: ShapeWithOneOf
250+
}

mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpService.java

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -763,7 +763,7 @@ private static String appendSentences(String first, String second) {
763763
return first + second;
764764
}
765765

766-
private static Document adaptDocument(Document doc, Schema schema) {
766+
private Document adaptDocument(Document doc, Schema schema) {
767767
if (doc == null) {
768768
return null;
769769
}
@@ -834,7 +834,7 @@ private static Document adaptDocument(Document doc, Schema schema) {
834834
};
835835
}
836836

837-
private static Document adaptDocumentWithOneOf(Document doc, Schema schema) {
837+
private Document adaptDocumentWithOneOf(Document doc, Schema schema) {
838838
var targetSchema = schema.isMember() ? schema.memberTarget() : schema;
839839
var oneOfTrait = targetSchema.getTrait(ONE_OF_TRAIT);
840840

@@ -850,9 +850,11 @@ private static Document adaptDocumentWithOneOf(Document doc, Schema schema) {
850850
if (memberDoc != null) {
851851
// Build the flat object with discriminator
852852
var flatMembers = new HashMap<String, Document>();
853-
flatMembers.put(discriminator, Document.of(memberDef.getTarget().toString()));
853+
var memberId = memberDef.getTarget();
854+
flatMembers.put(discriminator, Document.of(memberId.toString()));
854855
// Copy all fields from the inner object
855-
flatMembers.putAll(memberDoc.asStringMap());
856+
var memberSchema = schemaIndex.getSchema(memberId);
857+
flatMembers.putAll(adaptDocument(memberDoc, memberSchema).asStringMap());
856858
return Document.of(flatMembers);
857859
}
858860
}
@@ -943,13 +945,14 @@ private Document adaptOutputDocument(Document doc, Schema schema) {
943945
var discriminatorValue = doc.getMember(discriminator);
944946

945947
if (discriminatorValue != null) {
946-
var shapeIdStr = discriminatorValue.asString();
948+
var shapeId = ShapeId.from(discriminatorValue.asString());
947949
// Find the matching member definition
948950
for (var memberDef : oneOfTrait.getMembers()) {
949-
if (memberDef.getTarget().toString().equals(shapeIdStr)) {
951+
if (memberDef.getTarget().equals(shapeId)) {
950952
var memberName = memberDef.getName();
953+
var memberSchema = schemaIndex.getSchema(shapeId);
951954
// Build the inner object without the discriminator field
952-
var innerMembers = new HashMap<>(doc.asStringMap());
955+
var innerMembers = new HashMap<>(adaptOutputDocument(doc, memberSchema).asStringMap());
953956
innerMembers.remove(discriminator);
954957
// Return wrapper format
955958
yield Document.of(Map.of(memberName, Document.of(innerMembers)));

0 commit comments

Comments
 (0)