Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively;
Expand Down Expand Up @@ -48,6 +49,7 @@
import software.amazon.smithy.java.mcp.test.model.McpEchoInput;
import software.amazon.smithy.java.mcp.test.model.McpEchoOutput;
import software.amazon.smithy.java.mcp.test.model.Shape;
import software.amazon.smithy.java.mcp.test.model.TestUnion;
import software.amazon.smithy.java.mcp.test.service.CalculateAreaOperation;
import software.amazon.smithy.java.mcp.test.service.McpEchoOperation;
import software.amazon.smithy.java.mcp.test.service.TestService;
Expand Down Expand Up @@ -1022,6 +1024,170 @@ void testUnionWithNestedOption() {
assertEquals(77, nested.getMember("innerNumber").asNumber().intValue());
}

@Test
void testUnionInList() {
initializeLatestProtocol();
var union1 = Document.of(Map.of("stringOption", Document.of("first")));
var union2 = Document.of(Map.of("integerOption", Document.of(42)));
var union3 = Document.of(Map.of("nestedOption",
Document.of(Map.of(
"innerString",
Document.of("nested"),
"innerNumber",
Document.of(99)))));
var unionList = Document.of(List.of(union1, union2, union3));

var echo = echoSingleField("unionList", unionList);
var resultList = echo.getMember("unionList").asList();

assertEquals(3, resultList.size());
assertEquals("first", resultList.get(0).getMember("stringOption").asString());
assertEquals(42, resultList.get(1).getMember("integerOption").asNumber().intValue());
assertEquals("nested", resultList.get(2).getMember("nestedOption").getMember("innerString").asString());

// Verify server correctly deserialized unions inside list
var lastInputUnionList = echoOperation.getLastInput().getEcho().getUnionList();
assertEquals(3, lastInputUnionList.size());
assertInstanceOf(TestUnion.StringOptionMember.class, lastInputUnionList.get(0));
assertEquals("first", ((TestUnion.StringOptionMember) lastInputUnionList.get(0)).stringOption());
assertInstanceOf(TestUnion.IntegerOptionMember.class, lastInputUnionList.get(1));
assertEquals(42, ((TestUnion.IntegerOptionMember) lastInputUnionList.get(1)).integerOption());
assertInstanceOf(TestUnion.NestedOptionMember.class, lastInputUnionList.get(2));
assertEquals("nested",
((TestUnion.NestedOptionMember) lastInputUnionList.get(2)).nestedOption().getInnerString());
}

@Test
void testUnionInMap() {
initializeLatestProtocol();
var union1 = Document.of(Map.of("stringOption", Document.of("value1")));
var union2 = Document.of(Map.of("integerOption", Document.of(123)));
var unionMap = Document.of(Map.of(
"key1",
union1,
"key2",
union2));

var echo = echoSingleField("unionMap", unionMap);
var resultMap = echo.getMember("unionMap").asStringMap();

assertEquals(2, resultMap.size());
assertEquals("value1", resultMap.get("key1").getMember("stringOption").asString());
assertEquals(123, resultMap.get("key2").getMember("integerOption").asNumber().intValue());

// Verify server correctly deserialized unions inside map
var lastInputUnionMap = echoOperation.getLastInput().getEcho().getUnionMap();
assertEquals(2, lastInputUnionMap.size());
assertInstanceOf(TestUnion.StringOptionMember.class, lastInputUnionMap.get("key1"));
assertEquals("value1", ((TestUnion.StringOptionMember) lastInputUnionMap.get("key1")).stringOption());
assertInstanceOf(TestUnion.IntegerOptionMember.class, lastInputUnionMap.get("key2"));
assertEquals(123, ((TestUnion.IntegerOptionMember) lastInputUnionMap.get("key2")).integerOption());
}

// ========== @oneOf Document in Collections Tests ==========

@Test
void testOneOfDocumentInList() {
initializeLatestProtocol();
// Input is in MCP wrapper format: {"circle": {"radius": 5}}
var shape1 = Document.of(Map.of("circle", Document.of(Map.of("radius", Document.of(10)))));
var shape2 = Document.of(Map.of("square", Document.of(Map.of("side", Document.of(20)))));
var shapeList = Document.of(List.of(shape1, shape2));

var echo = echoSingleField("shapeWithOneOfList", shapeList);
var resultList = echo.getMember("shapeWithOneOfList").asList();

// Output should be back in MCP wrapper format
assertEquals(2, resultList.size());
assertEquals(10, resultList.get(0).getMember("circle").getMember("radius").asNumber().intValue());
assertEquals(20, resultList.get(1).getMember("square").getMember("side").asNumber().intValue());

// Verify server received correctly transformed input (discriminated format internally)
var lastInputList = echoOperation.getLastInput().getEcho().getShapeWithOneOfList();
assertEquals(2, lastInputList.size());
// The internal format should have __type discriminator
assertEquals("smithy.java.mcp.test#Circle", lastInputList.get(0).getMember("__type").asString());
assertEquals(10, lastInputList.get(0).getMember("radius").asNumber().intValue());
assertEquals("smithy.java.mcp.test#Square", lastInputList.get(1).getMember("__type").asString());
assertEquals(20, lastInputList.get(1).getMember("side").asNumber().intValue());
}

