Skip to content

Commit 9dfb2ec

Browse files
committed
Python code generator: functions with model parameters
1 parent 0a4088a commit 9dfb2ec

3 files changed

Lines changed: 33 additions & 30 deletions

File tree

keymaerax-webui/src/main/scala/edu/cmu/cs/ls/keymaerax/codegen/CodeGenerator.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package edu.cmu.cs.ls.keymaerax.codegen
77
import edu.cmu.cs.ls.keymaerax.core.{AssignAny, BaseVariable, Expression, Function, NamedSymbol, Program, Real, StaticSemantics, Tuple, Unit}
88
import edu.cmu.cs.ls.keymaerax.infrastruct.ExpressionTraversal.{ExpressionTraversalFunction, StopTraversal}
99
import edu.cmu.cs.ls.keymaerax.infrastruct.{ExpressionTraversal, PosInExpr}
10+
import edu.cmu.cs.ls.keymaerax.parser.InterpretedSymbols
1011

1112
object CodeGenerator {
1213
/**
@@ -15,8 +16,7 @@ object CodeGenerator {
1516
def getParameters(expr: Expression, exclude: Set[BaseVariable]): Set[NamedSymbol] =
1617
StaticSemantics.symbols(expr)
1718
.filter({
18-
case Function("abs", None, Real, Real, true) => false
19-
case Function("min" | "max", None, Tuple(Real, Real), Real, true) => false
19+
case InterpretedSymbols.absF | InterpretedSymbols.minF | InterpretedSymbols.maxF => false
2020
case Function(name, _, Unit, _, _) => !exclude.exists(v => v.name == name.stripSuffix("post"))
2121
case BaseVariable(name, _, _) => !exclude.exists(v => v.name == name.stripSuffix("post"))
2222
case _ => false //@note any other function or differential symbol

keymaerax-webui/src/main/scala/edu/cmu/cs/ls/keymaerax/codegen/PythonGenerator.scala

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -78,18 +78,31 @@ class PythonGenerator(bodyGenerator: CodeGenerator, defs: Declaration = Declarat
7878
override def apply(expr: Expression, stateVars: Set[BaseVariable], inputVars: Set[BaseVariable], fileName: String): (String, String) =
7979
generateMonitoredCtrlCCode(expr, stateVars, inputVars, fileName)
8080

81+
/** The name of the monitor/control function argument representing monitor parameters. */
82+
private val FUNC_PARAMS_NAME = "params"
83+
84+
/** Compiles primitive expressions with the appropriate params/curr/pre struct location. */
85+
private def primitiveExprGenerator(parameters: Set[NamedSymbol]) = new PythonFormulaTermGenerator({
86+
case t: Variable =>
87+
if (parameters.contains(t)) FUNC_PARAMS_NAME + "."
88+
else ""
89+
case FuncOf(fn, Nothing) =>
90+
if (parameters.contains(fn)) FUNC_PARAMS_NAME + "."
91+
else throw new CodeGenerationException("Non-posterior, non-parameter function symbol " + fn.prettyString + " is not supported")
92+
})
93+
8194
/** Prints function definitions of symbols in `mentionedIn`. */
82-
private def printFuncDefs(mentionedIn: Expression, defs: Declaration): String = {
83-
val what = StaticSemantics.symbols(mentionedIn)
84-
defs.decls.
95+
private def printFuncDefs(mentionedIn: Expression, defs: Declaration, parameters: Set[NamedSymbol], printed: Set[NamedSymbol] = Set.empty): String = {
96+
val what = StaticSemantics.symbols(mentionedIn) -- printed
97+
val printing = defs.decls.
8598
filter({
8699
case (n, s@Signature(_, Real | Bool, Some(args), _, _)) => args.nonEmpty && what.contains(Declaration.asNamedSymbol(n, s))
87-
case _ => false }).
88-
map({
100+
case _ => false })
101+
printing.map({
89102
case (name, Signature(_, codomain, Some(args), interpretation, _)) =>
90103
def ptype(s: Sort): String = s match {
91104
case Real => "np.float64"
92-
case Bool => "Bool"
105+
case Bool => "bool"
93106
case _ => throw new IllegalArgumentException("Sort " + s + " not supported")
94107
}
95108
val pargs = args.map({ case (n, s) => s"${n.prettyString}: ${ptype(s)}" }).mkString(", ")
@@ -102,11 +115,14 @@ class PythonGenerator(bodyGenerator: CodeGenerator, defs: Declaration = Declarat
102115
val argsSubst = USubst(args.zipWithIndex.flatMap({ case ((Name(n, idx), s), i) =>
103116
(if (i == 0) List(SubstitutionPair(DotTerm(s, None), Variable(n, idx, s))) else Nil) :+
104117
SubstitutionPair(DotTerm(s, Some(i)), Variable(n, idx, s)) }))
105-
val body = interpretation match {
106-
case Some(i) => (new CFormulaTermGenerator(_ => ""))(argsSubst(i))._2
107-
case _ => PythonPrettyPrinter.numberLiteral(0.0) + " /* todo */"
118+
val (interpretationDefs, body) = interpretation match {
119+
case Some(i) =>
120+
(printFuncDefs(i, defs, parameters, printed ++ printing.map({ case (n, s) => Declaration.asNamedSymbol(n, s) }).toSet),
121+
primitiveExprGenerator(parameters)(argsSubst(i))._2)
122+
case _ => ("", PythonPrettyPrinter.numberLiteral(0.0) + " # todo")
108123
}
109-
s"""${name.prettyString}($pargs) -> ${ptype(codomain)}:
124+
s"""$interpretationDefs
125+
|def ${name.prettyString}($FUNC_PARAMS_NAME: Params, $pargs) -> ${ptype(codomain)}:
110126
| return $body
111127
|""".stripMargin
112128
}).mkString("\n\n")
@@ -117,7 +133,7 @@ class PythonGenerator(bodyGenerator: CodeGenerator, defs: Declaration = Declarat
117133
val names = StaticSemantics.symbols(expr).map(nameIdentifier)
118134
require(names.intersect(PythonGenerator.RESERVED_NAMES).isEmpty, "Unexpected reserved Python names encountered: " +
119135
names.intersect(PythonGenerator.RESERVED_NAMES).mkString(","))
120-
val parameters = CodeGenerator.getParameters(expr, stateVars)
136+
val parameters = CodeGenerator.getParameters(defs.exhaustiveSubst(expr), stateVars)
121137

122138
val (bodyBody, bodyDefs) = bodyGenerator(expr, stateVars, inputVars, fileName)
123139

@@ -127,7 +143,7 @@ class PythonGenerator(bodyGenerator: CodeGenerator, defs: Declaration = Declarat
127143
printStateDeclaration(stateVars) +
128144
printInputDeclaration(inputVars) +
129145
printVerdictDeclaration +
130-
printFuncDefs(expr, defs) +
146+
printFuncDefs(expr, defs, parameters) +
131147
bodyDefs, bodyBody)
132148
}
133149
}

keymaerax-webui/src/main/scala/edu/cmu/cs/ls/keymaerax/codegen/PythonMonitorGenerator.scala

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@ import edu.cmu.cs.ls.keymaerax.btactics.{ModelPlex, SimplifierV3}
88
import edu.cmu.cs.ls.keymaerax.codegen.CFormulaTermGenerator.nameIdentifier
99
import edu.cmu.cs.ls.keymaerax.core._
1010
import edu.cmu.cs.ls.keymaerax.infrastruct.Augmentors._
11-
import edu.cmu.cs.ls.keymaerax.parser.{InterpretedSymbols, KeYmaeraXPrettyPrinter}
11+
import edu.cmu.cs.ls.keymaerax.parser.{Declaration, InterpretedSymbols, KeYmaeraXPrettyPrinter}
1212

1313
/**
1414
* Generates a monitor from a ModelPlex expression.
1515
* @author Stefan Mitsch
1616
*/
17-
class PythonMonitorGenerator(conjunctionsAs: Symbol) extends CodeGenerator {
17+
class PythonMonitorGenerator(conjunctionsAs: Symbol, defs: Declaration = Declaration(Map.empty)) extends CodeGenerator {
1818
override def apply(expr: Expression, stateVars: Set[BaseVariable], inputVars: Set[BaseVariable],
1919
modelName: String): (String, String) =
2020
generateMonitoredCtrlPythonCode(expr, stateVars)
@@ -27,7 +27,7 @@ class PythonMonitorGenerator(conjunctionsAs: Symbol) extends CodeGenerator {
2727
require(names.intersect(PythonGenerator.RESERVED_NAMES).isEmpty, "Unexpected reserved C names encountered: " +
2828
names.intersect(PythonGenerator.RESERVED_NAMES).mkString(","))
2929

30-
val parameters = getParameters(expr, stateVars)
30+
val parameters = CodeGenerator.getParameters(defs.exhaustiveSubst(expr), stateVars)
3131

3232
val monitorDistFuncHead =
3333
s"""def boundaryDist($MONITOR_PRE_STATE_NAME: State, $MONITOR_CURR_STATE_NAME: State, $MONITOR_PARAMS_NAME: Params) -> Verdict:
@@ -113,19 +113,6 @@ class PythonMonitorGenerator(conjunctionsAs: Symbol) extends CodeGenerator {
113113
errorIds(fml)
114114
}
115115

116-
/**
117-
* Returns a set of names (excluding names in `vars` and interpreted functions) that are immutable parameters of the
118-
* expression `expr`. */
119-
private def getParameters(expr: Expression, exclude: Set[BaseVariable]): Set[NamedSymbol] =
120-
StaticSemantics.symbols(expr)
121-
.filter({
122-
case Function("abs", None, Real, Real, true) => false
123-
case Function("min" | "max", None, Tuple(Real, Real), Real, true) => false
124-
case Function(name, _, Unit, _, _) => !exclude.exists(_.name == name.stripSuffix("post"))
125-
case _: Function => false
126-
case BaseVariable(name, _, _) => !exclude.exists(_.name == name.stripSuffix("post"))
127-
})
128-
129116
/** Compiles primitive expressions with the appropriate params/curr/pre struct location. */
130117
private def primitiveExprGenerator(parameters: Set[NamedSymbol]) = new PythonFormulaTermGenerator({
131118
case t: Variable if parameters.contains(t) => MONITOR_PARAMS_NAME + "."

0 commit comments

Comments
 (0)