Skip to content

Commit 066f301

Browse files
waterlensLPTK
andauthored
Prettier LLIR and C++ output (#301)
Co-authored-by: Lionel Parreaux <[email protected]>
1 parent 53d2587 commit 066f301

File tree

10 files changed

+380
-380
lines changed

10 files changed

+380
-380
lines changed

hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/CodeGen.scala

+7-5
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ class CppCodeGen(builtinClassSymbols: Set[Local], tl: TraceLogger):
199199
else if op.unary && args.length === 1 then
200200
Expr.Unary(op2, toExpr(args(0)))
201201
else
202-
TODO(s"codegenOps ${op.nme} ${args.size} ${op.binary} ${op.unary} ${args.map(_.show)}")
202+
TODO(s"codegenOps ${op.nme} ${args.size} ${op.binary} ${op.unary} ${args.map(_.show).mkString(", ")}")
203203

204204

205205
def codegen(expr: IExpr)(using Ctx, Raise, Scope): Expr = expr match
@@ -241,8 +241,8 @@ class CppCodeGen(builtinClassSymbols: Set[Local], tl: TraceLogger):
241241
case Node.LetExpr(name, expr, body) =>
242242
val stmts2 = stmts ++ Ls(Stmt.AutoBind(Ls(name |> allocIfNew), codegen(expr)))
243243
codegen(body, storeInto)(using decls, stmts2)
244-
case Node.LetCall(names, bin: BuiltinSymbol, args, body) if bin.nme === "<builtin>" =>
245-
val stmts2 = stmts ++ codegenBuiltin(names, args.head.toString.replace("\"", ""), args.tail)
244+
case Node.LetCall(names, bin: BuiltinSymbol, IExpr.Literal(syntax.Tree.StrLit(s)) :: tail, body) if bin.nme === "<builtin>" =>
245+
val stmts2 = stmts ++ codegenBuiltin(names, s.replace("\"", ""), tail)
246246
codegen(body, storeInto)(using decls, stmts2)
247247
case Node.LetMethodCall(names, cls, method, IExpr.Ref(bin: BuiltinSymbol) :: args, body) if bin.nme === "<this>" =>
248248
val call = mlsThisCall(cls, method |> directName, args.map(toExpr))
@@ -285,13 +285,15 @@ class CppCodeGen(builtinClassSymbols: Set[Local], tl: TraceLogger):
285285
depgraph = depgraph.view.mapValues(_.filter(_ =/= node)).toMap
286286
degree = depgraph.view.mapValues(_.size).toMap
287287
val sorted = ListBuffer.empty[ClassInfo]
288-
var work = degree.filter(_._2 === 0).keys.toSet
288+
given Ordering[Local] with
289+
def compare(x: Local, y: Local): Int = x.nme.compareTo(y.nme)
290+
var work = degree.filter(_._2 === 0).keys.toSortedSet()
289291
while work.nonEmpty do
290292
val node = work.head
291293
work -= node
292294
prog.classes.find(x => (x.symbol) === node).foreach(sorted.addOne)
293295
removeNode(node)
294-
val next = degree.filter(_._2 === 0).keys
296+
val next = degree.filter(_._2 === 0).keys.toSortedSet()
295297
work ++= next
296298
if depgraph.nonEmpty then
297299
val cycle = depgraph.keys.mkString(", ")

hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Llir.scala

+100-105
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ private def raw(x: String): Document = doc"$x"
2020

2121
final case class LowLevelIRError(message: String) extends Exception(message)
2222

23-
private def docSymWithUid(sym: Local): Document = doc"${sym.nme}$$${sym.uid.toString()}"
24-
2523
val hiddenPrefixes = Set("Tuple")
2624

2725
def defaultHidden(x: Str): Bool =
@@ -32,24 +30,7 @@ case class Program(
3230
defs: Set[Func],
3331
entry: Local,
3432
):
35-
override def toString: String =
36-
val t1 = classes.toArray
37-
val t2 = defs.toArray
38-
Sorting.quickSort(t1)
39-
Sorting.quickSort(t2)
40-
s"Program({${t1.mkString(",\n")}}, {\n${t2.mkString("\n")}\n},\n$entry)"
41-
42-
def show(hide: Str => Bool = defaultHidden) = toDocument(hide).toString
43-
def toDocument(hide: Str => Bool = defaultHidden) : Document =
44-
val t1 = classes.iterator.filterNot(c => hide(c.symbol.nme)).toArray
45-
val t2 = defs.toArray
46-
Sorting.quickSort(t1)
47-
Sorting.quickSort(t2)
48-
given Conversion[String, Document] = raw
49-
val docClasses = t1.map(_.toDocument).toList.mkDocument(doc" # ")
50-
val docDefs = t2.map(_.toDocument).toList.mkDocument(doc" # ")
51-
val docMain = doc"entry = ${entry.nme}$$${entry.uid.toString()}"
52-
doc" #{ $docClasses\n$docDefs\n$docMain #} "
33+
def show = LlirDebugPrinter.mkDocument(this).toString
5334

5435
implicit object ClassInfoOrdering extends Ordering[ClassInfo] {
5536
def compare(a: ClassInfo, b: ClassInfo) = a.id.compare(b.id)
@@ -63,20 +44,7 @@ case class ClassInfo(
6344
methods: Map[Local, Func],
6445
):
6546
override def hashCode: Int = id
66-
override def toString: String =
67-
s"ClassInfo($id, $symbol, [${fields mkString ","}], parents: ${parents mkString ","}, methods:\n${methods mkString ",\n"})"
68-
69-
def show = toDocument.toString
70-
def toDocument: Document =
71-
given Conversion[String, Document] = raw
72-
val ext = if parents.isEmpty then "" else " extends " + parents.map(_.nme).mkString(", ")
73-
if methods.isEmpty then
74-
doc"class ${symbol.nme}(${fields.map(docSymWithUid).mkString(",")})$ext"
75-
else
76-
val docFirst = doc"class ${symbol.nme}(${fields.map(docSymWithUid).mkString(",")})$ext {"
77-
val docMethods = methods.map { (_, func) => func.toDocument }.toList.mkDocument(doc" # ")
78-
val docLast = doc"}"
79-
doc"$docFirst #{ # $docMethods #} # $docLast"
47+
def show = LlirDebugPrinter.mkDocument(this).toString
8048

8149
class FuncRef(var func: Local):
8250
def name: String = func.nme
@@ -98,32 +66,18 @@ case class Func(
9866
):
9967
var recBoundary: Opt[Int] = None
10068
override def hashCode: Int = id
101-
102-
override def toString: String =
103-
val ps = params.map(_.toString).mkString("[", ",", "]")
104-
s"Def($id, $name, $ps, \n$resultNum, \n$body\n)"
105-
106-
def show = toDocument
107-
def toDocument: Document =
108-
given Conversion[String, Document] = raw
109-
val docFirst = doc"def ${docSymWithUid(name)}(${params.map(docSymWithUid).mkString(",")}) ="
110-
val docBody = body.toDocument
111-
doc"$docFirst #{ # $docBody #} "
69+
def show = LlirDebugPrinter.mkDocument(this).toString
11270

11371
sealed trait TrivialExpr:
11472
import Expr._
115-
override def toString: String
116-
def show: String
117-
def toDocument: Document
11873
def toExpr: Expr = this match { case x: Expr => x }
11974
def foldRef(f: Local => TrivialExpr): TrivialExpr = this match
12075
case Ref(sym) => f(sym)
12176
case _ => this
12277
def iterRef(f: Local => Unit): Unit = this match
12378
case Ref(sym) => f(sym)
12479
case _ => ()
125-
126-
private def showArguments(args: Ls[TrivialExpr]) = args map (_.show) mkString ","
80+
def show: String
12781

12882
enum Expr:
12983
case Ref(sym: Local) extends Expr, TrivialExpr
@@ -132,38 +86,12 @@ enum Expr:
13286
case Select(name: Local, cls: Local, field: Str)
13387
case BasicOp(name: BuiltinSymbol, args: Ls[TrivialExpr])
13488
case AssignField(assignee: Local, cls: Local, field: Str, value: TrivialExpr)
135-
136-
override def toString: String = show
137-
138-
def show: String = toDocument.toString
139-
140-
def toDocument: Document =
141-
given Conversion[String, Document] = raw
142-
this match
143-
case Ref(s) => docSymWithUid(s)
144-
case Literal(Tree.BoolLit(lit)) => s"$lit"
145-
case Literal(Tree.IntLit(lit)) => s"$lit"
146-
case Literal(Tree.DecLit(lit)) => s"$lit"
147-
case Literal(Tree.StrLit(lit)) => s"${lit.escaped}"
148-
case Literal(Tree.UnitLit(isNullNotUndefined)) =>
149-
if isNullNotUndefined then "null" else "undefined"
150-
case CtorApp(cls, args) =>
151-
doc"${docSymWithUid(cls)}(${args.map(_.toString).mkString(",")})"
152-
case Select(s, cls, fld) =>
153-
doc"${docSymWithUid(s)}.<${docSymWithUid(cls)}:$fld>"
154-
case BasicOp(sym, args) =>
155-
doc"${sym.nme}(${args.map(_.toString).mkString(",")})"
156-
case AssignField(assignee, clsInfo, fieldName, value) =>
157-
doc"${docSymWithUid(assignee)}.${fieldName} := ${value.toString}"
89+
def show = LlirDebugPrinter.mkDocument(this).toString
15890

15991
enum Pat:
16092
case Lit(lit: hkmc2.syntax.Literal)
16193
case Class(cls: Local)
16294

163-
override def toString: String = this match
164-
case Lit(lit) => s"$lit"
165-
case Class(cls) => s"${{docSymWithUid(cls)}}"
166-
16795
enum Node:
16896
// Terminal forms:
16997
case Result(res: Ls[TrivialExpr])
@@ -174,33 +102,100 @@ enum Node:
174102
case LetExpr(name: Local, expr: Expr, body: Node)
175103
case LetMethodCall(names: Ls[Local], cls: Local, method: Local, args: Ls[TrivialExpr], body: Node)
176104
case LetCall(names: Ls[Local], func: Local, args: Ls[TrivialExpr], body: Node)
177-
178-
override def toString: String = show
179-
180-
def show: String = toDocument.toString
181-
182-
def toDocument: Document =
105+
def show = LlirDebugPrinter.mkDocument(this).toString
106+
107+
abstract class LlirPrinting:
108+
import hkmc2.utils.*
109+
import hkmc2.semantics.Elaborator.State
110+
111+
def mkDocument(local: Local): Document
112+
def mkDocument(lit: Literal): Document = doc"${lit.idStr}"
113+
def mkDocument(texpr: TrivialExpr): Document = texpr match
114+
case Expr.Ref(sym) => mkDocument(sym)
115+
case Expr.Literal(lit) => mkDocument(lit)
116+
117+
def mkDocument(expr: Expr): Document =
118+
expr match
119+
case Expr.Ref(sym) => doc"${mkDocument(sym)}"
120+
case Expr.Literal(lit) => doc"${lit.idStr}"
121+
case Expr.CtorApp(cls, args) =>
122+
doc"${mkDocument(cls)}(${args.map(mkDocument).mkString(",")})"
123+
case Expr.Select(name, cls, field) =>
124+
doc"${mkDocument(name)}.<${mkDocument(cls)}:$field>"
125+
case Expr.BasicOp(sym, args) =>
126+
doc"${sym.nme}(${args.map(mkDocument).mkString(",")})"
127+
case Expr.AssignField(assignee, clsInfo, fieldName, value) =>
128+
doc"${mkDocument(assignee)}.${fieldName} := ${mkDocument(value)}"
129+
def mkDocument(node: Node): Document =
130+
node match
131+
case Node.Result(res) => doc"${res.map(mkDocument).mkString(",")}"
132+
case Node.Jump(func, args) =>
133+
doc"jump ${mkDocument(func)}(${args.map(mkDocument).mkString(",")})"
134+
case Node.Case(scrutinee, cases, default) =>
135+
val docFirst = doc"case ${mkDocument(scrutinee)} of"
136+
val docCases = cases.map {
137+
case (pat, node) => doc"${pat.toString} => #{ # ${mkDocument(node)} #} "
138+
}.mkDocument(doc" # ")
139+
default match
140+
case N => doc"$docFirst #{ # $docCases #} "
141+
case S(dc) =>
142+
val docDeft = doc"_ => #{ # ${mkDocument(dc)} #} "
143+
doc"$docFirst #{ # $docCases # $docDeft #} "
144+
case Node.Panic(msg) =>
145+
doc"panic ${s"\"$msg\""}"
146+
case Node.LetExpr(x, expr, body) =>
147+
doc"let ${mkDocument(x)} = ${mkDocument(expr)} in # ${mkDocument(body)}"
148+
case Node.LetMethodCall(xs, cls, method, args, body) =>
149+
doc"let ${xs.map(mkDocument).mkString(",")} = ${mkDocument(cls)}.${method.nme}(${args.map(mkDocument).mkString(",")}) in # ${mkDocument(body)}"
150+
case Node.LetCall(xs, func, args, body) =>
151+
doc"let* (${xs.map(mkDocument).mkString(",")}) = ${mkDocument(func)}(${args.map(mkDocument).mkString(",")}) in # ${mkDocument(body)}"
152+
def mkDocument(defn: Func): Document =
153+
def docParams(params: Ls[Local]): Document =
154+
params.map(mkDocument).mkString("(", ",", ")")
183155
given Conversion[String, Document] = raw
184-
this match
185-
case Result(res) => (res |> showArguments)
186-
case Jump(jp, args) =>
187-
doc"jump ${docSymWithUid(jp)}(${args |> showArguments})"
188-
case Case(x, cases, default) =>
189-
val docFirst = doc"case ${x.toString} of"
190-
val docCases = cases.map {
191-
case (pat, node) => doc"${pat.toString} => #{ # ${node.toDocument} #} "
192-
}.mkDocument(doc" # ")
193-
default match
194-
case N => doc"$docFirst #{ # $docCases #} "
195-
case S(dc) =>
196-
val docDeft = doc"_ => #{ # ${dc.toDocument} #} "
197-
doc"$docFirst #{ # $docCases # $docDeft #} "
198-
case Panic(msg) =>
199-
doc"panic ${s"\"$msg\""}"
200-
case LetExpr(x, expr, body) =>
201-
doc"let ${docSymWithUid(x)} = ${expr.toString} in # ${body.toDocument}"
202-
case LetMethodCall(xs, cls, method, args, body) =>
203-
doc"let ${xs.map(docSymWithUid).mkString(",")} = ${cls.nme}.${docSymWithUid(method)}(${args.map(_.toString).mkString(",")}) in # ${body.toDocument}"
204-
case LetCall(xs, func, args, body) =>
205-
doc"let* (${xs.map(docSymWithUid).mkString(",")}) = ${func.nme}(${args.map(_.toString).mkString(",")}) in # ${body.toDocument}"
156+
val docFirst = doc"def ${mkDocument(defn.name)}${docParams(defn.params)} ="
157+
val docBody = mkDocument(defn.body)
158+
doc"$docFirst #{ # $docBody #} "
159+
def mkDocument(cls: ClassInfo): Document =
160+
given Conversion[String, Document] = raw
161+
val ext = if cls.parents.isEmpty then "" else " extends " + cls.parents.map(mkDocument).mkString(", ")
162+
val docFirst = doc"class ${mkDocument(cls.symbol)}(${cls.fields.map(_.nme).mkString(",")})$ext"
163+
if cls.methods.isEmpty then
164+
doc"$docFirst"
165+
else
166+
val docMethods = cls.methods.map { (_, func) => mkDocument(func) }.toList.mkDocument(doc" # ")
167+
doc"$docFirst { #{ # $docMethods #} # }"
168+
def mkDocument(prog: Program, hide: Str => Bool = defaultHidden): Document =
169+
given Conversion[String, Document] = raw
170+
val t1 = prog.classes.iterator.filterNot(c => hide(c.symbol.nme)).toArray
171+
val t2 = prog.defs.toArray
172+
Sorting.quickSort(t1)
173+
Sorting.quickSort(t2)
174+
val docClasses = t1.filterNot(c => hide(c.symbol.nme)).map(mkDocument).toList.mkDocument(doc" # ")
175+
val docDefs = t2.map(mkDocument).toList.mkDocument(doc" # ")
176+
val docMain = doc"entry = ${mkDocument(prog.entry)}"
177+
doc" #{ $docClasses\n$docDefs\n$docMain #} "
206178

