@@ -78,18 +78,31 @@ class PythonGenerator(bodyGenerator: CodeGenerator, defs: Declaration = Declarat
7878 override def apply (expr : Expression , stateVars : Set [BaseVariable ], inputVars : Set [BaseVariable ], fileName : String ): (String , String ) =
7979 generateMonitoredCtrlCCode(expr, stateVars, inputVars, fileName)
8080
81+ /** The name of the monitor/control function argument representing monitor parameters. */
82+ private val FUNC_PARAMS_NAME = " params"
83+
84+ /** Compiles primitive expressions with the appropriate params/curr/pre struct location. */
85+ private def primitiveExprGenerator (parameters : Set [NamedSymbol ]) = new PythonFormulaTermGenerator ({
86+ case t : Variable =>
87+ if (parameters.contains(t)) FUNC_PARAMS_NAME + " ."
88+ else " "
89+ case FuncOf (fn, Nothing ) =>
90+ if (parameters.contains(fn)) FUNC_PARAMS_NAME + " ."
91+ else throw new CodeGenerationException (" Non-posterior, non-parameter function symbol " + fn.prettyString + " is not supported" )
92+ })
93+
8194 /** Prints function definitions of symbols in `mentionedIn`. */
82- private def printFuncDefs (mentionedIn : Expression , defs : Declaration ): String = {
83- val what = StaticSemantics .symbols(mentionedIn)
84- defs.decls.
95+ private def printFuncDefs (mentionedIn : Expression , defs : Declaration , parameters : Set [ NamedSymbol ], printed : Set [ NamedSymbol ] = Set .empty ): String = {
96+ val what = StaticSemantics .symbols(mentionedIn) -- printed
97+ val printing = defs.decls.
8598 filter({
8699 case (n, s@ Signature (_, Real | Bool , Some (args), _, _)) => args.nonEmpty && what.contains(Declaration .asNamedSymbol(n, s))
87- case _ => false }).
88- map({
100+ case _ => false })
101+ printing. map({
89102 case (name, Signature (_, codomain, Some (args), interpretation, _)) =>
90103 def ptype (s : Sort ): String = s match {
91104 case Real => " np.float64"
92- case Bool => " Bool "
105+ case Bool => " bool "
93106 case _ => throw new IllegalArgumentException (" Sort " + s + " not supported" )
94107 }
95108 val pargs = args.map({ case (n, s) => s " ${n.prettyString}: ${ptype(s)}" }).mkString(" , " )
@@ -102,11 +115,14 @@ class PythonGenerator(bodyGenerator: CodeGenerator, defs: Declaration = Declarat
102115 val argsSubst = USubst (args.zipWithIndex.flatMap({ case ((Name (n, idx), s), i) =>
103116 (if (i == 0 ) List (SubstitutionPair (DotTerm (s, None ), Variable (n, idx, s))) else Nil ) :+
104117 SubstitutionPair (DotTerm (s, Some (i)), Variable (n, idx, s)) }))
105- val body = interpretation match {
106- case Some (i) => (new CFormulaTermGenerator (_ => " " ))(argsSubst(i))._2
107- case _ => PythonPrettyPrinter .numberLiteral(0.0 ) + " /* todo */"
118+ val (interpretationDefs, body) = interpretation match {
119+ case Some (i) =>
120+ (printFuncDefs(i, defs, parameters, printed ++ printing.map({ case (n, s) => Declaration .asNamedSymbol(n, s) }).toSet),
121+ primitiveExprGenerator(parameters)(argsSubst(i))._2)
122+ case _ => (" " , PythonPrettyPrinter .numberLiteral(0.0 ) + " # todo" )
108123 }
109- s """ ${name.prettyString}( $pargs) -> ${ptype(codomain)}:
124+ s """ $interpretationDefs
125+ |def ${name.prettyString}( $FUNC_PARAMS_NAME: Params, $pargs) -> ${ptype(codomain)}:
110126 | return $body
111127 | """ .stripMargin
112128 }).mkString(" \n\n " )
@@ -117,7 +133,7 @@ class PythonGenerator(bodyGenerator: CodeGenerator, defs: Declaration = Declarat
117133 val names = StaticSemantics .symbols(expr).map(nameIdentifier)
118134 require(names.intersect(PythonGenerator .RESERVED_NAMES ).isEmpty, " Unexpected reserved Python names encountered: " +
119135 names.intersect(PythonGenerator .RESERVED_NAMES ).mkString(" ," ))
120- val parameters = CodeGenerator .getParameters(expr, stateVars)
136+ val parameters = CodeGenerator .getParameters(defs.exhaustiveSubst( expr) , stateVars)
121137
122138 val (bodyBody, bodyDefs) = bodyGenerator(expr, stateVars, inputVars, fileName)
123139
@@ -127,7 +143,7 @@ class PythonGenerator(bodyGenerator: CodeGenerator, defs: Declaration = Declarat
127143 printStateDeclaration(stateVars) +
128144 printInputDeclaration(inputVars) +
129145 printVerdictDeclaration +
130- printFuncDefs(expr, defs) +
146+ printFuncDefs(expr, defs, parameters ) +
131147 bodyDefs, bodyBody)
132148 }
133149}
0 commit comments