Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ interface ExpVisitor<R> {
fun visitForAllEmbedding(e: ForAllEmbedding): R
fun visitPredicateAccessPermissions(e: PredicateAccessPermissions): R
fun visitCast(e: Cast): R
fun visitUpcast(e: Upcast): R
fun visitIs(e: Is): R
fun visitOld(e: Old): R
fun visitPrimitiveFieldAccess(e: PrimitiveFieldAccess): R
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import org.jetbrains.kotlin.formver.core.embeddings.expression.ExpEmbedding
import org.jetbrains.kotlin.formver.core.embeddings.expression.FunctionCall
import org.jetbrains.kotlin.formver.core.embeddings.expression.MethodCall
import org.jetbrains.kotlin.formver.core.embeddings.expression.PlaceholderVariableEmbedding
import org.jetbrains.kotlin.formver.core.embeddings.expression.withUpcast
import org.jetbrains.kotlin.formver.core.names.PlaceholderReturnVariableName
import org.jetbrains.kotlin.formver.viper.ast.Function
import org.jetbrains.kotlin.formver.viper.ast.Method
Expand All @@ -22,10 +23,13 @@ class NonInlineNamedFunction(val signature: FullNamedFunctionSignature, val hasP
args: List<ExpEmbedding>,
ctx: StmtConversionContext,
): ExpEmbedding {
val wrappedArgs = args.zip(callableType.formalArgTypes).map { (arg, formalType) ->
arg.withUpcast(formalType)
}
return if (hasPureAnnotation) {
FunctionCall(signature, args)
FunctionCall(signature, wrappedArgs)
} else {
MethodCall(signature, args)
MethodCall(signature, wrappedArgs)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ data class Assign(val lhs: VariableEmbedding, val rhs: ExpEmbedding) : UnitResul
override val type: TypeEmbedding = lhs.type

override fun toViperSideEffects(ctx: LinearizationContext) {
rhs.withType(lhs.type).toViperStoringIn(LinearizationVariableEmbedding(lhs.name, lhs.type), ctx)
rhs.withUpcast(lhs.type).toViperStoringIn(LinearizationVariableEmbedding(lhs.name, lhs.type), ctx)
}

context(nameResolver: NameResolver)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,45 @@ fun ExpEmbedding.withType(newType: TypeEmbedding): ExpEmbedding = if (type == ne

fun ExpEmbedding.withType(init: TypeBuilder.() -> PretypeBuilder): ExpEmbedding = withType(buildType(init))

/**
* Upcast this expression to [targetType], emitting the necessary predicate unfolds during linearization.
* Only created when [targetType] is a strict supertype in the class hierarchy and uses [Cast] otherwise.
*/
data class Upcast(override val inner: ExpEmbedding, override val type: TypeEmbedding) : UnaryDirectResultExpEmbedding {
override fun toViper(ctx: LinearizationContext): Exp {
val innerPretype = inner.type.pretype
val targetPretype = type.pretype
require(innerPretype is ClassTypeEmbedding && targetPretype is ClassTypeEmbedding) {
"Upcast can only be applied to classes, but got $innerPretype -> $targetPretype"
}

val innerViper = inner.toViper(ctx)
val innerWrapper = ExpWrapper(innerViper, inner.type)
val predicates = innerPretype.details.hierarchyPathTo(targetPretype)
.map { it.predicateAccess(innerWrapper, ctx.source) }.toList()
val nullGuard = if (inner.type.flags.nullable) innerWrapper.notNullCmp().toViperBuiltinType(ctx) else null

return ctx.applyUnfolding(predicates, innerViper, inner.type, nullGuard)
}

override fun <R> accept(v: ExpVisitor<R>): R = v.visitUpcast(this)
}

fun ExpEmbedding.withUpcast(targetType: TypeEmbedding): ExpEmbedding {
if (type == targetType) return this
if (type.flags.nullable || targetType.flags.nullable) return withType(targetType)
val innerPretype = type.pretype
val targetPretype = targetType.pretype
return if (
innerPretype is ClassTypeEmbedding && targetPretype is ClassTypeEmbedding
&& innerPretype.details.hierarchyPathTo(targetPretype).any() // i.e. targetPretype is strict supertype
) {
Upcast(this, targetType)
} else {
withType(targetType)
}
}


/**
* Implementation of "safe as".
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@
package org.jetbrains.kotlin.formver.core.linearization

import org.jetbrains.kotlin.KtSourceElement
import org.jetbrains.kotlin.formver.core.asPosition
import org.jetbrains.kotlin.formver.core.conversion.ReturnTarget
import org.jetbrains.kotlin.formver.core.embeddings.expression.AnonymousVariableEmbedding
import org.jetbrains.kotlin.formver.core.embeddings.expression.ExpEmbedding
import org.jetbrains.kotlin.formver.core.embeddings.expression.FieldAccess
import org.jetbrains.kotlin.formver.core.embeddings.expression.VariableEmbedding
import org.jetbrains.kotlin.formver.core.embeddings.types.ClassTypeEmbedding
import org.jetbrains.kotlin.formver.core.embeddings.types.PretypeBuilder
import org.jetbrains.kotlin.formver.core.embeddings.types.TypeBuilder
import org.jetbrains.kotlin.formver.core.embeddings.types.TypeEmbedding
import org.jetbrains.kotlin.formver.core.embeddings.types.buildType
import org.jetbrains.kotlin.formver.core.embeddings.types.predicateAccess
import org.jetbrains.kotlin.formver.viper.SymbolicName
import org.jetbrains.kotlin.formver.viper.ast.Declaration
import org.jetbrains.kotlin.formver.viper.ast.Exp
Expand Down Expand Up @@ -57,6 +60,28 @@ interface LinearizationContext {

fun addFieldAccessStoringIn(access: FieldAccess, result: VariableEmbedding)

/**
* Unfold [predicates] in the appropriate way for this context, then return [innerViper].
*
* When [nullGuard] is non-null, the unfolding is conditional on that expression being true (used for nullable
* upcasts, where the value may be null).
*
* - [Linearizer]: emits `Stmt.Unfold` for each predicate, optionally wrapped in `Stmt.If`
* - [PureFunBodyLinearizer]: registers predicates on an SSA variable so [SsaConverter]
* wraps the *usage* (FuncApp/FieldAccess) with `Exp.Unfolding`, not the argument
* - [PureExpLinearizer]: throws — correct placement requires SSA, which this linearizer
* does not have
*/
fun applyUnfolding(
predicates: List<Exp.PredicateAccess>,
innerViper: Exp,
innerType: TypeEmbedding,
nullGuard: Exp? = null,
): Exp {
val unfolded = predicates.foldRight(innerViper) { pred, acc -> Exp.Unfolding(pred, acc) }
return if (nullGuard != null) Exp.TernaryExp(nullGuard, unfolded, innerViper) else unfolded
}

fun addModifier(mod: StmtModifier)

fun resolveVariableName(name: SymbolicName): SymbolicName
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ data class Linearizer(
}

override fun addReturn(returnExp: ExpEmbedding, target: ReturnTarget) {
returnExp.withType(target.variable.type)
returnExp.withUpcast(target.variable.type)
.toViperStoringIn(target.variable, this)
addStatement { target.label.toLink().toViperGoto(this) }
}
Expand All @@ -87,8 +87,8 @@ data class Linearizer(
) =
addStatement {
val condViper = condition.toViperBuiltinType(this)
val thenViper = asBlock { thenBranch.withType(type).toViperMaybeStoringIn(result, this) }
val elseViper = asBlock { elseBranch.withType(type).toViperMaybeStoringIn(result, this) }
val thenViper = asBlock { thenBranch.withUpcast(type).toViperMaybeStoringIn(result, this) }
val elseViper = asBlock { elseBranch.withUpcast(type).toViperMaybeStoringIn(result, this) }
Stmt.If(condViper, thenViper, elseViper, source.asPosition)
}

Expand All @@ -98,6 +98,23 @@ data class Linearizer(
return result.toViper(this)
}

override fun applyUnfolding(
predicates: List<Exp.PredicateAccess>,
innerViper: Exp,
innerType: TypeEmbedding,
nullGuard: Exp?,
): Exp {
val emitUnfolds: LinearizationContext.() -> Unit = {
for (pred in predicates) addStatement { Stmt.Unfold(pred, source.asPosition) }
}
if (nullGuard != null) {
addStatement { Stmt.If(nullGuard, asBlock(emitUnfolds), Stmt.Seqn(), source.asPosition) }
} else {
emitUnfolds()
}
return innerViper
}

override fun addModifier(mod: StmtModifier) {
stmtModifierTracker?.add(mod) ?: error("Not in a statement")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,18 @@ data class PureExpLinearizer(
access.unfoldingInImpl()


override fun applyUnfolding(
predicates: List<Exp.PredicateAccess>,
innerViper: Exp,
innerType: TypeEmbedding,
nullGuard: Exp?,
): Exp {
// Implementing unfolding for pure expressions requires some kind of SSA to move the unfolding in function
// calls or assignments before the actual usage of the expression itself. This is currently not supported here.
// TODO: Implement a behavior similar to PureFunBodyLinearizer
throw PureExpLinearizerMisuseException("applyUnfolding (upcast in specification context not yet supported)")
}

override fun addModifier(mod: StmtModifier) {
throw PureExpLinearizerMisuseException("addModifier")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,12 @@ data class PureFunBodyLinearizer(
ssaConverter.addAssignment(lhs.name, rhs.toViper(this))

override fun addReturn(returnExp: ExpEmbedding, target: ReturnTarget) {
ssaConverter.addReturn(returnExp.toViper(this))
// When we reach a return, we must ensure that the actual type corresponds to the expected type, which might
// require upcasts with unfolding. This can only be applied to variables. Thus, we introduce a new var and
// reuse the existing logic for assignments.
val resultVar = freshAnonVar(target.variable.type)
ssaConverter.addAssignment(resultVar.name, returnExp.toViper(this))
ssaConverter.addReturn(resultVar.toViper(this))
}

override fun addBranch(
Expand All @@ -73,14 +78,18 @@ data class PureFunBodyLinearizer(
conditionExp,
{
if (result != null) {
resultThen = thenBranch.toViper(this)
val thenVar = freshAnonVar(type)
ssaConverter.addAssignment(thenVar.name, thenBranch.toViper(this))
resultThen = thenVar.toViper(this)
} else {
thenBranch.toViperUnusedResult(this)
}
},
{
if (result != null) {
resultElse = elseBranch.toViper(this)
val elseVar = freshAnonVar(type)
ssaConverter.addAssignment(elseVar.name, elseBranch.toViper(this))
resultElse = elseVar.toViper(this)
} else {
elseBranch.toViperUnusedResult(this)
}
Expand Down Expand Up @@ -117,6 +126,21 @@ data class PureFunBodyLinearizer(
return result.toViper(this)
}

override fun applyUnfolding(
predicates: List<Exp.PredicateAccess>,
innerViper: Exp,
innerType: TypeEmbedding,
nullGuard: Exp?,
): Exp {
// Store innerViper in a fresh SSA variable with the predicates as access invariants.
// SsaConverter.withAccessInvariants then propagates them to any FuncApp/FieldAccess that
// uses this variable, wrapping the usage with Exp.Unfolding.
if (predicates.isEmpty()) return innerViper
val result = freshAnonVar(innerType)
ssaConverter.addAssignment(result.name, innerViper, predicates)
return result.toViper(this)
}

override fun addModifier(mod: StmtModifier) {
throw PureFunBodyLinearizerMisuseException("addModifier")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ internal class ExprPurityVisitor(val declaredVariables: MutableSet<VariableEmbed
override fun visitPrimitiveFieldAccess(e: PrimitiveFieldAccess): Boolean = e.allChildrenPure(this)
override fun visitIs(e: Is) = e.allChildrenPure(this)
override fun visitCast(e: Cast): Boolean = e.allChildrenPure(this)
override fun visitUpcast(e: Upcast): Boolean = e.allChildrenPure(this)
override fun visitShared(e: Shared) = e.allChildrenPure(this)

/* ————— impure nodes ————— */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,12 @@ public void testAllFilesPresentInClasses() {
KtTestUtil.assertAllTestsPresentByMetadataWithExcluded(this.getClass(), new File("formver.compiler-plugin/testData/diagnostics/verification/classes"), Pattern.compile("^(.+)\\.kt$"), null, true);
}

@Test
@TestMetadata("conditional_subtype_passing.kt")
public void testConditional_subtype_passing() {
runTest("formver.compiler-plugin/testData/diagnostics/verification/classes/conditional_subtype_passing.kt");
}

@Test
@TestMetadata("multiple_interfaces.kt")
public void testMultiple_interfaces() {
Expand Down Expand Up @@ -906,6 +912,12 @@ public void testHeap_dependent_specifications() {
public void testPure_function_rely_on_branch() {
runTest("formver.compiler-plugin/testData/diagnostics/verification/pure_functions/pure_function_rely_on_branch.kt");
}

@Test
@TestMetadata("pure_upcast.kt")
public void testPure_upcast() {
runTest("formver.compiler-plugin/testData/diagnostics/verification/pure_functions/pure_upcast.kt");
}
}

@Nested
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ public void testAllFilesPresentInClasses() {
KtTestUtil.assertAllTestsPresentByMetadataWithExcluded(this.getClass(), new File("formver.compiler-plugin/testData/diagnostics/verification/classes"), Pattern.compile("^(.+)\\.kt$"), null, true);
}

@Test
@TestMetadata("conditional_subtype_passing.kt")
public void testConditional_subtype_passing() {
runTest("formver.compiler-plugin/testData/diagnostics/verification/classes/conditional_subtype_passing.kt");
}

@Test
@TestMetadata("multiple_interfaces.kt")
public void testMultiple_interfaces() {
Expand Down Expand Up @@ -238,11 +244,23 @@ public void testAllFilesPresentInPure_functions() {
KtTestUtil.assertAllTestsPresentByMetadataWithExcluded(this.getClass(), new File("formver.compiler-plugin/testData/diagnostics/verification/pure_functions"), Pattern.compile("^(.+)\\.kt$"), null, true);
}

@Test
@TestMetadata("heap_dependent_specifications.kt")
public void testHeap_dependent_specifications() {
runTest("formver.compiler-plugin/testData/diagnostics/verification/pure_functions/heap_dependent_specifications.kt");
}

@Test
@TestMetadata("pure_function_rely_on_branch.kt")
public void testPure_function_rely_on_branch() {
runTest("formver.compiler-plugin/testData/diagnostics/verification/pure_functions/pure_function_rely_on_branch.kt");
}

@Test
@TestMetadata("pure_upcast.kt")
public void testPure_upcast() {
runTest("formver.compiler-plugin/testData/diagnostics/verification/pure_functions/pure_upcast.kt");
}
}

@Nested
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ method f$callSuperMethod$TF$T$Bar$T$Int(p$bar: Ref) returns (ret$0: Ref)
ensures df$rt$isSubtype(df$rt$typeOf(ret$0), df$rt$intType())
{
inhale acc(p$c$Bar$shared(p$bar), wildcard)
unfold acc(p$c$Bar$shared(p$bar), wildcard)
ret$0 := f$c$Foo$getY$TF$T$Foo$T$Int(p$bar)
goto lbl$ret$0
label lbl$ret$0
Expand Down
Loading