179+
class LlirPrinter(using Raise, hkmc2.utils.Scope) extends LlirPrinting:
180+
import hkmc2.utils.*
181+
import hkmc2.semantics.Elaborator.State
182+
183+
def getVar(l: Local): String = l match
184+
case ts: hkmc2.semantics.TermSymbol =>
185+
ts.owner match
186+
case S(owner) => summon[Scope].lookup_!(ts)
187+
case N => summon[Scope].lookup_!(ts)
188+
case ts: hkmc2.semantics.InnerSymbol =>
189+
summon[Scope].lookup_!(ts)
190+
case _ => summon[Scope].lookup_!(l)
191+
def allocIfNew(l: Local): String =
192+
summon[Scope].lookup(l) match
193+
case S(_) => getVar(l)
194+
case N =>
195+
summon[Scope].allocateName(l)
196+
override def mkDocument(local: Local): Document = allocIfNew(local)
197+
198+
object LlirDebugPrinter extends LlirPrinting:
199+
import hkmc2.utils.*
200+
def docSymWithUid(sym: Local): Document = doc"${sym.nme}$$${sym.uid.toString()}"
201+
override def mkDocument(local: Local): Document = docSymWithUid(local)

hkmc2/shared/src/test/mlscript/llir/Classes.mls

+12-12
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,23 @@ fun main() =
2626
main()
2727
//│ LLIR:
2828
//│ class Base() {
29-
//│ def get$882() =
29+
//│ def get() =
3030
//│ 1
3131
//│ }
3232
//│ class Derived() extends Base {
33-
//│ def get$883() =
33+
//│ def get1() =
3434
//│ 2
3535
//│ }
36-
//│ def main$885() =
37-
//│ let x$906 = Derived$891() in
38-
//│ let x$907 = Base.get$882(x$906) in
39-
//│ let x$908 = Derived.get$883(x$906) in
40-
//│ let x$909 = *(x$907,x$908) in
41-
//│ x$909
42-
//│ def entry$911() =
43-
//│ let* (x$910) = main() in
44-
//│ x$910
45-
//│ entry = entry$911
36+
//│ def main() =
37+
//│ let x = Derived() in
38+
//│ let x1 = Base.get(x) in
39+
//│ let x2 = Derived.get(x) in
40+
//│ let x3 = *(x1,x2) in
41+
//│ x3
42+
//│ def entry() =
43+
//│ let* (x4) = main() in
44+
//│ x4
45+
//│ entry = entry
4646
//│
4747
//│ Interpreted:
4848
//│ 4

0 commit comments

Comments
 (0)