Skip to content

Method call analysis based testQuick command #4731

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 46 additions & 14 deletions core/codesig/src/mill/codesig/CodeSig.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,63 @@ package mill.codesig
import mill.codesig.JvmModel.*

object CodeSig {
def compute(
classFiles: Seq[os.Path],
upstreamClasspath: Seq[os.Path],
ignoreCall: (Option[MethodDef], MethodSig) => Boolean,
logger: Logger,
prevTransitiveCallGraphHashesOpt: () => Option[Map[String, Int]]
): CallGraphAnalysis = {
implicit val st: SymbolTable = new SymbolTable()

private def callGraphAnalysis(
classFiles: Seq[os.Path],
upstreamClasspath: Seq[os.Path],
ignoreCall: (Option[MethodDef], MethodSig) => Boolean
)(implicit st: SymbolTable): CallGraphAnalysis = {
val localSummary = LocalSummary.apply(classFiles.iterator.map(os.read.inputStream(_)))
logger.log(localSummary)

val externalSummary = ExternalSummary.apply(localSummary, upstreamClasspath)
logger.log(externalSummary)

val resolvedMethodCalls = ResolvedCalls.apply(localSummary, externalSummary)
logger.log(resolvedMethodCalls)

new CallGraphAnalysis(
localSummary,
resolvedMethodCalls,
externalSummary,
ignoreCall,
logger,
prevTransitiveCallGraphHashesOpt
ignoreCall
)
}

def getCallGraphAnalysis(
classFiles: Seq[os.Path],
upstreamClasspath: Seq[os.Path],
ignoreCall: (Option[MethodDef], MethodSig) => Boolean
): CallGraphAnalysis = {
implicit val st: SymbolTable = new SymbolTable()

callGraphAnalysis(classFiles, upstreamClasspath, ignoreCall)
}

def compute(
classFiles: Seq[os.Path],
upstreamClasspath: Seq[os.Path],
ignoreCall: (Option[MethodDef], MethodSig) => Boolean,
logger: Logger,
prevTransitiveCallGraphHashesOpt: () => Option[Map[String, Int]]
): CallGraphAnalysis = {
implicit val st: SymbolTable = new SymbolTable()

val callAnalysis = callGraphAnalysis(classFiles, upstreamClasspath, ignoreCall)

logger.log(callAnalysis.localSummary)
logger.log(callAnalysis.externalSummary)
logger.log(callAnalysis.resolved)

logger.mandatoryLog(callAnalysis.methodCodeHashes)
logger.mandatoryLog(callAnalysis.prettyCallGraph)
Comment on lines +51 to +52
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these two mandatoryLogs necessary? It seems we only use transitiveCallGraphHashes and spanningInvalidationTree, at least as far as I can tell

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no ideal either, I put them here to mimic the original code, which log everything like this

logger.mandatoryLog(callAnalysis.transitiveCallGraphHashes0)

logger.log(callAnalysis.transitiveCallGraphHashes)

val spanningInvalidationTree = callAnalysis.calculateSpanningInvalidationTree {
prevTransitiveCallGraphHashesOpt()
}

logger.mandatoryLog(spanningInvalidationTree)

callAnalysis
}
}
6 changes: 4 additions & 2 deletions core/codesig/src/mill/codesig/ExternalSummary.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import mill.codesig.JvmModel.*
import org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Opcodes}

import java.net.URLClassLoader
import scala.util.Try

