Skip to content

Commit ae3d43e

Browse files
cushonError Prone Team
authored andcommitted
Suggest minimizing the amount of logic in assertThrows
PiperOrigin-RevId: 877896836
1 parent 9947a47 commit ae3d43e

File tree

5 files changed

+369
-2
lines changed

5 files changed

+369
-2
lines changed
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
/*
2+
* Copyright 2026 The Error Prone Authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.errorprone.bugpatterns;
18+
19+
import static com.google.common.collect.ImmutableList.toImmutableList;
20+
import static com.google.common.collect.ImmutableSet.toImmutableSet;
21+
import static com.google.common.collect.Iterables.getOnlyElement;
22+
import static com.google.errorprone.BugPattern.SeverityLevel.WARNING;
23+
import static com.google.errorprone.matchers.Description.NO_MATCH;
24+
import static com.google.errorprone.matchers.method.MethodMatchers.staticMethod;
25+
import static com.google.errorprone.util.ASTHelpers.getReceiver;
26+
import static com.google.errorprone.util.ASTHelpers.getStartPosition;
27+
import static com.google.errorprone.util.ASTHelpers.getSymbol;
28+
import static com.google.errorprone.util.ASTHelpers.getType;
29+
30+
import com.google.common.base.CaseFormat;
31+
import com.google.common.collect.ImmutableList;
32+
import com.google.common.collect.Streams;
33+
import com.google.errorprone.BugPattern;
34+
import com.google.errorprone.ErrorProneFlags;
35+
import com.google.errorprone.VisitorState;
36+
import com.google.errorprone.bugpatterns.BugChecker.MethodInvocationTreeMatcher;
37+
import com.google.errorprone.bugpatterns.threadsafety.ConstantExpressions;
38+
import com.google.errorprone.fixes.SuggestedFix;
39+
import com.google.errorprone.fixes.SuggestedFixes;
40+
import com.google.errorprone.matchers.Description;
41+
import com.google.errorprone.matchers.Matcher;
42+
import com.google.errorprone.util.FindIdentifiers;
43+
import com.sun.source.tree.BlockTree;
44+
import com.sun.source.tree.ExpressionStatementTree;
45+
import com.sun.source.tree.ExpressionTree;
46+
import com.sun.source.tree.IdentifierTree;
47+
import com.sun.source.tree.LambdaExpressionTree;
48+
import com.sun.source.tree.MemberSelectTree;
49+
import com.sun.source.tree.MethodInvocationTree;
50+
import com.sun.source.tree.StatementTree;
51+
import com.sun.tools.javac.code.Symbol.VarSymbol;
52+
import java.util.stream.IntStream;
53+
import java.util.stream.Stream;
54+
import javax.inject.Inject;
55+
import javax.lang.model.element.ElementKind;
56+
57+
/** A {@link BugChecker}; see the associated {@link BugPattern} annotation for details. */
58+
@BugPattern(summary = "Minimize the amount of logic in assertThrows", severity = WARNING)
59+
public class AssertThrowsMinimizer extends BugChecker implements MethodInvocationTreeMatcher {
60+
61+
private static final Matcher<ExpressionTree> MATCHER =
62+
staticMethod().onClass("org.junit.Assert").named("assertThrows");
63+
64+
private final ConstantExpressions constantExpressions;
65+
private final boolean useVarType;
66+
67+
@Inject
68+
AssertThrowsMinimizer(ConstantExpressions constantExpressions, ErrorProneFlags flags) {
69+
this.constantExpressions = constantExpressions;
70+
this.useVarType = flags.getBoolean("AssertThrowsMinimizer:UseVarType").orElse(false);
71+
}
72+
73+
record Hoist(ExpressionTree site, String name) {}
74+
75+
@Override
76+
public Description matchMethodInvocation(MethodInvocationTree tree, VisitorState state) {
77+
if (!MATCHER.matches(tree, state)) {
78+
return NO_MATCH;
79+
}
80+
if (!(state.getPath().getParentPath().getLeaf() instanceof StatementTree parent)) {
81+
// We need a scope to declare variables in, assertThrows is usually an expression statement or
82+
// a variable initializer
83+
return NO_MATCH;
84+
}
85+
if (!(tree.getArguments().getLast() instanceof LambdaExpressionTree lambdaExpressionTree)) {
86+
return NO_MATCH;
87+
}
88+
MethodInvocationTree runnable;
89+
switch (lambdaExpressionTree.getBody()) {
90+
case BlockTree blockTree -> {
91+
if (blockTree.getStatements().size() != 1) {
92+
return NO_MATCH;
93+
}
94+
if (!(getOnlyElement(blockTree.getStatements())
95+
instanceof ExpressionStatementTree expressionStatementTree
96+
&& expressionStatementTree.getExpression()
97+
instanceof MethodInvocationTree methodInvocationTree)) {
98+
return NO_MATCH;
99+
}
100+
runnable = methodInvocationTree;
101+
}
102+
case MethodInvocationTree methodInvocationTree -> runnable = methodInvocationTree;
103+
default -> {
104+
return NO_MATCH;
105+
}
106+
}
107+
ImmutableList<Hoist> toHoist =
108+
Streams.concat(
109+
Stream.ofNullable(getReceiver(runnable))
110+
.map(r -> new Hoist(r, receiverVariableName(r, state))),
111+
Streams.zip(
112+
runnable.getArguments().stream(),
113+
getSymbol(runnable).getParameters().stream(),
114+
(ExpressionTree a, VarSymbol p) -> new Hoist(a, p.getSimpleName().toString())))
115+
.filter(h -> needsHoisting(h.site(), state))
116+
.collect(toImmutableList());
117+
if (toHoist.isEmpty()) {
118+
return NO_MATCH;
119+
}
120+
SuggestedFix.Builder fix = SuggestedFix.builder();
121+
StringBuilder hoistedVariables = new StringBuilder();
122+
for (Hoist hoist : toHoist) {
123+
String identifier = avoidShadowing(hoist.name(), state);
124+
hoistedVariables.append(
125+
String.format(
126+
"%s %s = %s;\n",
127+
useVarType ? "var" : SuggestedFixes.qualifyType(state, fix, getType(hoist.site())),
128+
identifier,
129+
state.getSourceForNode(hoist.site())));
130+
fix.replace(hoist.site(), identifier);
131+
}
132+
fix.prefixWith(parent, hoistedVariables.toString());
133+
if (lambdaExpressionTree.getBody() instanceof BlockTree blockTree) {
134+
fix.replace(getStartPosition(blockTree), getStartPosition(runnable), "");
135+
fix.replace(state.getEndPosition(runnable), state.getEndPosition(blockTree), "");
136+
}
137+
return describeMatch(tree, fix.build());
138+
}
139+
140+
private static String receiverVariableName(ExpressionTree tree, VisitorState state) {
141+
return CaseFormat.UPPER_CAMEL.to(
142+
CaseFormat.LOWER_CAMEL, getType(tree).asElement().getSimpleName().toString());
143+
}
144+
145+
private boolean needsHoisting(ExpressionTree tree, VisitorState state) {
146+
boolean unqualifiedIdentifier =
147+
switch (tree) {
148+
case IdentifierTree identifierTree -> true;
149+
case MemberSelectTree memberSelectTree ->
150+
memberSelectTree.getExpression() instanceof IdentifierTree identifierTree
151+
&& identifierTree.getName().contentEquals("this");
152+
default -> false;
153+
};
154+
if (unqualifiedIdentifier && getSymbol(tree).getKind() == ElementKind.FIELD) {
155+
return false;
156+
}
157+
// This is an imperfect heuristic. These expressions aren't guaranteed not to throw, but may be
158+
// less valuable to hoist.
159+
return constantExpressions.constantExpression(tree, state).isEmpty();
160+
}
161+
162+
// Stolen from PatternMatchingInstanceof
163+
// TODO: cushon - add to SuggestedFixes?
164+
private static String avoidShadowing(String name, VisitorState state) {
165+
var idents =
166+
FindIdentifiers.findAllIdents(state).stream()
167+
.map(s -> s.getSimpleName().toString())
168+
.collect(toImmutableSet());
169+
return IntStream.iterate(1, i -> i + 1)
170+
.mapToObj(i -> i == 1 ? name : (name + i))
171+
.filter(n -> !idents.contains(n))
172+
.findFirst()
173+
.get();
174+
}
175+
}

