Skip to content

Commit a7d0f1e

Browse files
authored
fix potential nil deref in waiter path matcher (#563)
1 parent e5c5ac3 commit a7d0f1e

File tree

2 files changed

+48
-17
lines changed

2 files changed

+48
-17
lines changed

codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGenerator.java

+16-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
package software.amazon.smithy.go.codegen;
1717

1818
import static software.amazon.smithy.go.codegen.GoWriter.goTemplate;
19+
import static software.amazon.smithy.go.codegen.SymbolUtils.isNilable;
1920
import static software.amazon.smithy.go.codegen.SymbolUtils.isPointable;
2021
import static software.amazon.smithy.go.codegen.SymbolUtils.sliceOf;
2122
import static software.amazon.smithy.go.codegen.util.ShapeUtil.BOOL_SHAPE;
@@ -280,7 +281,21 @@ private Variable visitProjection(ProjectionExpression expr, Variable current) {
280281

281282
private Variable visitSub(Subexpression expr, Variable current) {
282283
var left = visit(expr.getLeft(), current);
283-
return visit(expr.getRight(), left);
284+
if (!isNilable(left.type)) {
285+
return visit(expr.getRight(), left);
286+
}
287+
288+
var lookahead = new GoJmespathExpressionGenerator(ctx, new GoWriter(""))
289+
.generate(expr.getRight(), left);
290+
var ident = nextIdent();
291+
writer.write("var $L $P", ident, lookahead.type);
292+
writer.write("if $L != nil {", left.ident);
293+
writer.indent();
294+
var inner = visit(expr.getRight(), left);
295+
writer.write("$L = $L", ident, inner.ident);
296+
writer.dedent();
297+
writer.write("}");
298+
return new Variable(inner.shape, ident, inner.type);
284299
}
285300

286301
private Variable visitField(FieldExpression expr, Variable current) {

codegen/smithy-go-codegen/src/test/java/software/amazon/smithy/go/codegen/GoJmespathExpressionGeneratorTest.java

+32-16
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,11 @@ public void testSubexpression() {
124124
assertThat(actual.ident(), Matchers.equalTo("v2"));
125125
assertThat(writer.toString(), Matchers.containsString("""
126126
v1 := input.Nested
127-
v2 := v1.NestedField
127+
var v2 *string
128+
if v1 != nil {
129+
v3 := v1.NestedField
130+
v2 = v3
131+
}
128132
"""));
129133
}
130134

@@ -304,14 +308,18 @@ public void testComparatorStringLHSNil() {
304308
"input"
305309
));
306310
assertThat(actual.shape(), Matchers.equalTo(ShapeUtil.BOOL_SHAPE));
307-
assertThat(actual.ident(), Matchers.equalTo("v4"));
311+
assertThat(actual.ident(), Matchers.equalTo("v5"));
308312
assertThat(writer.toString(), Matchers.containsString("""
309313
v1 := input.Nested
310-
v2 := v1.NestedField
311-
v3 := "foo"
312-
var v4 bool
314+
var v2 *string
315+
if v1 != nil {
316+
v3 := v1.NestedField
317+
v2 = v3
318+
}
319+
v4 := "foo"
320+
var v5 bool
313321
if v2 != nil {
314-
v4 = string(*v2) == string(v3)
322+
v5 = string(*v2) == string(v4)
315323
}
316324
"""));
317325
}
@@ -327,14 +335,18 @@ public void testComparatorStringRHSNil() {
327335
"input"
328336
));
329337
assertThat(actual.shape(), Matchers.equalTo(ShapeUtil.BOOL_SHAPE));
330-
assertThat(actual.ident(), Matchers.equalTo("v4"));
338+
assertThat(actual.ident(), Matchers.equalTo("v5"));
331339
assertThat(writer.toString(), Matchers.containsString("""
332340
v1 := "foo"
333341
v2 := input.Nested
334-
v3 := v2.NestedField
335-
var v4 bool
342+
var v3 *string
343+
if v2 != nil {
344+
v4 := v2.NestedField
345+
v3 = v4
346+
}
347+
var v5 bool
336348
if v3 != nil {
337-
v4 = string(v1) == string(*v3)
349+
v5 = string(v1) == string(*v3)
338350
}
339351
"""));
340352
}
@@ -350,14 +362,18 @@ public void testComparatorStringBothNil() {
350362
"input"
351363
));
352364
assertThat(actual.shape(), Matchers.equalTo(ShapeUtil.BOOL_SHAPE));
353-
assertThat(actual.ident(), Matchers.equalTo("v4"));
365+
assertThat(actual.ident(), Matchers.equalTo("v5"));
354366
assertThat(writer.toString(), Matchers.containsString("""
355367
v1 := input.Nested
356-
v2 := v1.NestedField
357-
v3 := input.SimpleShape
358-
var v4 bool
359-
if v2 != nil && v3 != nil {
360-
v4 = string(*v2) == string(*v3)
368+
var v2 *string
369+
if v1 != nil {
370+
v3 := v1.NestedField
371+
v2 = v3
372+
}
373+
v4 := input.SimpleShape
374+
var v5 bool
375+
if v2 != nil && v4 != nil {
376+
v5 = string(*v2) == string(*v4)
361377
}
362378
"""));
363379
}

0 commit comments

Comments
 (0)