case class ExternalSummary(
directMethods: Map[JCls, Map[MethodSig, Boolean]],
Expand Down Expand Up @@ -47,7 +48,8 @@ object ExternalSummary {

def load(cls: JCls): Unit = methodsPerCls.getOrElse(cls, load0(cls))

def load0(cls: JCls): Unit = {
// Some macros implementations will fail the ClassReader, we can skip them
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What kind of error do those macro implementations produce? We should try and be specific about what errors we catch here, to avoid silencing unexpected errors that may indicate real issues

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you run core.define.compile and then check class/mill/define/Cross$Factory$.class in out folder, you will see something like this

// Source code is decompiled from a .class file using FernFlower decompiler.
package mill.define;

import java.io.Serializable;
import mill.define.internal.CrossMacros;
import mill.define.internal.CrossMacros.;
import scala.runtime.ModuleSerializationProxy;

public final class Cross$Factory$ implements Serializable {
   public static final Cross$Factory$ MODULE$ = new Cross$Factory$();

   public Cross$Factory$() {
   }

   private Object writeReplace() {
      return new ModuleSerializationProxy(Cross$Factory$.class);
   }

   public CrossMacros inline$CrossMacros$i1(final internal x$0) {
      return .MODULE$;
   }
}

The retrieved call from asm yield something like this:
def mill.define.Cross$Factory$#inline$CrossMacros$i1(mill.define.internal)mill.define.internal

As you can see mill/define/internal is not a valid class, so class loader will throw when trying to load this up.

I don't know why the code is like this, I will try to reproduce this and look around for clues

def load0(cls: JCls): Unit = Try {
val visitor = new MyClassVisitor()
val resourcePath =
os.resource(upstreamClassloader) / os.SubPath(cls.name.replace('.', '/') + ".class")
Expand All @@ -61,7 +63,7 @@ object ExternalSummary {
methodsPerCls(cls) = visitor.methods
ancestorsPerCls(cls) = visitor.ancestors
ancestorsPerCls(cls).foreach(load)
}
}.getOrElse(())

(allDirectAncestors ++ allMethodCallParamClasses)
.filter(!localSummary.contains(_))
Expand Down
128 changes: 88 additions & 40 deletions core/codesig/src/mill/codesig/ReachabilityAnalysis.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,17 @@ package mill.codesig

import mill.codesig.JvmModel.*
import mill.internal.{SpanningForest, Tarjans}
import ujson.Obj
import ujson.{Obj, Arr}
import upickle.default.{Writer, writer}

import scala.collection.immutable.SortedMap
import scala.collection.mutable

class CallGraphAnalysis(
localSummary: LocalSummary,
resolved: ResolvedCalls,
externalSummary: ExternalSummary,
ignoreCall: (Option[MethodDef], MethodSig) => Boolean,
logger: Logger,
prevTransitiveCallGraphHashesOpt: () => Option[Map[String, Int]]
val localSummary: LocalSummary,
val resolved: ResolvedCalls,
val externalSummary: ExternalSummary,
ignoreCall: (Option[MethodDef], MethodSig) => Boolean
)(implicit st: SymbolTable) {

val methods: Map[MethodDef, LocalSummary.MethodInfo] = for {
Expand All @@ -40,17 +39,13 @@ class CallGraphAnalysis(
lazy val methodCodeHashes: SortedMap[String, Int] =
methods.map { case (k, vs) => (k.toString, vs.codeHash) }.to(SortedMap)

logger.mandatoryLog(methodCodeHashes)

lazy val prettyCallGraph: SortedMap[String, Array[CallGraphAnalysis.Node]] = {
indexGraphEdges.zip(indexToNodes).map { case (vs, k) =>
(k.toString, vs.map(indexToNodes))
}
.to(SortedMap)
}

logger.mandatoryLog(prettyCallGraph)

def transitiveCallGraphValues[V: scala.reflect.ClassTag](
nodeValues: Array[V],
reduce: (V, V) => V,
Expand Down Expand Up @@ -78,44 +73,45 @@ class CallGraphAnalysis(
.collect { case (CallGraphAnalysis.LocalDef(d), v) => (d.toString, v) }
.to(SortedMap)

logger.mandatoryLog(transitiveCallGraphHashes0)
logger.log(transitiveCallGraphHashes)

lazy val spanningInvalidationTree: Obj = prevTransitiveCallGraphHashesOpt() match {
case Some(prevTransitiveCallGraphHashes) =>
CallGraphAnalysis.spanningInvalidationTree(
prevTransitiveCallGraphHashes,
transitiveCallGraphHashes0,
indexToNodes,
indexGraphEdges
)
case None => ujson.Obj()
def calculateSpanningInvalidationTree(
prevTransitiveCallGraphHashesOpt: => Option[Map[String, Int]]
): Obj = {
prevTransitiveCallGraphHashesOpt match {
case Some(prevTransitiveCallGraphHashes) =>
CallGraphAnalysis.spanningInvalidationTree(
prevTransitiveCallGraphHashes,
transitiveCallGraphHashes0,
indexToNodes,
indexGraphEdges
)
case None => ujson.Obj()
}
}

logger.mandatoryLog(spanningInvalidationTree)
def calculateInvalidatedClassNames(
prevTransitiveCallGraphHashesOpt: => Option[Map[String, Int]]
): Set[String] = {
prevTransitiveCallGraphHashesOpt match {
case Some(prevTransitiveCallGraphHashes) =>
CallGraphAnalysis.invalidatedClassNames(
prevTransitiveCallGraphHashes,
transitiveCallGraphHashes0,
indexToNodes,
indexGraphEdges
)
case None => Set.empty
}
}
}

object CallGraphAnalysis {

/**
* Computes the minimal spanning forest of the that covers the nodes in the
* call graph whose transitive call graph hashes has changed since the last
* run, rendered as a JSON dictionary tree. This provides a great "debug
* view" that lets you easily Cmd-F to find a particular node and then trace
* it up the JSON hierarchy to figure out what upstream node was the root
* cause of the change in the callgraph.
*
* There are typically multiple possible spanning forests for a given graph;
* one is chosen arbitrarily. This is usually fine, since when debugging you
* typically are investigating why there's a path to a node at all where none
* should exist, rather than trying to fully analyse all possible paths
*/
def spanningInvalidationTree(
private def getSpanningForest(
prevTransitiveCallGraphHashes: Map[String, Int],
transitiveCallGraphHashes0: Array[(CallGraphAnalysis.Node, Int)],
indexToNodes: Array[Node],
indexGraphEdges: Array[Array[Int]]
): ujson.Obj = {
) = {
val transitiveCallGraphHashes0Map = transitiveCallGraphHashes0.toMap

val nodesWithChangedHashes = indexGraphEdges
Expand All @@ -135,12 +131,64 @@ object CallGraphAnalysis {
val reverseGraphEdges =
indexGraphEdges.indices.map(reverseGraphMap.getOrElse(_, Array[Int]())).toArray

SpanningForest.apply(reverseGraphEdges, nodesWithChangedHashes, false)
}

/**
* Computes the minimal spanning forest of the that covers the nodes in the
* call graph whose transitive call graph hashes has changed since the last
* run, rendered as a JSON dictionary tree. This provides a great "debug
* view" that lets you easily Cmd-F to find a particular node and then trace
* it up the JSON hierarchy to figure out what upstream node was the root
* cause of the change in the callgraph.
*
* There are typically multiple possible spanning forests for a given graph;
* one is chosen arbitrarily. This is usually fine, since when debugging you
* typically are investigating why there's a path to a node at all where none
* should exist, rather than trying to fully analyse all possible paths
*/
def spanningInvalidationTree(
prevTransitiveCallGraphHashes: Map[String, Int],
transitiveCallGraphHashes0: Array[(CallGraphAnalysis.Node, Int)],
indexToNodes: Array[Node],
indexGraphEdges: Array[Array[Int]]
): ujson.Obj = {
SpanningForest.spanningTreeToJsonTree(
SpanningForest.apply(reverseGraphEdges, nodesWithChangedHashes, false),
getSpanningForest(prevTransitiveCallGraphHashes, transitiveCallGraphHashes0, indexToNodes, indexGraphEdges),
k => indexToNodes(k).toString
)
}

/**
* Get all class names that have their hashcode changed compared to prevTransitiveCallGraphHashes
*/
def invalidatedClassNames(
prevTransitiveCallGraphHashes: Map[String, Int],
transitiveCallGraphHashes0: Array[(CallGraphAnalysis.Node, Int)],
indexToNodes: Array[Node],
indexGraphEdges: Array[Array[Int]]
): Set[String] = {
val rootNode = getSpanningForest(prevTransitiveCallGraphHashes, transitiveCallGraphHashes0, indexToNodes, indexGraphEdges)

val jsonValueQueue = mutable.ArrayDeque[(Int, SpanningForest.Node)]()
jsonValueQueue.appendAll(rootNode.values.toSeq)
val builder = Set.newBuilder[String]

while (jsonValueQueue.nonEmpty) {
val (nodeIndex, node) = jsonValueQueue.removeHead()
node.values.foreach { case (childIndex, childNode) =>
jsonValueQueue.append((childIndex, childNode))
}
indexToNodes(nodeIndex) match {
case CallGraphAnalysis.LocalDef(methodDef) => builder.addOne(methodDef.cls.name)
case CallGraphAnalysis.Call(methodCall) => builder.addOne(methodCall.cls.name)
case CallGraphAnalysis.ExternalClsCall(externalCls) => builder.addOne(externalCls.name)
}
}

builder.result()
}

def indexGraphEdges(
indexToNodes: Array[Node],
methods: Map[MethodDef, LocalSummary.MethodInfo],
Expand Down
21 changes: 16 additions & 5 deletions core/define/src/mill/define/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,16 @@ object Task extends TaskBase {
inline def Command[T](inline t: Result[T])(implicit
inline w: W[T],
inline ctx: mill.define.ModuleCtx
): Command[T] = ${ TaskMacros.commandImpl[T]('t)('w, 'ctx, exclusive = '{ false }) }
): Command[T] = ${ TaskMacros.commandImpl[T]('t)('w, 'ctx, exclusive = '{ false }, persistent = '{ false }) }


/**
* This version allow [[Command]] to be persistent
*/
inline def Command[T](inline persistent: Boolean)(inline t: Result[T])(implicit
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a Command with persistent option.

inline w: W[T],
inline ctx: mill.define.ModuleCtx
): Command[T] = ${ TaskMacros.commandImpl[T]('t)('w, 'ctx, exclusive = '{ false }, persistent = '{ persistent }) }

/**
* @param exclusive Exclusive commands run serially at the end of an evaluation,
Expand All @@ -142,7 +151,7 @@ object Task extends TaskBase {
inline def apply[T](inline t: Result[T])(implicit
inline w: W[T],
inline ctx: mill.define.ModuleCtx
): Command[T] = ${ TaskMacros.commandImpl[T]('t)('w, 'ctx, '{ this.exclusive }) }
): Command[T] = ${ TaskMacros.commandImpl[T]('t)('w, 'ctx, '{ this.exclusive }, '{ false }) }
}

/**
Expand Down Expand Up @@ -396,7 +405,8 @@ class Command[+T](
val ctx0: mill.define.ModuleCtx,
val writer: W[?],
val isPrivate: Option[Boolean],
val exclusive: Boolean
val exclusive: Boolean,
override val persistent: Boolean
) extends NamedTask[T] {

override def asCommand: Some[Command[T]] = Some(this)
Expand Down Expand Up @@ -543,12 +553,13 @@ private object TaskMacros {
)(t: Expr[Result[T]])(
w: Expr[W[T]],
ctx: Expr[mill.define.ModuleCtx],
exclusive: Expr[Boolean]
exclusive: Expr[Boolean],
persistent: Expr[Boolean]
): Expr[Command[T]] = {
appImpl[Command, T](
(in, ev) =>
'{
new Command[T]($in, $ev, $ctx, $w, ${ taskIsPrivate() }, exclusive = $exclusive)
new Command[T]($in, $ev, $ctx, $w, ${ taskIsPrivate() }, exclusive = $exclusive, persistent = $persistent)
},
t
)
Expand Down
33 changes: 33 additions & 0 deletions example/javalib/testing/7-test-quick/build.mill
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
//// SNIPPET:BUILD1
package build
import mill._, javalib._
import os._

object foo extends JavaModule {
object test extends JavaTests {
def testFramework = "com.novocode.junit.JUnitFramework" // Use JUnit 4 framework interface
def mvnDeps = Seq(
mvn"junit:junit:4.13.2", // JUnit 4 itself
mvn"com.novocode:junit-interface:0.11" // sbt-compatible JUnit interface
)
}
// Ultilities for replacing text in files
def replaceBar(args: String*) = Task.Command {
val relativePath = os.RelPath("../../../foo/src/Bar.java")
val filePath = Task.dest() / relativePath
os.write.over(filePath, os.read(filePath).replace(
"""return String.format("Hi, %s!", name);""",
"""return String.format("Ciao, %s!", name);"""
))
}
Comment on lines +15 to +22
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move these into the /** Usage blocks using sed -i.bak rather than having it as part of the Scala code; can grep for that string to see examples of that in the codebase


def replaceFooTest2(args: String*) = Task.Command {
val relativePath = os.RelPath("../../../foo/test/src/FooTest2.java")
val filePath = Task.dest() / relativePath
os.write.over(filePath, os.read(filePath).replace(
"""assertEquals("Hi, " + name + "!", greeted);""",
"""assertEquals("Ciao, " + name + "!", greeted);""",
))
}
}
//// SNIPPET:END
11 changes: 11 additions & 0 deletions example/javalib/testing/7-test-quick/foo/src/Bar.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package foo;

public class Bar {
public static String greet(String name) {
return String.format("Hello, %s!", name);
}

public static String greet2(String name) {
return String.format("Hi, %s!", name);
}
}
Loading