Skip to content

Commit d708d1d

Browse files
authored
coerce nil numbers to 0 in jmespath codegen (#565)
1 parent f2ae388 commit d708d1d

File tree

2 files changed

+159
-3
lines changed

2 files changed

+159
-3
lines changed

codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGenerator.java

+50-2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
package software.amazon.smithy.go.codegen;
1717

18+
import static software.amazon.smithy.go.codegen.GoWriter.emptyGoTemplate;
1819
import static software.amazon.smithy.go.codegen.GoWriter.goTemplate;
1920
import static software.amazon.smithy.go.codegen.SymbolUtils.isNilable;
2021
import static software.amazon.smithy.go.codegen.SymbolUtils.isPointable;
@@ -407,11 +408,48 @@ private GoWriter.Writable compareVariables(String ident, Variable left, Variable
407408
return goTemplate("$1L := $5L($2L) $4L $5L($3L)", ident, left.ident, right.ident, cmp, cast);
408409
}
409410

411+
// undocumented jmespath behavior: null in numeric _ordering_ comparisons coerces to 0
412+
// this means the subsequent nil checks for numerics are moot, but it's either this or branch the codegen even
413+
// further for questionable benefit
414+
var nilCoerceLeft = emptyGoTemplate();
415+
var nilCoerceRight = emptyGoTemplate();
416+
if (isOrderComparator(cmp)) {
417+
if (isLPtr && left.shape instanceof NumberShape) {
418+
nilCoerceLeft = goTemplate("""
419+
if ($1L == nil) {
420+
$1L = new($2T)
421+
*$1L = 0
422+
}""", left.ident, left.type);
423+
}
424+
if (isRPtr && right.shape instanceof NumberShape) {
425+
nilCoerceRight = goTemplate("""
426+
if ($1L == nil) {
427+
$1L = new($2T)
428+
*$1L = 0
429+
}""", right.ident, right.type);
430+
}
431+
}
432+
433+
// also, if they're both pointers, and it's (in)equality, there's an additional true case where both are nil,
434+
// or both are different
435+
var elseCheckPtrs = emptyGoTemplate();
436+
if (isLPtr && isRPtr) {
437+
if (cmp == ComparatorType.EQUAL) {
438+
elseCheckPtrs = goTemplate("else { $L = $L == nil && $L == nil }",
439+
ident, left.ident, right.ident);
440+
} else if (cmp == ComparatorType.NOT_EQUAL) {
441+
elseCheckPtrs = goTemplate("else { $1L = ($2L == nil && $3L != nil) || ($2L != nil && $3L == nil) }",
442+
ident, left.ident, right.ident);
443+
}
444+
}
445+
410446
return goTemplate("""
411447
var $ident:L bool
448+
$nilCoerceLeft:W
449+
$nilCoerceRight:W
412450
if $lif:L $amp:L $rif:L {
413451
$ident:L = $cast:L($lhs:L) $cmp:L $cast:L($rhs:L)
414-
}""",
452+
}$elseCheckPtrs:W""",
415453
Map.of(
416454
"ident", ident,
417455
"lif", isLPtr ? left.ident + " != nil" : "",
@@ -420,10 +458,20 @@ private GoWriter.Writable compareVariables(String ident, Variable left, Variable
420458
"cmp", cmp,
421459
"lhs", isLPtr ? "*" + left.ident : left.ident,
422460
"rhs", isRPtr ? "*" + right.ident : right.ident,
423-
"cast", cast
461+
"cast", cast,
462+
"nilCoerceLeft", nilCoerceLeft,
463+
"nilCoerceRight", nilCoerceRight
464+
),
465+
Map.of(
466+
"elseCheckPtrs", elseCheckPtrs
424467
));
425468
}
426469

470+
private static boolean isOrderComparator(ComparatorType cmp) {
471+
return cmp == ComparatorType.GREATER_THAN || cmp == ComparatorType.LESS_THAN
472+
|| cmp == ComparatorType.GREATER_THAN_EQUAL || cmp == ComparatorType.LESS_THAN_EQUAL;
473+
}
474+
427475
/**
428476
* Represents a variable (input, intermediate, or final output) of a JMESPath traversal.
429477
* @param shape The underlying shape referenced by this variable. For certain jmespath expressions (e.g.

codegen/smithy-go-codegen/src/test/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGeneratorTest.java

+109-1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ public class GoJmespathExpressionGeneratorTest {
4343
objectList: ObjectList
4444
objectMap: ObjectMap
4545
nested: NestedStruct
46+
nullableIntegerA: Integer
47+
nullableIntegerB: Integer
4648
}
4749
4850
structure Object {
@@ -318,6 +320,7 @@ public void testComparatorStringLHSNil() {
318320
}
319321
v4 := "foo"
320322
var v5 bool
323+
321324
if v2 != nil {
322325
v5 = string(*v2) == string(v4)
323326
}
@@ -345,6 +348,7 @@ public void testComparatorStringRHSNil() {
345348
v3 = v4
346349
}
347350
var v5 bool
351+
348352
if v3 != nil {
349353
v5 = string(v1) == string(*v3)
350354
}
@@ -372,9 +376,10 @@ public void testComparatorStringBothNil() {
372376
}
373377
v4 := input.SimpleShape
374378
var v5 bool
379+
375380
if v2 != nil && v4 != nil {
376381
v5 = string(*v2) == string(*v4)
377-
}
382+
}else { v5 = v2 == nil && v4 == nil }
378383
"""));
379384
}
380385

