Skip to content

Commit 8f3e5f0

Browse files
committed
Rewrite AdtMemberTraitValidator
1 parent 4757ac3 commit 8f3e5f0

File tree

6 files changed

+208
-135
lines changed

6 files changed

+208
-135
lines changed

modules/protocol-tests/test/src/smithy4s/api/validation/AdtMemberTraitValidatorSpec.scala

+144-23
Original file line numberDiff line numberDiff line change
@@ -93,18 +93,71 @@ object AdtMemberTraitValidatorSpec extends FunSuite {
9393
val expected = List(
9494
ValidationEvent
9595
.builder()
96-
.id("AdtValidator")
96+
.id("AdtMemberTrait")
9797
.shape(struct)
9898
.severity(Severity.ERROR)
9999
.message(
100-
"test#MyUnion must have exactly one member targeting test#struct"
100+
"This shape must be referenced by test#MyUnion because of its smithy4s.meta#adtMember trait"
101101
)
102102
.build()
103103
)
104104
expect(result == expected)
105105
}
106106

107-
test("return error when structure is targeted by multiple unions") {
107+
test("return no error when there are duplicate non-adtMember members") {
108+
val unionShapeId = ShapeId.fromParts("test", "MyUnion")
109+
val adtTrait = new AdtMemberTrait(unionShapeId)
110+
val structMember = MemberShape
111+
.builder()
112+
.id("test#struct$testing")
113+
.target("smithy.api#String")
114+
.build()
115+
116+
val struct =
117+
StructureShape
118+
.builder()
119+
.id("test#struct")
120+
.addTrait(adtTrait)
121+
.addMember(structMember)
122+
.build()
123+
124+
val unionMember = MemberShape
125+
.builder()
126+
.id(unionShapeId.withMember("unionMember"))
127+
.target(struct.getId)
128+
.build()
129+
130+
val unionMemberString1 = MemberShape
131+
.builder()
132+
.id(unionShapeId.withMember("unionMemberString1"))
133+
.target("smithy.api#String")
134+
.build()
135+
136+
val unionMemberString2 = MemberShape
137+
.builder()
138+
.id(unionShapeId.withMember("unionMemberString2"))
139+
.target("smithy.api#String")
140+
.build()
141+
142+
val union =
143+
UnionShape
144+
.builder()
145+
.id(unionShapeId)
146+
.addMember(unionMember)
147+
.addMember(unionMemberString1)
148+
.addMember(unionMemberString2)
149+
.build()
150+
151+
val model =
152+
Model.builder().addShapes(struct, union).build()
153+
154+
val result = validator.validate(model).asScala.toList
155+
156+
val expected = List.empty
157+
expect(result == expected)
158+
}
159+
160+
test("return error when structure is targeted by a union twice") {
108161
val unionShapeId = ShapeId.fromParts("test", "MyUnion")
109162
val adtTrait = new AdtMemberTrait(unionShapeId)
110163
val structMember = MemberShape
@@ -125,36 +178,106 @@ object AdtMemberTraitValidatorSpec extends FunSuite {
125178
.id(unionShapeId.withMember("unionMember"))
126179
.target(struct.getId)
127180
.build()
181+
182+
val unionMember2 = MemberShape
183+
.builder()
184+
.id(unionShapeId.withMember("unionMember2"))
185+
.target(struct.getId)
186+
.build()
187+
188+
val union =
189+
UnionShape
190+
.builder()
191+
.id(unionShapeId)
192+
.addMember(unionMember)
193+
.addMember(unionMember2)
194+
.build()
195+
196+
val model =
197+
Model.builder().addShapes(struct, union).build()
198+
199+
val result = validator.validate(model).asScala.toList
200+
201+
val expected = List(
202+
ValidationEvent
203+
.builder()
204+
.id("AdtMemberTrait")
205+
.shape(unionMember)
206+
.severity(Severity.ERROR)
207+
.message(
208+
"Duplicate reference to shape test#struct in container test#MyUnion - only one is allowed"
209+
)
210+
.build()
211+
)
212+
expect(result == expected)
213+
}
214+
215+
test("return error when structure is targeted by the wrong union") {
216+
val unionShapeId = ShapeId.fromParts("test", "MyUnion")
217+
val adtTrait = new AdtMemberTrait(unionShapeId)
218+
val stringShape = StringShape.builder().id("smithy.api#String").build()
219+
val structMember = MemberShape
220+
.builder()
221+
.id("test#struct$testing")
222+
.target("test#String")
223+
.build()
224+
225+
val struct =
226+
StructureShape
227+
.builder()
228+
.id("test#struct")
229+
.addTrait(adtTrait)
230+
.addMember(structMember)
231+
.build()
232+
233+
val unionMember = MemberShape
234+
.builder()
235+
.id(unionShapeId.withMember("unionMember"))
236+
.target(stringShape.getId)
237+
.build()
238+
128239
val union =
129240
UnionShape.builder().id(unionShapeId).addMember(unionMember).build()
130241

131242
val union2ShapeId = ShapeId.fromParts("test", "MyUnionTwo")
132-
val unionMember2 = unionMember.toBuilder
133-
.id(union2ShapeId.withMember("unionMemberTwo"))
243+
val union2Member = MemberShape
244+
.builder()
245+
.id(union2ShapeId.withMember("unionMember"))
246+
.target(struct.getId)
134247
.build()
248+
135249
val union2 =
136-
UnionShape.builder().id(union2ShapeId).addMember(unionMember2).build()
250+
UnionShape.builder().id(union2ShapeId).addMember(union2Member).build()
137251

138252
val model =
139-
Model.builder().addShapes(struct, union, union2).build()
253+
Model.builder().addShapes(struct, stringShape, union, union2).build()
140254

141255
val result = validator.validate(model).asScala.toList
142256

143257
val expected = List(
144258
ValidationEvent
145259
.builder()
146-
.id("AdtValidator")
147-
.shape(union2)
260+
.id("AdtMemberTrait")
261+
.shape(struct)
262+
.severity(Severity.ERROR)
263+
.message(
264+
"This shape must be referenced by test#MyUnion because of its smithy4s.meta#adtMember trait"
265+
)
266+
.build(),
267+
ValidationEvent
268+
.builder()
269+
.id("AdtMemberTrait")
270+
.shape(union2Member)
148271
.severity(Severity.ERROR)
149272
.message(
150-
"ADT member test#struct must not be referenced in any other shape but test#MyUnion"
273+
"Invalid reference to test#struct - due to its smithy4s.meta#adtMember trait, only test#MyUnion can reference it"
151274
)
152275
.build()
153276
)
154277
expect(result == expected)
155278
}
156279

157-
test("return error when structure is targeted by a union and a structure") {
280+
test("return error when structure is targeted by multiple unions") {
158281
val unionShapeId = ShapeId.fromParts("test", "MyUnion")
159282
val adtTrait = new AdtMemberTrait(unionShapeId)
160283
val structMember = MemberShape
@@ -178,29 +301,27 @@ object AdtMemberTraitValidatorSpec extends FunSuite {
178301
val union =
179302
UnionShape.builder().id(unionShapeId).addMember(unionMember).build()
180303

181-
val struct2ShapeId = ShapeId.fromParts("test", "MyStruct2")
182-
val structMember2 = unionMember.toBuilder
183-
.id(struct2ShapeId.withMember("structMember2"))
184-
.build()
185-
val struct2 = StructureShape
186-
.builder()
187-
.id(struct2ShapeId)
188-
.addMember(structMember2)
304+
val union2ShapeId = ShapeId.fromParts("test", "MyUnionTwo")
305+
val unionMember2 = unionMember.toBuilder
306+
.id(union2ShapeId.withMember("unionMemberTwo"))
189307
.build()
190308

309+
val union2 =
310+
UnionShape.builder().id(union2ShapeId).addMember(unionMember2).build()
311+
191312
val model =
192-
Model.builder().addShapes(struct, union, struct2).build()
313+
Model.builder().addShapes(struct, union, union2).build()
193314

194315
val result = validator.validate(model).asScala.toList
195316

196317
val expected = List(
197318
ValidationEvent
198319
.builder()
199-
.id("AdtValidator")
200-
.shape(struct2)
320+
.id("AdtMemberTrait")
321+
.shape(unionMember2)
201322
.severity(Severity.ERROR)
202323
.message(
203-
"ADT member test#struct must not be referenced in any other shape but test#MyUnion"
324+
"Invalid reference to test#struct - due to its smithy4s.meta#adtMember trait, only test#MyUnion can reference it"
204325
)
205326
.build()
206327
)

modules/protocol-tests/test/src/smithy4s/api/validation/AdtTraitValidatorSpec.scala

+1-3
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,7 @@ object AdtTraitValidatorSpec extends FunSuite {
6868
expect(result == expected)
6969
}
7070

71-
test(
72-
"AdtTrait - return error when union does not target the structure"
73-
) {
71+
test("AdtTrait - return error when union does not target the structure") {
7472
val unionShapeId = ShapeId.fromParts("test", "MyUnion")
7573
val adtTrait = new AdtTrait()
7674
val structMember = MemberShape

modules/protocol/resources/META-INF/smithy/smithy4s.meta.smithy

+4-1
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,13 @@ structure packedInputs {}
2727
@idRef(failWhenMissing: true, selector: "union")
2828
string adtMember
2929

30+
// note: technically, the structure test in the selector is redundant
31+
// as it only checks that there exists at least one member that targets a struct.
32+
// Keeping it for now to avoid model merging issues.
3033
/// Implies that all members of the union are annotated with the `adtMember` trait.
3134
/// Further signals that the `sealed trait` for this adt will extend the traits
3235
/// defined by any mixins that are present on all of the adt members.
33-
@trait(selector: ":test(union, :not([trait|mixin]))")
36+
@trait(selector: ":test(union :test(> member > structure), :not([trait|mixin]))")
3437
structure adt {}
3538

3639
// the indexedSeq trait can be added to list shapes in order for the generated collection

modules/protocol/src/smithy4s/meta/validation/AdtMemberTraitValidator.java

+54-7
Original file line numberDiff line numberDiff line change
@@ -20,28 +20,75 @@
2020
import software.amazon.smithy.model.Model;
2121
import software.amazon.smithy.model.shapes.Shape;
2222
import software.amazon.smithy.model.shapes.ShapeId;
23+
import software.amazon.smithy.model.shapes.MemberShape;
2324
import software.amazon.smithy.model.validation.AbstractValidator;
2425
import software.amazon.smithy.model.validation.ValidationEvent;
26+
import software.amazon.smithy.model.selector.Selector;
2527

2628
import java.util.*;
29+
import java.util.stream.Stream;
2730
import java.util.stream.Collectors;
2831

2932
/**
3033
* All structures annotated with `@adtMember(SomeUnion)` are targeted in EXACTLY
3134
* ONE place: as a member of the union they reference in their idRef (SomeUnion
3235
* in this case)
3336
*
34-
* Also checks that structure is not empty (must have at least one member)
37+
* Doesn't check if the container is a union because the idRef on adtMember enforces that.
3538
*/
3639
public final class AdtMemberTraitValidator extends AbstractValidator {
3740

41+
private final Selector adtTargettingContainersSelector = Selector.parse(
42+
String.format(":test(> member > [trait|%s])", AdtMemberTrait.ID.toString())
43+
);
44+
45+
private Boolean targetIsAdtMember(MemberShape shape, Model model) {
46+
return model.getShape(shape.getTarget()).filter(mem -> mem.hasTrait(AdtMemberTrait.class)).isPresent();
47+
}
48+
3849
@Override
3950
public List<ValidationEvent> validate(Model model) {
40-
Set<Shape> adtMemberShapes = model.getShapesWithTrait(AdtMemberTrait.class);
41-
Map<ShapeId, List<Shape>> grouped = adtMemberShapes.stream()
42-
.collect(Collectors.groupingBy(mem -> mem.expectTrait(AdtMemberTrait.class).getValue()));
43-
return grouped.entrySet().stream().flatMap(entry -> {
44-
return AdtValidatorCommon.getReferenceEvents(model, entry.getValue(), model.expectShape(entry.getKey()));
45-
}).collect(Collectors.toList());
51+
52+
Stream<ValidationEvent> invalidCountErrors = model.getShapesWithTrait(AdtMemberTrait.class).stream()
53+
.flatMap(target -> {
54+
// we simply check if the shape contained in the adtMember trait actually refers to the annotated shape exactly once.
55+
// that covers all of the "not referenced from anywhere else", "no duplicate references", and "no reference" rules.
56+
57+
Shape expectedContainer = model.expectShape(target.expectTrait(AdtMemberTrait.class).getValue());
58+
59+
List<MemberShape> referencesToTargetInContainer = expectedContainer.getAllMembers().values().stream()
60+
.filter(mem -> mem.getTarget().equals(target.getId()))
61+
.collect(Collectors.toList());
62+
63+
switch(referencesToTargetInContainer.size()){
64+
case 0:
65+
// note: this may seem like a duplicate of the invalidReferenceErrors check below, but it's not:
66+
// this checks for "not referenced anywhere", and the other one checks for "referenced in the wrong place".
67+
return Stream.of(error(target, String.format("This shape must be referenced by %s because of its %s trait", expectedContainer.getId(), AdtMemberTrait.ID)));
68+
case 1:
69+
return Stream.empty(); // perfect - it's only referenced in the shape that actually should reference it
70+
default:
71+
return Stream.of(error(referencesToTargetInContainer.get(0), String.format("Duplicate reference to shape %s in container %s - only one is allowed", target.getId(), expectedContainer.getId())));
72+
}
73+
});
74+
75+
Stream<Shape> shapesTargettingAdtMembers = adtTargettingContainersSelector.shapes(model);
76+
77+
Stream<ValidationEvent> invalidReferenceErrors = shapesTargettingAdtMembers.flatMap(parent -> {
78+
List<MemberShape> adtMembersInParent = parent.getAllMembers()
79+
.values()
80+
.stream()
81+
.filter(mem -> targetIsAdtMember(mem, model))
82+
.collect(Collectors.toList());
83+
84+
Stream<MemberShape> invalidMembers = adtMembersInParent.stream().filter(mem -> !model.expectShape(mem.getTarget()).expectTrait(AdtMemberTrait.class).getValue().equals(parent.getId()));
85+
86+
return invalidMembers.map(mem -> {
87+
ShapeId expectedContainer = model.expectShape(mem.getTarget()).expectTrait(AdtMemberTrait.class).getValue();
88+
return error(mem, String.format("Invalid reference to %s - due to its %s trait, only %s can reference it", mem.getTarget(), AdtMemberTrait.ID, expectedContainer));
89+
});
90+
});
91+
92+
return Stream.concat(invalidCountErrors, invalidReferenceErrors).collect(Collectors.toList());
4693
}
4794
}

modules/protocol/src/smithy4s/meta/validation/AdtTraitValidator.java

+5-8
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,8 @@
2929
import software.amazon.smithy.model.selector.Selector;
3030

3131
/**
32-
* Unions marked with the adt trait must have at least one member. Also, the
33-
* structures that the union targets must NOT be used within any other union.
34-
*
35-
* Also checks that the structures targeted are not empty (they must have at
36-
* least one member).
32+
* All the members of an ADT union must be structures.
33+
* Also, each such structure can only be referenced once in the whole model (from said union).
3734
*/
3835
public final class AdtTraitValidator extends AbstractValidator {
3936
private final Selector adtTargetedMemberSelector = Selector.parse(
@@ -62,7 +59,7 @@ public List<ValidationEvent> validate(Model model) {
6259
.filter(union -> !union.getAllMembers().values().stream().allMatch(mem -> model.expectShape(mem.getTarget()).isStructureShape()))
6360
.map(union -> error(union, "All members of an adt union must be structures"));
6461

65-
List<ValidationEvent> dupes = adtTargetedMemberSelector.shapes(model).flatMap(parent -> {
62+
Stream<ValidationEvent> dupes = adtTargetedMemberSelector.shapes(model).flatMap(parent -> {
6663
return parent.getAllMembers().values().stream().map(mem -> new Reference(parent, model.expectShape(mem.getTarget())));
6764
})
6865
.collect(Collectors.groupingBy(ref -> ref.to))
@@ -82,9 +79,9 @@ public List<ValidationEvent> validate(Model model) {
8279
.collect(Collectors.joining(", "));
8380

8481
return error(targetWithDuplicateParents.getKey(), "This shape can only be referenced once and from one adt union, but it's referenced from " + targets);
85-
}).collect(Collectors.toList());
82+
});
8683

87-
return Stream.concat(nonStructTargets, dupes.stream()).collect(Collectors.toList());
84+
return Stream.concat(nonStructTargets, dupes).collect(Collectors.toList());
8885
}
8986

9087
}

0 commit comments

Comments
 (0)