Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,9 @@ class AstCreator(
case typedExpr: KtLambdaExpression => Seq(astForLambda(typedExpr, argIdxMaybe, argNameMaybe, annotations))
case typedExpr: KtNameReferenceExpression if typedExpr.getReferencedNameElementType == KtTokens.IDENTIFIER =>
Seq(astForNameReference(typedExpr, argIdxMaybe, argNameMaybe, annotations))
// TODO: callable reference
case _: KtNameReferenceExpression => Seq()
case typedExpr: KtCallableReferenceExpression =>
Seq(astForCallableReferenceExpression(typedExpr, argIdxMaybe, argNameMaybe, annotations))
case typedExpr: KtObjectLiteralExpression =>
Seq(astForObjectLiteralExpr(typedExpr, argIdxMaybe, argNameMaybe, annotations))
case typedExpr: KtParenthesizedExpression =>
Expand Down Expand Up @@ -331,7 +332,6 @@ class AstCreator(
case null =>
logDebugWithTestAndStackTrace("Received null expression! Skipping...")
Seq()
// TODO: handle `KtCallableReferenceExpression` like `this::baseTerrain`
case unknownExpr =>
logger.debug(
s"Creating empty AST node for unknown expression `${unknownExpr.getClass}` with text `${unknownExpr.getText}`."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import io.shiftleft.codepropertygraph.generated.Operators
import io.shiftleft.codepropertygraph.generated.nodes.NewMethodRef
import org.jetbrains.kotlin.descriptors.DescriptorVisibilities
import org.jetbrains.kotlin.descriptors.FunctionDescriptor
import org.jetbrains.kotlin.descriptors.ClassDescriptor
import org.jetbrains.kotlin.lexer.KtToken
import org.jetbrains.kotlin.lexer.KtTokens
import org.jetbrains.kotlin.psi.*
Expand Down Expand Up @@ -667,4 +668,57 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
.withChildren(annotations.map(astForAnnotationEntry))
}

def astForCallableReferenceExpression(
expr: KtCallableReferenceExpression,
argIdx: Option[Int],
argNameMaybe: Option[String],
annotations: Seq[KtAnnotationEntry] = Seq()
): Ast = {
// These represent constructs like:
// - ::func (unbound reference)
// - obj/this::func (bound reference with receiver)
// - Type::func (static reference)

val callableNameExpr = expr.getCallableReference
val methodName = callableNameExpr.getText

val receiverTypeFullName = Option(expr.getReceiverExpression) match {
case Some(receiver: KtNameReferenceExpression) if typeInfoProvider.isReferenceToClass(receiver) =>
val typeName = receiver.getText
val nameToClass = expr.getContainingKtFile.getDeclarations.asScala.collect { case c: KtClass =>
c.getName -> c
}.toMap

if (nameToClass.contains(typeName)) {
val klass = nameToClass(typeName)
val packageName = klass.getContainingKtFile.getPackageFqName.toString
if (packageName.isEmpty) Some(typeName) else Some(s"$packageName.$typeName")
} else {
exprTypeFullName(receiver)
}
case Some(receiver) =>
exprTypeFullName(receiver)
case None =>
None
}

val namespacePrefix = receiverTypeFullName.getOrElse(Defines.UnresolvedNamespace)

val funcDesc = bindingUtils.getCalledFunctionDesc(callableNameExpr)

val signature = funcDesc
.orElse(getAmbiguousFuncDescIfSignaturesEqual(callableNameExpr))
.flatMap(nameRenderer.funcDescSignature)
.getOrElse(Defines.UnresolvedSignature)

val methodFullName = nameRenderer.combineFunctionFullName(s"$namespacePrefix.$methodName", signature)

val methodRefTypeFullName = receiverTypeFullName.map(registerType).getOrElse(TypeConstants.Any)

val methodRefNode_ = methodRefNode(expr, expr.getText, methodFullName, methodRefTypeFullName)

val node = withArgumentIndex(methodRefNode_, argIdx).argumentName(argNameMaybe)

Ast(node).withChildren(annotations.map(astForAnnotationEntry))
}
}
Original file line number Diff line number Diff line change
@@ -1,30 +1,157 @@
package io.joern.kotlin2cpg.querying

import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture
import io.shiftleft.codepropertygraph.generated.nodes.{Identifier, MethodRef}
import io.shiftleft.semanticcpg.language.*

class CallableReferenceTests extends KotlinCode2CpgFixture(withOssDataflow = false) {

"CPG for code with simple callback usage" should {
"resolved callable references as call argument should be handled correctly" in {
val cpg = code("""
|package com.test
|
|class Bar {
| fun bar(x: Int) {}
|}
|
|class Foo {
| fun doNothing(c: (Int) -> Unit) {}
|
| fun foo() {
| doNothing(Bar::bar)
| }
|}
|""".stripMargin)

inside(cpg.call.name("doNothing").argument.l) { case List(thisArg: Identifier, methodRef: MethodRef) =>
thisArg.name shouldBe "this"
Comment on lines +26 to +27
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Testing for this argument does not really belong to the thing we want to check here. Please remove it.


methodRef.methodFullName shouldBe "com.test.Bar.bar:void(int)"
methodRef.typeFullName shouldBe "com.test.Bar"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is wrong. You seem to have looked at Javasrc2cpg, but that is/was sadly not a good example and Johannes is fixing it there as we speak.

You basically have to create a new synthetic TYPE_DECL and a corresponding TYPE which is then referenced here in typeFullName. This new TYPE_DECL then needs bindings that reflect the implemented interface.

I suggest you and @johannescoetzee talk tomorrow so that he can tell you the details.
One pointer where most of this is already implemented and tested are the Lambda function tests in Javasrc2cpg here

methodRef.code shouldBe "Bar::bar"
}
}

"callable references with unresolved signature should be handled correctly" in {
val cpg = code("""
|fun isOdd(x: Int) = x % 2 != 0
|package com.test
|
|class Foo {
| fun doNothing(c: (Int) -> Unit) {}
|
| fun foo() {
| doNothing(Bar::bar)
| }
|}
|""".stripMargin)

// When Bar class doesn't exist, the callable reference should still be created
val methodRefs = cpg.methodRef.code("Bar::bar").l
methodRefs.size shouldBe 1
val methodRef = methodRefs.head

methodRef.methodFullName shouldBe "<unresolvedNamespace>.bar:<unresolvedSignature>"
methodRef.typeFullName shouldBe "ANY"
methodRef.code shouldBe "Bar::bar"
}

"unresolved callable references should be handled correctly" in {
val cpg = code("""
|fun main() {
| someFunction(::unknownFunction)
|}
|""".stripMargin)

val methodRefs = cpg.methodRef.code("::unknownFunction").l
methodRefs.size shouldBe 1
val methodRef = methodRefs.head

methodRef.methodFullName shouldBe "<unresolvedNamespace>.unknownFunction:<unresolvedSignature>"
methodRef.typeFullName shouldBe "ANY"
methodRef.code shouldBe "::unknownFunction"
}

"resolved instance method refs should be handled correctly" in {
val cpg = code("""
|package com.test
|
|fun firstOdd(x: Int): Int {
| val numbers = listOf(1, 2, x)
| val y = numbers.filter(::isOdd)[0]
| return y
|class Foo {
| fun doNothing(c: (Int) -> Unit) {}
|
| fun func(x: Int) {}
|
| fun foo() {
| val f = Foo()
| doNothing(f::func)
| }
|}
|""".stripMargin)

inside(cpg.call.name("doNothing").argument.l) { case List(thisArg: Identifier, methodRef: MethodRef) =>
thisArg.name shouldBe "this"

methodRef.methodFullName shouldBe "com.test.Foo.func:void(int)"
methodRef.typeFullName shouldBe "com.test.Foo"
methodRef.code shouldBe "f::func"
}
}

"instance method refs with 'this' receiver should be handled correctly" in {
val cpg = code("""
|package com.test
|
|class Foo {
| fun doNothing(c: (Int) -> Unit) {}
|
| fun func(x: Int) {}
|
|fun main(args : Array<String>) {
| println(firstOdd(3))
| fun foo() {
| doNothing(this::func)
| }
|}
|""".stripMargin)

"should have a non-0 number of CALL nodes" in {
cpg.call.size should not be 0
inside(cpg.call.name("doNothing").argument.l) { case List(thisArg: Identifier, methodRef: MethodRef) =>
thisArg.name shouldBe "this"

methodRef.methodFullName shouldBe "com.test.Foo.func:void(int)"
methodRef.typeFullName shouldBe "com.test.Foo"
methodRef.code shouldBe "this::func"
}
}

// TODO: add the rest of the test cases
"callable references with collection receiver should be handled correctly" in {
val cpg = code("""
|fun main() {
| val numbers = listOf(1, 2, 3)
| val reference = numbers::forEach
|}
|""".stripMargin)

val methodRefs = cpg.methodRef.code("numbers::forEach").l
methodRefs.size shouldBe 1
val methodRef = methodRefs.head

// The exact type will depend on Kotlin stdlib resolution
methodRef.methodFullName should include("forEach")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forEach is not a full name. I would have expected something like e.g. kotlin.collection.forEach(...).

methodRef.typeFullName should include("List")
methodRef.code shouldBe "numbers::forEach"
}

"unbound callable references to top-level functions should be handled correctly" in {
val cpg = code("""
|fun isOdd(x: Int) = x % 2 != 0
|
|fun main() {
| val numbers = listOf(1, 2, 3)
| numbers.filter(::isOdd)
|}
|""".stripMargin)

inside(cpg.call.name("filter").argument.l) { case List(receiver, methodRef: MethodRef) =>
methodRef.methodFullName shouldBe "<unresolvedNamespace>.isOdd:boolean(int)"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does isOdd really have <unresolvedNamespace>.isOdd:boolean(int) as method full name? If yes that is already a problem, if not the methodRef should use the same method full name as the referenced method.

methodRef.typeFullName shouldBe "ANY"
methodRef.code shouldBe "::isOdd"
}
}
}
Loading