@@ -2,19 +2,17 @@ package mill.codesig
2
2
3
3
import mill .codesig .JvmModel .*
4
4
import mill .internal .{SpanningForest , Tarjans }
5
- import ujson .{Arr , Obj }
5
+ import ujson .{Obj , Arr }
6
6
import upickle .default .{Writer , writer }
7
7
8
8
import scala .collection .immutable .SortedMap
9
9
import scala .collection .mutable
10
10
11
11
class CallGraphAnalysis (
12
- localSummary : LocalSummary ,
13
- resolved : ResolvedCalls ,
14
- externalSummary : ExternalSummary ,
15
- ignoreCall : (Option [MethodDef ], MethodSig ) => Boolean ,
16
- logger : Logger ,
17
- prevTransitiveCallGraphHashesOpt : () => Option [Map [String , Int ]]
12
+ val localSummary : LocalSummary ,
13
+ val resolved : ResolvedCalls ,
14
+ val externalSummary : ExternalSummary ,
15
+ ignoreCall : (Option [MethodDef ], MethodSig ) => Boolean
18
16
)(implicit st : SymbolTable ) {
19
17
20
18
val methods : Map [MethodDef , LocalSummary .MethodInfo ] = for {
@@ -41,17 +39,13 @@ class CallGraphAnalysis(
41
39
lazy val methodCodeHashes : SortedMap [String , Int ] =
42
40
methods.map { case (k, vs) => (k.toString, vs.codeHash) }.to(SortedMap )
43
41
44
- logger.mandatoryLog(methodCodeHashes)
45
-
46
42
lazy val prettyCallGraph : SortedMap [String , Array [CallGraphAnalysis .Node ]] = {
47
43
indexGraphEdges.zip(indexToNodes).map { case (vs, k) =>
48
44
(k.toString, vs.map(indexToNodes))
49
45
}
50
46
.to(SortedMap )
51
47
}
52
48
53
- logger.mandatoryLog(prettyCallGraph)
54
-
55
49
def transitiveCallGraphValues [V : scala.reflect.ClassTag ](
56
50
nodeValues : Array [V ],
57
51
reduce : (V , V ) => V ,
@@ -79,45 +73,45 @@ class CallGraphAnalysis(
79
73
.collect { case (CallGraphAnalysis .LocalDef (d), v) => (d.toString, v) }
80
74
.to(SortedMap )
81
75
82
- logger.mandatoryLog(transitiveCallGraphHashes0)
83
- logger.log(transitiveCallGraphHashes)
84
-
85
- lazy val (spanningInvalidationTree : Obj , invalidClassNames : Arr ) = prevTransitiveCallGraphHashesOpt() match {
86
- case Some (prevTransitiveCallGraphHashes) =>
87
- CallGraphAnalysis .spanningInvalidationTree(
88
- prevTransitiveCallGraphHashes,
89
- transitiveCallGraphHashes0,
90
- indexToNodes,
91
- indexGraphEdges
92
- )
93
- case None => ujson.Obj () -> ujson.Arr ()
76
+ def calculateSpanningInvalidationTree (
77
+ prevTransitiveCallGraphHashesOpt : => Option [Map [String , Int ]]
78
+ ): Obj = {
79
+ prevTransitiveCallGraphHashesOpt match {
80
+ case Some (prevTransitiveCallGraphHashes) =>
81
+ CallGraphAnalysis .spanningInvalidationTree(
82
+ prevTransitiveCallGraphHashes,
83
+ transitiveCallGraphHashes0,
84
+ indexToNodes,
85
+ indexGraphEdges
86
+ )
87
+ case None => ujson.Obj ()
88
+ }
94
89
}
95
90
96
- logger.mandatoryLog(spanningInvalidationTree)
97
- logger.mandatoryLog(invalidClassNames)
91
+ def calculateInvalidatedClassNames (
92
+ prevTransitiveCallGraphHashesOpt : => Option [Map [String , Int ]]
93
+ ): Set [String ] = {
94
+ prevTransitiveCallGraphHashesOpt match {
95
+ case Some (prevTransitiveCallGraphHashes) =>
96
+ CallGraphAnalysis .invalidatedClassNames(
97
+ prevTransitiveCallGraphHashes,
98
+ transitiveCallGraphHashes0,
99
+ indexToNodes,
100
+ indexGraphEdges
101
+ )
102
+ case None => Set .empty
103
+ }
104
+ }
98
105
}
99
106
100
107
object CallGraphAnalysis {
101
108
102
- /**
103
- * Computes the minimal spanning forest of the that covers the nodes in the
104
- * call graph whose transitive call graph hashes has changed since the last
105
- * run, rendered as a JSON dictionary tree. This provides a great "debug
106
- * view" that lets you easily Cmd-F to find a particular node and then trace
107
- * it up the JSON hierarchy to figure out what upstream node was the root
108
- * cause of the change in the callgraph.
109
- *
110
- * There are typically multiple possible spanning forests for a given graph;
111
- * one is chosen arbitrarily. This is usually fine, since when debugging you
112
- * typically are investigating why there's a path to a node at all where none
113
- * should exist, rather than trying to fully analyse all possible paths
114
- */
115
- def spanningInvalidationTree (
109
+ private def getSpanningForest (
116
110
prevTransitiveCallGraphHashes : Map [String , Int ],
117
111
transitiveCallGraphHashes0 : Array [(CallGraphAnalysis .Node , Int )],
118
112
indexToNodes : Array [Node ],
119
113
indexGraphEdges : Array [Array [Int ]]
120
- ): (ujson. Obj , ujson. Arr ) = {
114
+ ) = {
121
115
val transitiveCallGraphHashes0Map = transitiveCallGraphHashes0.toMap
122
116
123
117
val nodesWithChangedHashes = indexGraphEdges
@@ -137,23 +131,62 @@ object CallGraphAnalysis {
137
131
val reverseGraphEdges =
138
132
indexGraphEdges.indices.map(reverseGraphMap.getOrElse(_, Array [Int ]())).toArray
139
133
140
- val spanningForest = SpanningForest .apply(reverseGraphEdges, nodesWithChangedHashes, false )
134
+ SpanningForest .apply(reverseGraphEdges, nodesWithChangedHashes, false )
135
+ }
141
136
142
- val spanningInvalidationTree = SpanningForest .spanningTreeToJsonTree(
143
- spanningForest,
137
+ /**
138
+ * Computes the minimal spanning forest of the that covers the nodes in the
139
+ * call graph whose transitive call graph hashes has changed since the last
140
+ * run, rendered as a JSON dictionary tree. This provides a great "debug
141
+ * view" that lets you easily Cmd-F to find a particular node and then trace
142
+ * it up the JSON hierarchy to figure out what upstream node was the root
143
+ * cause of the change in the callgraph.
144
+ *
145
+ * There are typically multiple possible spanning forests for a given graph;
146
+ * one is chosen arbitrarily. This is usually fine, since when debugging you
147
+ * typically are investigating why there's a path to a node at all where none
148
+ * should exist, rather than trying to fully analyse all possible paths
149
+ */
150
+ def spanningInvalidationTree (
151
+ prevTransitiveCallGraphHashes : Map [String , Int ],
152
+ transitiveCallGraphHashes0 : Array [(CallGraphAnalysis .Node , Int )],
153
+ indexToNodes : Array [Node ],
154
+ indexGraphEdges : Array [Array [Int ]]
155
+ ): ujson.Obj = {
156
+ SpanningForest .spanningTreeToJsonTree(
157
+ getSpanningForest(prevTransitiveCallGraphHashes, transitiveCallGraphHashes0, indexToNodes, indexGraphEdges),
144
158
k => indexToNodes(k).toString
145
159
)
160
+ }
146
161
147
- val invalidSet = invalidClassNameSet(
148
- spanningForest,
149
- indexToNodes.map {
150
- case LocalDef (call) => call.cls.name
151
- case Call (call) => call.cls.name
152
- case ExternalClsCall (cls) => cls.name
162
+ /**
163
+ * Get all class names that have their hashcode changed compared to prevTransitiveCallGraphHashes
164
+ */
165
+ def invalidatedClassNames (
166
+ prevTransitiveCallGraphHashes : Map [String , Int ],
167
+ transitiveCallGraphHashes0 : Array [(CallGraphAnalysis .Node , Int )],
168
+ indexToNodes : Array [Node ],
169
+ indexGraphEdges : Array [Array [Int ]]
170
+ ): Set [String ] = {
171
+ val rootNode = getSpanningForest(prevTransitiveCallGraphHashes, transitiveCallGraphHashes0, indexToNodes, indexGraphEdges)
172
+
173
+ val jsonValueQueue = mutable.ArrayDeque [(Int , SpanningForest .Node )]()
174
+ jsonValueQueue.appendAll(rootNode.values.toSeq)
175
+ val builder = Set .newBuilder[String ]
176
+
177
+ while (jsonValueQueue.nonEmpty) {
178
+ val (nodeIndex, node) = jsonValueQueue.removeHead()
179
+ node.values.foreach { case (childIndex, childNode) =>
180
+ jsonValueQueue.append((childIndex, childNode))
153
181
}
154
- )
182
+ indexToNodes(nodeIndex) match {
183
+ case CallGraphAnalysis .LocalDef (methodDef) => builder.addOne(methodDef.cls.name)
184
+ case CallGraphAnalysis .Call (methodCall) => builder.addOne(methodCall.cls.name)
185
+ case CallGraphAnalysis .ExternalClsCall (externalCls) => builder.addOne(externalCls.name)
186
+ }
187
+ }
155
188
156
- (spanningInvalidationTree, invalidSet )
189
+ builder.result( )
157
190
}
158
191
159
192
def indexGraphEdges (
@@ -278,24 +311,6 @@ object CallGraphAnalysis {
278
311
}
279
312
}
280
313
281
- private def invalidClassNameSet (
282
- spanningForest : SpanningForest .Node ,
283
- indexToClassName : Array [String ]
284
- ): Set [String ] = {
285
- val queue = mutable.ArrayBuffer .empty[(Int , SpanningForest .Node )]
286
- val result = mutable.Set .empty[String ]
287
-
288
- queue.appendAll(spanningForest.values)
289
-
290
- while (queue.nonEmpty) {
291
- val (index, node) = queue.remove(0 )
292
- result += indexToClassName(index)
293
- queue.appendAll(node.values)
294
- }
295
-
296
- result.toSet
297
- }
298
-
299
314
/**
300
315
* Represents the three types of nodes in our call graph. These are kept heterogeneous
301
316
* because flattening them out into a homogenous graph of MethodDef -> MethodDef edges
0 commit comments