@@ -546,4 +551,107 @@ public void testMultiSelectFlatten() {
546551
}
547552
"""));
548553
}
554+
555+
@Test
556+
public void testOrderComparatorNumberCoercesLeftNullable() {
557+
var expr = "nullableIntegerA > `9`";
558+
559+
var writer = testWriter();
560+
var generator = new GoJmespathExpressionGenerator(testContext(), writer);
561+
var actual = generator.generate(JmespathExpression.parse(expr), new GoJmespathExpressionGenerator.Variable(
562+
TEST_MODEL.expectShape(ShapeId.from("smithy.go.test#Struct")),
563+
"input"
564+
));
565+
assertThat(actual.shape().toShapeId().toString(), Matchers.equalTo("smithy.api#PrimitiveBoolean"));
566+
assertThat(actual.ident(), Matchers.equalTo("v3"));
567+
assertThat(writer.toString(), Matchers.containsString("""
568+
v1 := input.NullableIntegerA
569+
v2 := 9
570+
var v3 bool
571+
if (v1 == nil) {
572+
v1 = new(int32)
573+
*v1 = 0
574+
}
575+
576+
if v1 != nil {
577+
v3 = int64(*v1) > int64(v2)
578+
}
579+
"""));
580+
}
581+
582+
@Test
583+
public void testOrderComparatorNumberCoercesBothNullable() {
584+
var expr = "nullableIntegerA > nullableIntegerB";
585+
586+
var writer = testWriter();
587+
var generator = new GoJmespathExpressionGenerator(testContext(), writer);
588+
var actual = generator.generate(JmespathExpression.parse(expr), new GoJmespathExpressionGenerator.Variable(
589+
TEST_MODEL.expectShape(ShapeId.from("smithy.go.test#Struct")),
590+
"input"
591+
));
592+
assertThat(actual.shape().toShapeId().toString(), Matchers.equalTo("smithy.api#PrimitiveBoolean"));
593+
assertThat(actual.ident(), Matchers.equalTo("v3"));
594+
assertThat(writer.toString(), Matchers.containsString("""
595+
v1 := input.NullableIntegerA
596+
v2 := input.NullableIntegerB
597+
var v3 bool
598+
if (v1 == nil) {
599+
v1 = new(int32)
600+
*v1 = 0
601+
}
602+
if (v2 == nil) {
603+
v2 = new(int32)
604+
*v2 = 0
605+
}
606+
if v1 != nil && v2 != nil {
607+
v3 = int64(*v1) > int64(*v2)
608+
}
609+
"""));
610+
}
611+
612+
@Test
613+
public void testEqualBothNullable() {
614+
var expr = "nullableIntegerA == nullableIntegerB";
615+
616+
var writer = testWriter();
617+
var generator = new GoJmespathExpressionGenerator(testContext(), writer);
618+
var actual = generator.generate(JmespathExpression.parse(expr), new GoJmespathExpressionGenerator.Variable(
619+
TEST_MODEL.expectShape(ShapeId.from("smithy.go.test#Struct")),
620+
"input"
621+
));
622+
assertThat(actual.shape().toShapeId().toString(), Matchers.equalTo("smithy.api#PrimitiveBoolean"));
623+
assertThat(actual.ident(), Matchers.equalTo("v3"));
624+
assertThat(writer.toString(), Matchers.containsString("""
625+
v1 := input.NullableIntegerA
626+
v2 := input.NullableIntegerB
627+
var v3 bool
628+
629+
if v1 != nil && v2 != nil {
630+
v3 = int64(*v1) == int64(*v2)
631+
}else { v3 = v1 == nil && v2 == nil }
632+
"""));
633+
}
634+
635+
@Test
636+
public void testNotEqualBothNullable() {
637+
var expr = "nullableIntegerA != nullableIntegerB";
638+
639+
var writer = testWriter();
640+
var generator = new GoJmespathExpressionGenerator(testContext(), writer);
641+
var actual = generator.generate(JmespathExpression.parse(expr), new GoJmespathExpressionGenerator.Variable(
642+
TEST_MODEL.expectShape(ShapeId.from("smithy.go.test#Struct")),
643+
"input"
644+
));
645+
assertThat(actual.shape().toShapeId().toString(), Matchers.equalTo("smithy.api#PrimitiveBoolean"));
646+
assertThat(actual.ident(), Matchers.equalTo("v3"));
647+
assertThat(writer.toString(), Matchers.containsString("""
648+
v1 := input.NullableIntegerA
649+
v2 := input.NullableIntegerB
650+
var v3 bool
651+
652+
if v1 != nil && v2 != nil {
653+
v3 = int64(*v1) != int64(*v2)
654+
}else { v3 = (v1 == nil && v2 != nil) || (v1 != nil && v2 == nil) }
655+
"""));
656+
}
549657
}

0 commit comments

Comments
 (0)