diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGenerator.java index ec11c078..89a87f62 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGenerator.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGenerator.java @@ -15,6 +15,7 @@ package software.amazon.smithy.go.codegen; +import static software.amazon.smithy.go.codegen.GoWriter.emptyGoTemplate; import static software.amazon.smithy.go.codegen.GoWriter.goTemplate; import static software.amazon.smithy.go.codegen.SymbolUtils.isNilable; import static software.amazon.smithy.go.codegen.SymbolUtils.isPointable; @@ -407,11 +408,48 @@ private GoWriter.Writable compareVariables(String ident, Variable left, Variable return goTemplate("$1L := $5L($2L) $4L $5L($3L)", ident, left.ident, right.ident, cmp, cast); } + // undocumented jmespath behavior: null in numeric _ordering_ comparisons coerces to 0 + // this means the subsequent nil checks for numerics are moot, but it's either this or branch the codegen even + // further for questionable benefit + var nilCoerceLeft = emptyGoTemplate(); + var nilCoerceRight = emptyGoTemplate(); + if (isOrderComparator(cmp)) { + if (isLPtr && left.shape instanceof NumberShape) { + nilCoerceLeft = goTemplate(""" + if ($1L == nil) { + $1L = new($2T) + *$1L = 0 + }""", left.ident, left.type); + } + if (isRPtr && right.shape instanceof NumberShape) { + nilCoerceRight = goTemplate(""" + if ($1L == nil) { + $1L = new($2T) + *$1L = 0 + }""", right.ident, right.type); + } + } + + // also, if they're both pointers, and it's (in)equality, there's an additional true case where both are nil, + // or both are different + var elseCheckPtrs = emptyGoTemplate(); + if (isLPtr && isRPtr) { + if (cmp == ComparatorType.EQUAL) { + elseCheckPtrs = goTemplate("else { $L = $L == nil && $L == nil }", + ident, left.ident, right.ident); + } else if (cmp == ComparatorType.NOT_EQUAL) { + elseCheckPtrs = goTemplate("else { $1L = ($2L == nil && $3L != nil) || ($2L != nil && $3L == nil) }", + ident, left.ident, right.ident); + } + } + return goTemplate(""" var $ident:L bool + $nilCoerceLeft:W + $nilCoerceRight:W if $lif:L $amp:L $rif:L { $ident:L = $cast:L($lhs:L) $cmp:L $cast:L($rhs:L) - }""", + }$elseCheckPtrs:W""", Map.of( "ident", ident, "lif", isLPtr ? left.ident + " != nil" : "", @@ -420,10 +458,20 @@ private GoWriter.Writable compareVariables(String ident, Variable left, Variable "cmp", cmp, "lhs", isLPtr ? "*" + left.ident : left.ident, "rhs", isRPtr ? "*" + right.ident : right.ident, - "cast", cast + "cast", cast, + "nilCoerceLeft", nilCoerceLeft, + "nilCoerceRight", nilCoerceRight + ), + Map.of( + "elseCheckPtrs", elseCheckPtrs )); } + private static boolean isOrderComparator(ComparatorType cmp) { + return cmp == ComparatorType.GREATER_THAN || cmp == ComparatorType.LESS_THAN + || cmp == ComparatorType.GREATER_THAN_EQUAL || cmp == ComparatorType.LESS_THAN_EQUAL; + } + /** * Represents a variable (input, intermediate, or final output) of a JMESPath traversal. * @param shape The underlying shape referenced by this variable. For certain jmespath expressions (e.g. diff --git a/codegen/smithy-go-codegen/src/test/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGeneratorTest.java b/codegen/smithy-go-codegen/src/test/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGeneratorTest.java index 764c3537..0a6b01a3 100644 --- a/codegen/smithy-go-codegen/src/test/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGeneratorTest.java +++ b/codegen/smithy-go-codegen/src/test/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGeneratorTest.java @@ -43,6 +43,8 @@ public class GoJmespathExpressionGeneratorTest { objectList: ObjectList objectMap: ObjectMap nested: NestedStruct + nullableIntegerA: Integer + nullableIntegerB: Integer } structure Object { @@ -318,6 +320,7 @@ public void testComparatorStringLHSNil() { } v4 := "foo" var v5 bool + if v2 != nil { v5 = string(*v2) == string(v4) } @@ -345,6 +348,7 @@ public void testComparatorStringRHSNil() { v3 = v4 } var v5 bool + if v3 != nil { v5 = string(v1) == string(*v3) } @@ -372,9 +376,10 @@ public void testComparatorStringBothNil() { } v4 := input.SimpleShape var v5 bool + if v2 != nil && v4 != nil { v5 = string(*v2) == string(*v4) - } + }else { v5 = v2 == nil && v4 == nil } """)); } @@ -546,4 +551,107 @@ public void testMultiSelectFlatten() { } """)); } + + @Test + public void testOrderComparatorNumberCoercesLeftNullable() { + var expr = "nullableIntegerA > `9`"; + + var writer = testWriter(); + var generator = new GoJmespathExpressionGenerator(testContext(), writer); + var actual = generator.generate(JmespathExpression.parse(expr), new GoJmespathExpressionGenerator.Variable( + TEST_MODEL.expectShape(ShapeId.from("smithy.go.test#Struct")), + "input" + )); + assertThat(actual.shape().toShapeId().toString(), Matchers.equalTo("smithy.api#PrimitiveBoolean")); + assertThat(actual.ident(), Matchers.equalTo("v3")); + assertThat(writer.toString(), Matchers.containsString(""" + v1 := input.NullableIntegerA + v2 := 9 + var v3 bool + if (v1 == nil) { + v1 = new(int32) + *v1 = 0 + } + + if v1 != nil { + v3 = int64(*v1) > int64(v2) + } + """)); + } + + @Test + public void testOrderComparatorNumberCoercesBothNullable() { + var expr = "nullableIntegerA > nullableIntegerB"; + + var writer = testWriter(); + var generator = new GoJmespathExpressionGenerator(testContext(), writer); + var actual = generator.generate(JmespathExpression.parse(expr), new GoJmespathExpressionGenerator.Variable( + TEST_MODEL.expectShape(ShapeId.from("smithy.go.test#Struct")), + "input" + )); + assertThat(actual.shape().toShapeId().toString(), Matchers.equalTo("smithy.api#PrimitiveBoolean")); + assertThat(actual.ident(), Matchers.equalTo("v3")); + assertThat(writer.toString(), Matchers.containsString(""" + v1 := input.NullableIntegerA + v2 := input.NullableIntegerB + var v3 bool + if (v1 == nil) { + v1 = new(int32) + *v1 = 0 + } + if (v2 == nil) { + v2 = new(int32) + *v2 = 0 + } + if v1 != nil && v2 != nil { + v3 = int64(*v1) > int64(*v2) + } + """)); + } + + @Test + public void testEqualBothNullable() { + var expr = "nullableIntegerA == nullableIntegerB"; + + var writer = testWriter(); + var generator = new GoJmespathExpressionGenerator(testContext(), writer); + var actual = generator.generate(JmespathExpression.parse(expr), new GoJmespathExpressionGenerator.Variable( + TEST_MODEL.expectShape(ShapeId.from("smithy.go.test#Struct")), + "input" + )); + assertThat(actual.shape().toShapeId().toString(), Matchers.equalTo("smithy.api#PrimitiveBoolean")); + assertThat(actual.ident(), Matchers.equalTo("v3")); + assertThat(writer.toString(), Matchers.containsString(""" + v1 := input.NullableIntegerA + v2 := input.NullableIntegerB + var v3 bool + + if v1 != nil && v2 != nil { + v3 = int64(*v1) == int64(*v2) + }else { v3 = v1 == nil && v2 == nil } + """)); + } + + @Test + public void testNotEqualBothNullable() { + var expr = "nullableIntegerA != nullableIntegerB"; + + var writer = testWriter(); + var generator = new GoJmespathExpressionGenerator(testContext(), writer); + var actual = generator.generate(JmespathExpression.parse(expr), new GoJmespathExpressionGenerator.Variable( + TEST_MODEL.expectShape(ShapeId.from("smithy.go.test#Struct")), + "input" + )); + assertThat(actual.shape().toShapeId().toString(), Matchers.equalTo("smithy.api#PrimitiveBoolean")); + assertThat(actual.ident(), Matchers.equalTo("v3")); + assertThat(writer.toString(), Matchers.containsString(""" + v1 := input.NullableIntegerA + v2 := input.NullableIntegerB + var v3 bool + + if v1 != nil && v2 != nil { + v3 = int64(*v1) != int64(*v2) + }else { v3 = (v1 == nil && v2 != nil) || (v1 != nil && v2 == nil) } + """)); + } }