Skip to content

Commit 709406e

Browse files
authored
Merge pull request #1907 from DependencyTrack/simplify-cel-visitor
2 parents 09213bf + 6cf6a81 commit 709406e

File tree

2 files changed

+67
-63
lines changed

2 files changed

+67
-63
lines changed

apiserver/src/main/java/org/dependencytrack/policy/cel/CelPolicyScriptVisitor.java

Lines changed: 26 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -18,76 +18,58 @@
1818
*/
1919
package org.dependencytrack.policy.cel;
2020

21-
import alpine.common.logging.Logger;
2221
import com.google.api.expr.v1alpha1.Expr;
2322
import com.google.api.expr.v1alpha1.Type;
2423
import org.apache.commons.collections4.MultiValuedMap;
2524
import org.apache.commons.collections4.multimap.HashSetValuedHashMap;
2625

27-
import java.util.ArrayDeque;
28-
import java.util.Deque;
2926
import java.util.HashSet;
3027
import java.util.List;
3128
import java.util.Map;
3229
import java.util.Set;
3330

34-
class CelPolicyScriptVisitor {
35-
36-
private static final Logger LOGGER = Logger.getLogger(CelPolicyScriptVisitor.class);
31+
final class CelPolicyScriptVisitor {
3732

3833
record FunctionSignature(String function, Type targetType, List<Type> argumentTypes) {
3934
}
4035

41-
private final Map<Long, Type> types;
36+
private final Map<Long, Type> typeByExpressionId;
4237
private final MultiValuedMap<Type, String> accessedFieldsByType;
4338
private final Set<FunctionSignature> usedFunctionSignatures;
44-
private final Deque<String> callFunctionStack;
45-
private final Deque<String> selectFieldStack;
46-
private final Deque<Type> selectOperandTypeStack;
4739

48-
CelPolicyScriptVisitor(final Map<Long, Type> types) {
49-
this.types = types;
40+
CelPolicyScriptVisitor(Map<Long, Type> typeByExpressionId) {
41+
this.typeByExpressionId = typeByExpressionId;
5042
this.accessedFieldsByType = new HashSetValuedHashMap<>();
5143
this.usedFunctionSignatures = new HashSet<>();
52-
this.callFunctionStack = new ArrayDeque<>();
53-
this.selectFieldStack = new ArrayDeque<>();
54-
this.selectOperandTypeStack = new ArrayDeque<>();
5544
}
5645

57-
void visit(final Expr expr) {
46+
void visit(Expr expr) {
5847
switch (expr.getExprKindCase()) {
5948
case CALL_EXPR -> visitCall(expr);
6049
case COMPREHENSION_EXPR -> visitComprehension(expr);
61-
case CONST_EXPR -> visitConst(expr);
62-
case IDENT_EXPR -> visitIdent(expr);
6350
case LIST_EXPR -> visitList(expr);
6451
case SELECT_EXPR -> visitSelect(expr);
6552
case STRUCT_EXPR -> visitStruct(expr);
66-
case EXPRKIND_NOT_SET -> LOGGER.debug("Unknown expression: %s".formatted(expr));
53+
case CONST_EXPR, EXPRKIND_NOT_SET, IDENT_EXPR -> {
54+
}
6755
}
6856
}
6957

70-
private void visitCall(final Expr expr) {
71-
logExpr(expr);
58+
private void visitCall(Expr expr) {
7259
final Expr.Call callExpr = expr.getCallExpr();
7360

74-
final Type targetType = types.get(callExpr.getTarget().getId());
61+
final Type targetType = typeByExpressionId.get(callExpr.getTarget().getId());
7562
final List<Type> argumentTypes = callExpr.getArgsList().stream()
7663
.map(Expr::getId)
77-
.map(types::get)
64+
.map(typeByExpressionId::get)
7865
.toList();
7966
usedFunctionSignatures.add(new FunctionSignature(callExpr.getFunction(), targetType, argumentTypes));
8067

81-
callFunctionStack.push(callExpr.getFunction());
8268
visit(callExpr.getTarget());
83-
for (final Expr argExpr : callExpr.getArgsList()) {
84-
visit(argExpr);
85-
}
86-
callFunctionStack.pop();
69+
callExpr.getArgsList().forEach(this::visit);
8770
}
8871

89-
private void visitComprehension(final Expr expr) {
90-
logExpr(expr);
72+
private void visitComprehension(Expr expr) {
9173
final Expr.Comprehension comprehensionExpr = expr.getComprehensionExpr();
9274

9375
visit(comprehensionExpr.getAccuInit());
@@ -97,40 +79,26 @@ private void visitComprehension(final Expr expr) {
9779
visit(comprehensionExpr.getResult());
9880
}
9981

100-
private void visitConst(final Expr expr) {
101-
logExpr(expr);
102-
}
103-
104-
private void visitIdent(final Expr expr) {
105-
logExpr(expr);
106-
selectOperandTypeStack.push(types.get(expr.getId()));
107-
}
108-
109-
private void visitList(final Expr expr) {
110-
logExpr(expr);
82+
private void visitList(Expr expr) {
83+
expr.getListExpr().getElementsList().forEach(this::visit);
11184
}
11285

113-
private void visitSelect(final Expr expr) {
114-
logExpr(expr);
86+
private void visitSelect(Expr expr) {
11587
final Expr.Select selectExpr = expr.getSelectExpr();
116-
117-
selectFieldStack.push(selectExpr.getField());
118-
selectOperandTypeStack.push(types.get(expr.getId()));
88+
final Type operandType = typeByExpressionId.get(selectExpr.getOperand().getId());
89+
if (operandType != null) {
90+
accessedFieldsByType.put(operandType, selectExpr.getField());
91+
}
11992
visit(selectExpr.getOperand());
120-
accessedFieldsByType.put(selectOperandTypeStack.pop(), selectFieldStack.pop());
121-
}
122-
123-
private void visitStruct(final Expr expr) {
124-
logExpr(expr);
12593
}
12694

127-
private void logExpr(final Expr expr) {
128-
if (!LOGGER.isDebugEnabled()) {
129-
return;
130-
}
131-
132-
LOGGER.debug("Visiting %s (id=%d, fieldStack=%s, fieldTypeStack=%s, functionStack=%s)"
133-
.formatted(expr.getExprKindCase(), expr.getId(), selectFieldStack, selectOperandTypeStack, callFunctionStack));
95+
private void visitStruct(Expr expr) {
96+
expr.getStructExpr().getEntriesList().forEach(entry -> {
97+
if (entry.hasMapKey()) {
98+
visit(entry.getMapKey());
99+
}
100+
visit(entry.getValue());
101+
});
134102
}
135103

136104
MultiValuedMap<Type, String> getAccessedFieldsByType() {

apiserver/src/test/java/org/dependencytrack/policy/cel/CelPolicyScriptHostTest.java

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@
3838
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
3939
import static org.junit.jupiter.api.Assertions.assertThrows;
4040

41-
public class CelPolicyScriptHostTest {
41+
class CelPolicyScriptHostTest {
4242

4343
@Test
44-
public void testCompileWithCache() throws Exception {
44+
void testCompileWithCache() throws Exception {
4545
final var scriptSrc = """
4646
component.name == "foo"
4747
""";
@@ -55,7 +55,7 @@ public void testCompileWithCache() throws Exception {
5555
}
5656

5757
@Test
58-
public void testCompileWithoutCache() throws Exception {
58+
void testCompileWithoutCache() throws Exception {
5959
final var scriptSrc = """
6060
component.name == "foo"
6161
""";
@@ -69,7 +69,7 @@ public void testCompileWithoutCache() throws Exception {
6969
}
7070

7171
@Test
72-
public void testRequirementsAnalysis() throws Exception {
72+
void testRequirementsAnalysis() throws Exception {
7373
final CelPolicyScript compiledScript = CelPolicyScriptHost.getInstance(CelPolicyType.COMPONENT).compile("""
7474
component.resolved_license.groups.exists(licenseGroup, licenseGroup.name == "Permissive")
7575
&& vulns.exists(vuln, vuln.severity in ["HIGH", "CRITICAL"] && has(vuln.aliases))
@@ -96,7 +96,43 @@ public void testRequirementsAnalysis() throws Exception {
9696
}
9797

9898
@Test
99-
public void testVisitVersRangeCheck() {
99+
void testRequirementsAnalysisWithFieldAccessInList() throws Exception {
100+
final CelPolicyScript compiledScript = CelPolicyScriptHost.getInstance(CelPolicyType.COMPONENT).compile("""
101+
[component.name, project.name].exists(name, name == "foo")
102+
""", CacheMode.NO_CACHE);
103+
104+
final Map<Type, Collection<String>> requirements = compiledScript.getRequirements().asMap();
105+
assertThat(requirements).containsOnlyKeys(TYPE_COMPONENT, TYPE_PROJECT);
106+
assertThat(requirements.get(TYPE_COMPONENT)).containsOnly("name");
107+
assertThat(requirements.get(TYPE_PROJECT)).containsOnly("name");
108+
}
109+
110+
@Test
111+
void testRequirementsAnalysisWithFieldAccessInStructValue() throws Exception {
112+
final CelPolicyScript compiledScript = CelPolicyScriptHost.getInstance(CelPolicyType.COMPONENT).compile("""
113+
project.depends_on(v1.Component{name: component.name})
114+
""", CacheMode.NO_CACHE);
115+
116+
final Map<Type, Collection<String>> requirements = compiledScript.getRequirements().asMap();
117+
assertThat(requirements).containsOnlyKeys(TYPE_COMPONENT, TYPE_PROJECT);
118+
assertThat(requirements.get(TYPE_COMPONENT)).containsOnly("name");
119+
assertThat(requirements.get(TYPE_PROJECT)).containsOnly("uuid");
120+
}
121+
122+
@Test
123+
void testRequirementsAnalysisWithFieldAccessInMapKey() throws Exception {
124+
final CelPolicyScript compiledScript = CelPolicyScriptHost.getInstance(CelPolicyType.COMPONENT).compile("""
125+
{component.name: project.name}.size() > 0
126+
""", CacheMode.NO_CACHE);
127+
128+
final Map<Type, Collection<String>> requirements = compiledScript.getRequirements().asMap();
129+
assertThat(requirements).containsOnlyKeys(TYPE_COMPONENT, TYPE_PROJECT);
130+
assertThat(requirements.get(TYPE_COMPONENT)).containsOnly("name");
131+
assertThat(requirements.get(TYPE_PROJECT)).containsOnly("name");
132+
}
133+
134+
@Test
135+
void testVisitVersRangeCheck() {
100136
var exception = assertThrows(ScriptCreateException.class, () -> CelPolicyScriptHost.getInstance(CelPolicyType.COMPONENT).compile("""
101137
project.name == "foo" && project.matches_range("vers:generic<1")
102138
&& project.depends_on(v1.Component{

0 commit comments

Comments
 (0)