@Test
void testOneOfDocumentInMap() {
initializeLatestProtocol();
var shape1 = Document.of(Map.of("circle", Document.of(Map.of("radius", Document.of(15)))));
var shape2 = Document.of(Map.of("rectangle",
Document.of(Map.of(
"length",
Document.of(30),
"breadth",
Document.of(40)))));
var shapeMap = Document.of(Map.of("key1", shape1, "key2", shape2));

var echo = echoSingleField("shapeWithOneOfMap", shapeMap);
var resultMap = echo.getMember("shapeWithOneOfMap").asStringMap();

// Output should be back in MCP wrapper format
assertEquals(2, resultMap.size());
assertEquals(15, resultMap.get("key1").getMember("circle").getMember("radius").asNumber().intValue());
assertEquals(30, resultMap.get("key2").getMember("rectangle").getMember("length").asNumber().intValue());

// Verify server received correctly transformed input (discriminated format internally)
var lastInputMap = echoOperation.getLastInput().getEcho().getShapeWithOneOfMap();
assertEquals(2, lastInputMap.size());
assertEquals("smithy.java.mcp.test#Circle", lastInputMap.get("key1").getMember("__type").asString());
assertEquals(15, lastInputMap.get("key1").getMember("radius").asNumber().intValue());
assertEquals("smithy.java.mcp.test#Rectangle", lastInputMap.get("key2").getMember("__type").asString());
assertEquals(30, lastInputMap.get("key2").getMember("length").asNumber().intValue());
}

@Test
void testNestedOneOfDocumentInList() {
initializeLatestProtocol();

// Circle with nested shapes (list of @oneOf documents)
var nestedShape1 = Document.of(Map.of("square", Document.of(Map.of("side", Document.of(10)))));
var nestedShape2 = Document.of(Map.of("rectangle",
Document.of(Map.of(
"length",
Document.of(20),
"breadth",
Document.of(30)))));
var nestedShapeList = List.of(nestedShape1, nestedShape2);

var circleWithNested = Document.of(Map.of("circleWithNested",
Document.of(Map.of(
"radius",
Document.of(5),
"nestedShapes",
Document.of(nestedShapeList)))));

var echo = echoSingleField("nestedShapeWithOneOf", circleWithNested);

// Verify output is back in MCP wrapper format
var result = echo.getMember("nestedShapeWithOneOf");
var circleData = result.getMember("circleWithNested");
assertEquals(5, circleData.getMember("radius").asNumber().intValue());

var resultNestedList = circleData.getMember("nestedShapes").asList();
assertEquals(2, resultNestedList.size());
assertEquals(10, resultNestedList.get(0).getMember("square").getMember("side").asNumber().intValue());
assertEquals(20, resultNestedList.get(1).getMember("rectangle").getMember("length").asNumber().intValue());

// Verify server received correctly transformed input (discriminated format)
var lastInput = echoOperation.getLastInput().getEcho().getNestedShapeWithOneOf();
assertEquals("smithy.java.mcp.test#CircleWithNested", lastInput.getMember("__type").asString());
assertEquals(5, lastInput.getMember("radius").asNumber().intValue());

// Verify nested shapes were also transformed to discriminated format
var lastNestedList = lastInput.getMember("nestedShapes").asList();
assertEquals(2, lastNestedList.size());
assertEquals("smithy.java.mcp.test#Square", lastNestedList.get(0).getMember("__type").asString());
assertEquals(10, lastNestedList.get(0).getMember("side").asNumber().intValue());
assertEquals("smithy.java.mcp.test#Rectangle", lastNestedList.get(1).getMember("__type").asString());
assertEquals(20, lastNestedList.get(1).getMember("length").asNumber().intValue());
}

// ========== Input Deserialization Verification Tests ==========

@Test
Expand Down
59 changes: 59 additions & 0 deletions mcp/mcp-server/src/it/resources/META-INF/smithy/main.smithy
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,28 @@ union Shape {
])
document ShapeWithOneOf