core/src/main/java/com/google/errorprone/bugpatterns/PatternMatchingInstanceof.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ private static String generateVariableName(Type targetType, VisitorState state)
174174
return avoidShadowing(camelCased, state);
175175
}
176176

177+
// TODO: cushon - add to SuggestedFixes?
177178
private static String avoidShadowing(String name, VisitorState state) {
178179
var idents =
179180
FindIdentifiers.findAllIdents(state).stream()

core/src/main/java/com/google/errorprone/bugpatterns/threadsafety/ConstantExpressions.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,8 +273,8 @@ public Optional<ConstantExpression> constantExpression(ExpressionTree tree, Visi
273273
return Optional.of(new ConstantExpression.ConstantEquals(lhs.get(), rhs.get()));
274274
}
275275
}
276-
Object value = constValue(tree);
277-
if (value != null && tree instanceof LiteralTree) {
276+
Object value = tree instanceof LiteralTree ? constValue(tree) : null;
277+
if (value != null || tree.getKind() == Kind.NULL_LITERAL) {
278278
return Optional.of(new ConstantExpression.Literal(value));
279279
}
280280
return symbolizeImmutableExpression(tree, state).map(x -> x);

core/src/main/java/com/google/errorprone/scanner/BuiltInCheckerSuppliers.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import com.google.errorprone.bugpatterns.ArrayToString;
3838
import com.google.errorprone.bugpatterns.ArraysAsListPrimitiveArray;
3939
import com.google.errorprone.bugpatterns.AssertFalse;
40+
import com.google.errorprone.bugpatterns.AssertThrowsMinimizer;
4041
import com.google.errorprone.bugpatterns.AssertThrowsMultipleStatements;
4142
import com.google.errorprone.bugpatterns.AssertionFailureIgnored;
4243
import com.google.errorprone.bugpatterns.AssignmentExpression;
@@ -905,6 +906,7 @@ public static ScannerSupplier warningChecks() {
905906
ArrayRecordComponent.class,
906907
AssertEqualsArgumentOrderChecker.class,
907908
AssertSameIncompatible.class,
909+
AssertThrowsMinimizer.class,
908910
AssertThrowsMultipleStatements.class,
909911
AssertionFailureIgnored.class,
910912
AssignmentExpression.class,

0 commit comments

Comments
 (0)