/// Circle structure with optional nested shapes for testing recursive @oneOf adaptation
structure CircleWithNested {
@required
radius : Integer

/// List of nested shapes (for testing recursive @oneOf document adaptation)
nestedShapes: ShapeWithOneOfList
}

/// @oneOf document with nested list of @oneOf documents
@oneOf(discriminator: "__type", members: [
{name: "circleWithNested", target: CircleWithNested},
{name: "square", target: Square},
{name: "rectangle", target: Rectangle}
])
document NestedShapeWithOneOf

/// List of nested @oneOf documents
list NestedShapeWithOneOfList {
member: NestedShapeWithOneOf
}

structure Circle {
@required
radius : Integer
Expand Down Expand Up @@ -126,6 +148,21 @@ structure Echo {
// Union type
unionValue: TestUnion

// Union in collections (for testing nested union adaptation)
unionList: UnionList
unionMap: UnionMap

// @oneOf document in collections (for testing Document-based polymorphic types)
shapeWithOneOfList: ShapeWithOneOfList
shapeWithOneOfMap: ShapeWithOneOfMap

// Nested @oneOf documents (for testing recursive adaptation)
nestedShapeWithOneOf: NestedShapeWithOneOf
nestedShapeWithOneOfList: NestedShapeWithOneOfList

// Helper to make CircleWithNested reachable for schema generation
circleWithNested: CircleWithNested

// Required field to test required validation
@required
requiredField: String
Expand Down Expand Up @@ -189,3 +226,25 @@ union TestUnion {
integerOption: Integer
nestedOption: NestedEcho
}

/// List of unions for testing nested union adaptation
list UnionList {
member: TestUnion
}

/// Map of unions for testing nested union adaptation
map UnionMap {
key: String
value: TestUnion
}

/// List of @oneOf documents for testing Document-based polymorphic types in collections
list ShapeWithOneOfList {
member: ShapeWithOneOf
}

/// Map of @oneOf documents for testing Document-based polymorphic types in collections
map ShapeWithOneOfMap {
key: String
value: ShapeWithOneOf
}
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ private static String appendSentences(String first, String second) {
return first + second;
}

private static Document adaptDocument(Document doc, Schema schema) {
private Document adaptDocument(Document doc, Schema schema) {
if (doc == null) {
return null;
}
Expand Down Expand Up @@ -834,7 +834,7 @@ private static Document adaptDocument(Document doc, Schema schema) {
};
}

private static Document adaptDocumentWithOneOf(Document doc, Schema schema) {
private Document adaptDocumentWithOneOf(Document doc, Schema schema) {
var targetSchema = schema.isMember() ? schema.memberTarget() : schema;
var oneOfTrait = targetSchema.getTrait(ONE_OF_TRAIT);

Expand All @@ -850,9 +850,11 @@ private static Document adaptDocumentWithOneOf(Document doc, Schema schema) {
if (memberDoc != null) {
// Build the flat object with discriminator
var flatMembers = new HashMap<String, Document>();
flatMembers.put(discriminator, Document.of(memberDef.getTarget().toString()));
var memberId = memberDef.getTarget();
flatMembers.put(discriminator, Document.of(memberId.toString()));
// Copy all fields from the inner object
flatMembers.putAll(memberDoc.asStringMap());
var memberSchema = schemaIndex.getSchema(memberId);
flatMembers.putAll(adaptDocument(memberDoc, memberSchema).asStringMap());
return Document.of(flatMembers);
}
}
Expand Down Expand Up @@ -943,13 +945,14 @@ private Document adaptOutputDocument(Document doc, Schema schema) {
var discriminatorValue = doc.getMember(discriminator);

if (discriminatorValue != null) {
var shapeIdStr = discriminatorValue.asString();
var shapeId = ShapeId.from(discriminatorValue.asString());
// Find the matching member definition
for (var memberDef : oneOfTrait.getMembers()) {
if (memberDef.getTarget().toString().equals(shapeIdStr)) {
if (memberDef.getTarget().equals(shapeId)) {
var memberName = memberDef.getName();
var memberSchema = schemaIndex.getSchema(shapeId);
// Build the inner object without the discriminator field
var innerMembers = new HashMap<>(doc.asStringMap());
var innerMembers = new HashMap<>(adaptOutputDocument(doc, memberSchema).asStringMap());
innerMembers.remove(discriminator);
// Return wrapper format
yield Document.of(Map.of(memberName, Document.of(innerMembers)));
Expand Down