Skip to content

Commit e29e935

Browse files
author
Guannan Wei
committed
evals
1 parent f9044c0 commit e29e935

File tree

5 files changed

+737
-10
lines changed

5 files changed

+737
-10
lines changed

src/main/scala/AST.scala

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,14 @@ case object Drop extends Instr
6161
case object Alloc extends Instr
6262
case object Free extends Instr
6363
case class Select(ty: Option[List[ValueType]]) extends Instr
64-
case class Block(ty: BlockType, instrs: List[Instr]) extends Instr
64+
case class Block(ty: BlockType, instrs: List[Instr]) extends Instr {
65+
override def toString: String = s"Block(...)"
66+
}
6567
case class IdBlock(id: Int, ty: BlockType, instrs: List[Instr]) extends Instr
66-
case class Loop(ty: BlockType, instrs: List[Instr]) extends Instr
68+
case class Loop(ty: BlockType, instrs: List[Instr]) extends Instr {
69+
override def toString: String = s"Loop(...)"
70+
}
71+
case class ForLoop(init:List[Instr], cond: List[Instr], post: List[Instr], body: List[Instr]) extends Instr
6772
case class IdLoop(id: Int, ty: BlockType, instrs: List[Instr]) extends Instr
6873
case class If(ty: BlockType, thenInstrs: List[Instr], elseInstrs: List[Instr]) extends Instr
6974
case class IdIf(ty: BlockType, thenInstrs: IdBlock, elseInstrs: IdBlock) extends Instr
@@ -281,10 +286,11 @@ case class CmdModule(module: Module) extends Cmd
281286
// TODO: extend if needed
282287
case class CMdInstnace() extends Cmd
283288

284-
abstract class Action extends WIR
289+
abstract class Action extends Cmd
285290
case class Invoke(instName: Option[String], name: String, args: List[Value]) extends Action
286291

287292
abstract class Assertion extends Cmd
293+
case class AssertInvalid() extends Assertion
288294
case class AssertReturn(action: Action, expect: List[Num] /* TODO: support multiple expect result type*/)
289295
extends Assertion
290296
case class AssertTrap(action: Action, message: String) extends Assertion
@@ -324,10 +330,7 @@ case class RefFuncV(funcAddr: Int) extends Ref {
324330
case FuncDef(_, FuncBodyDef(ty, _, _, _)) => RefType(ty)
325331
}
326332
}
327-
// RefContV refers to a delimited continuation
328-
// case class RefContV(cont: List[Value] => List[Value]) extends Ref {
329-
// def tipe(implicit m: ModuleInstance): ValueType = ???
330-
// }
333+
331334
case class RefExternV(externAddr: Int) extends Ref {
332335
def tipe(implicit m: ModuleInstance): ValueType = ???
333336
}

src/main/scala/MiniWasmFX.scala

Lines changed: 337 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,337 @@
1+
package wasm.miniwasm
2+
3+
import wasm.ast._
4+
import wasm.memory._
5+
6+
import scala.collection.mutable.ArrayBuffer
7+
import scala.collection.mutable.HashMap
8+
import Console.{GREEN, RED, RESET, YELLOW_B, UNDERLINED}
9+
10+
case class EvaluatorFX(module: ModuleInstance) {
11+
import Primtives._
12+
implicit val m: ModuleInstance = module
13+
14+
type Stack = List[Value]
15+
16+
trait Cont[A] {
17+
def apply(stack: Stack, trail: Trail[A], handler: Handlers[A]): A
18+
}
19+
type Trail[A] = List[(Cont[A], List[Int])] // trail items are pairs of continuation and tags
20+
type MCont[A] = Stack => A
21+
22+
type Handler[A] = Stack => A
23+
type Handlers[A] = List[(Int, Handler[A])]
24+
25+
case class ContV[A](k: (Stack, Cont[A], Trail[A], Handlers[A]) => A) extends Value {
26+
def tipe(implicit m: ModuleInstance): ValueType = ???
27+
28+
// override def toString: String = "ContV"
29+
}
30+
31+
// initK is a continuation that simply returns the inputed stack
32+
def initK[Ans](s: Stack, trail: Trail[Ans], hs: Handlers[Ans]): Ans =
33+
trail match {
34+
// Currently, the last element of the Trail is the halt continuation
35+
// the exception will never be thrown
36+
case (k1, _) :: trail => k1(s, trail, hs)
37+
case Nil => throw new Exception("No halting continuation in trail")
38+
}
39+
40+
def eval1[Ans](inst: Instr, stack: Stack, frame: Frame, kont: Cont[Ans],
41+
trail: Trail[Ans], brTable: List[Cont[Ans]], hs: Handlers[Ans]): Ans = {
42+
// System.err.println(f"[DEBUG] ${inst} | ${frame} | ${stack.reverse} | handlers: ${hs}");
43+
inst match {
44+
case Drop => kont(stack.tail, trail, hs)
45+
case Select(_) =>
46+
val I32V(cond) :: v2 :: v1 :: newStack = stack
47+
val value = if (cond == 0) v1 else v2
48+
kont(value :: newStack, trail, hs)
49+
case LocalGet(i) =>
50+
kont(frame.locals(i) :: stack, trail, hs)
51+
case LocalSet(i) =>
52+
val value :: newStack = stack
53+
frame.locals(i) = value
54+
kont(newStack, trail, hs)
55+
case LocalTee(i) =>
56+
val value :: newStack = stack
57+
frame.locals(i) = value
58+
kont(stack, trail, hs)
59+
case GlobalGet(i) =>
60+
kont(module.globals(i).value :: stack, trail, hs)
61+
case GlobalSet(i) =>
62+
val value :: newStack = stack
63+
module.globals(i).ty match {
64+
case GlobalType(tipe, true) if value.tipe == tipe =>
65+
module.globals(i).value = value
66+
case GlobalType(_, true) => throw new Exception("Invalid type")
67+
case _ => throw new Exception("Cannot set immutable global")
68+
}
69+
kont(newStack, trail, hs)
70+
case MemorySize =>
71+
kont(I32V(module.memory.head.size) :: stack, trail, hs)
72+
case MemoryGrow =>
73+
val I32V(delta) :: newStack = stack
74+
val mem = module.memory.head
75+
val oldSize = mem.size
76+
mem.grow(delta) match {
77+
case Some(e) => kont(I32V(-1) :: newStack, trail, hs)
78+
case _ => kont(I32V(oldSize) :: newStack, trail, hs)
79+
}
80+
case MemoryFill =>
81+
val I32V(value) :: I32V(offset) :: I32V(size) :: newStack = stack
82+
if (memOutOfBound(module, 0, offset, size))
83+
throw new Exception("Out of bounds memory access") // GW: turn this into a `trap`?
84+
else {
85+
module.memory.head.fill(offset, size, value.toByte)
86+
kont(newStack, trail, hs)
87+
}
88+
case MemoryCopy =>
89+
val I32V(n) :: I32V(src) :: I32V(dest) :: newStack = stack
90+
if (memOutOfBound(module, 0, src, n) || memOutOfBound(module, 0, dest, n))
91+
throw new Exception("Out of bounds memory access")
92+
else {
93+
module.memory.head.copy(dest, src, n)
94+
kont(newStack, trail, hs)
95+
}
96+
case Const(n) => kont(n :: stack, trail, hs)
97+
case Binary(op) =>
98+
val v2 :: v1 :: newStack = stack
99+
kont(evalBinOp(op, v1, v2) :: newStack, trail, hs)
100+
case Unary(op) =>
101+
val v :: newStack = stack
102+
kont(evalUnaryOp(op, v) :: newStack, trail, hs)
103+
case Compare(op) =>
104+
val v2 :: v1 :: newStack = stack
105+
kont(evalRelOp(op, v1, v2) :: newStack, trail, hs)
106+
case Test(op) =>
107+
val v :: newStack = stack
108+
kont(evalTestOp(op, v) :: newStack, trail, hs)
109+
case Store(StoreOp(align, offset, ty, None)) =>
110+
val I32V(v) :: I32V(addr) :: newStack = stack
111+
module.memory(0).storeInt(addr + offset, v)
112+
kont(newStack, trail, hs)
113+
case Load(LoadOp(align, offset, ty, None, None)) =>
114+
val I32V(addr) :: newStack = stack
115+
val value = module.memory(0).loadInt(addr + offset)
116+
kont(I32V(value) :: newStack, trail, hs)
117+
case Nop => kont(stack, trail, hs)
118+
case Unreachable => throw Trap()
119+
case Block(ty, inner) =>
120+
val funcTy = getFuncType(ty)
121+
val (inputs, restStack) = stack.splitAt(funcTy.inps.size)
122+
val escape: Cont[Ans] = (s1, t1, h1) => kont(s1.take(funcTy.out.size) ++ restStack, t1, h1)
123+
evalList(inner, inputs, frame, escape, trail, escape::brTable, hs)
124+
case Loop(ty, inner) =>
125+
val funcTy = getFuncType(ty)
126+
val (inputs, restStack) = stack.splitAt(funcTy.inps.size)
127+
val escape: Cont[Ans] = (s1, t1, h1) => kont(s1.take(funcTy.out.size) ++ restStack, t1, h1)
128+
def loop(retStack: List[Value], trail1: Trail[Ans], h1: Handlers[Ans]): Ans =
129+
evalList(inner, retStack.take(funcTy.inps.size), frame, escape, trail, (loop _ : Cont[Ans])::brTable, h1)
130+
loop(inputs, trail, hs)
131+
case If(ty, thn, els) =>
132+
val funcTy = getFuncType(ty)
133+
val I32V(cond) :: newStack = stack
134+
val inner = if (cond != 0) thn else els
135+
val (inputs, restStack) = newStack.splitAt(funcTy.inps.size)
136+
val escape: Cont[Ans] = (s1, t1, h1) => kont(s1.take(funcTy.out.size) ++ restStack, t1, h1)
137+
evalList(inner, inputs, frame, escape, trail, escape::brTable, hs)
138+
case Br(label) =>
139+
brTable(label)(stack, trail, hs)
140+
case BrIf(label) =>
141+
val I32V(cond) :: newStack = stack
142+
if (cond != 0) brTable(label)(newStack, trail, hs)
143+
else kont(newStack, trail, hs)
144+
case BrTable(labels, default) =>
145+
val I32V(cond) :: newStack = stack
146+
val goto = if (cond < labels.length) labels(cond) else default
147+
brTable(goto)(newStack, trail, hs)
148+
case Return =>
149+
brTable.last(stack, trail, hs)
150+
case Call(f) => evalCall1(f, stack, frame, kont, trail, brTable, hs, false)
151+
case ReturnCall(f) =>
152+
// System.err.println(s"[DEBUG] return call: $f")
153+
evalCall1(f, stack, frame, kont, trail, brTable, hs, true)
154+
case RefFunc(f) =>
155+
// TODO: RefFuncV stores an applicable function, instead of a syntactic structure
156+
kont(RefFuncV(f) :: stack, trail, hs)
157+
// WasmFX effect handlers:
158+
case ContNew(ty) =>
159+
val RefFuncV(f) :: newStack = stack
160+
def kr(s: Stack, k1: Cont[Ans], t1: Trail[Ans], hs: Handlers[Ans]): Ans = {
161+
evalCall1(f, s, frame/*?*/, k1, t1, List(), hs, false)
162+
}
163+
kont(ContV(kr) :: newStack, trail, hs)
164+
case Suspend(tagId) =>
165+
val FuncType(_, inps, out) = module.tags(tagId)
166+
val (inputs, restStack) = stack.splitAt(inps.size)
167+
// System.err.println(s"[DEBUG] handlers: $hs")
168+
// System.err.println(s"[DEBUG] trail: $trail")
169+
val kr = (s: Stack, _: Cont[Ans], t1: Trail[Ans], hs1: Handlers[Ans]) => {
170+
// construct a new trail by ignoring the default handler
171+
val index = trail.indexWhere { case (_, tags) => tags.contains(tagId) }
172+
val newTrail = if (index >= 0) trail.take(index) else trail
173+
// Q: `hs` are ignored here, don't we need prepend some thing from `hs` to `hs1`?
174+
// A: No, according to fig.3 in the paper, solely using the new handlers is just engough.
175+
// Q: Should we clear tags in the `newTrail`? Is that possible suspend target tag in hs1 but also in newTrail?
176+
// A: Yes, we should maintain the consistency between `hs1` and `newTrail + t1`.
177+
// mkont lost here, and it's safe if we never modify it
178+
kont(s ++ restStack, newTrail.map({ case (c, _) => (c, List()) }) ++ t1, hs1)
179+
}
180+
val newStack = ContV(kr) :: inputs
181+
hs.find(_._1 == tagId) match {
182+
case Some((_, handler)) =>
183+
// we don't need to pass trail here, because handler's trail was determined when resuming
184+
handler(newStack)
185+
case None => throw new Exception(s"no handler for tag $tagId")
186+
}
187+
case Resume(tyId, handler) =>
188+
val (f: ContV[Ans]) :: newStack = stack
189+
val ContType(funcTypeId) = module.types(tyId)
190+
val FuncType(_, inps, out) = module.types(funcTypeId)
191+
val (inputs, restStack) = newStack.splitAt(inps.size)
192+
val newHs: List[(Int, Handler[Ans])] = handler.map {
193+
case Handler(tagId, labelId) =>
194+
val hh: Handler[Ans] = s1 => brTable(labelId)(s1, trail, hs)
195+
(tagId, hh)
196+
}
197+
val tags = handler.map(_.tag)
198+
// rather than push `kont` to meta-continuation, maybe we can push it to `trail`?
199+
f.k(inputs, initK, List((kont,tags)) ++ trail, newHs ++ hs)
200+
201+
case ContBind(oldContTyId, newConTyId) =>
202+
val (f: ContV[Ans]) :: newStack = stack
203+
// use oldParamTy - newParamTy to get how many values to pop from the stack
204+
val ContType(oldId) = module.types(oldContTyId)
205+
val FuncType(_, oldParamTy, _) = module.types(oldId)
206+
val ContType(newId) = module.types(newConTyId)
207+
val FuncType(_, newParamTy, _) = module.types(newId)
208+
// get oldParamTy - newParamTy (there's no type checking at all)
209+
val inputSize = oldParamTy.size - newParamTy.size
210+
val (inputs, restStack) = newStack.splitAt(inputSize)
211+
// partially apply the old continuation
212+
def kr(s: Stack, k1: Cont[Ans], t1: Trail[Ans], handlers: Handlers[Ans]): Ans = {
213+
f.k(s ++ inputs, k1, t1, handlers)
214+
}
215+
kont(ContV(kr) :: restStack, trail, hs)
216+
217+
case CallRef(ty) =>
218+
val RefFuncV(f) :: newStack = stack
219+
evalCall1(f, newStack, frame, kont, trail, brTable, hs, false)
220+
221+
case _ =>
222+
println(inst)
223+
throw new Exception(s"instruction $inst not implemented")
224+
}
225+
}
226+
227+
def evalList[Ans](insts: List[Instr], stack: Stack, frame: Frame, kont: Cont[Ans],
228+
trail1: Trail[Ans], brTable: List[Cont[Ans]], hs: Handlers[Ans]): Ans = {
229+
insts match {
230+
case Nil => kont(stack, trail1, hs)
231+
case inst :: rest =>
232+
val newKont: Cont[Ans] = (s1, t1, h1) => evalList(rest, s1, frame, kont, t1, brTable, h1)
233+
eval1(inst, stack, frame, newKont, trail1, brTable, hs)
234+
}
235+
}
236+
237+
def evalCall1[Ans](funcIndex: Int,
238+
stack: List[Value],
239+
frame: Frame,
240+
kont: Cont[Ans],
241+
trail: Trail[Ans],
242+
brTable: List[Cont[Ans]], // can be removed
243+
h: Handlers[Ans],
244+
isTail: Boolean): Ans =
245+
module.funcs(funcIndex) match {
246+
case FuncDef(_, FuncBodyDef(ty, _, locals, body)) =>
247+
val args = stack.take(ty.inps.size).reverse
248+
val newStack = stack.drop(ty.inps.size)
249+
val frameLocals = args ++ locals.map(zero(_))
250+
val newFrame = Frame(ArrayBuffer(frameLocals: _*))
251+
if (isTail) {
252+
// when tail call, share the continuation for returning with the callee
253+
evalList(body, List(), newFrame, brTable.last, trail, List(brTable.last), h)
254+
}
255+
else {
256+
val restK: Cont[Ans] = (s1, t1, h1) => kont(s1.take(ty.out.size) ++ newStack, t1, h1)
257+
// We make a new brTable by `restK`, since function creates a new block to escape
258+
// (more or less like `return`)
259+
evalList(body, List(), newFrame, restK, trail, List(restK), h)
260+
}
261+
case Import("console", "log", _) =>
262+
// println(s"[DEBUG] current stack: $stack")
263+
val I32V(v) :: newStack = stack
264+
println(v)
265+
kont(newStack, trail, h)
266+
case Import("spectest", "print_i32", _) =>
267+
// println(s"[DEBUG] current stack: $stack")
268+
val I32V(v) :: newStack = stack
269+
println(v)
270+
kont(newStack, trail, h)
271+
case Import(_, _, _) => throw new Exception(s"Unknown import at $funcIndex")
272+
case _ => throw new Exception(s"Definition at $funcIndex is not callable")
273+
}
274+
275+
// If `main` is given, then we use that function as the entry point of the program;
276+
// otherwise, we look up the top-level `start` instruction to locate the entry point.
277+
def evalTop[Ans](halt: Cont[Ans], main: Option[String] = None): Ans = {
278+
val instrs = main match {
279+
case Some(func_name) =>
280+
module.defs.flatMap({
281+
case Export(`func_name`, ExportFunc(fid)) =>
282+
System.err.println(s"Entering function $main")
283+
module.funcs(fid) match {
284+
case FuncDef(_, FuncBodyDef(_, _, locals, body)) => body
285+
case _ => throw new Exception("Entry function has no concrete body")
286+
}
287+
case _ => List()
288+
})
289+
case None =>
290+
module.defs.flatMap({
291+
case Start(id) =>
292+
System.err.println(s"Entering unnamed function $id")
293+
module.funcs(id) match {
294+
case FuncDef(_, FuncBodyDef(_, _, locals, body)) => body
295+
case _ =>
296+
throw new Exception("Entry function has no concrete body")
297+
}
298+
case _ => List()
299+
})
300+
}
301+
val locals = main match {
302+
case Some(func_name) =>
303+
module.defs.flatMap({
304+
case Export(`func_name`, ExportFunc(fid)) =>
305+
System.err.println(s"Entering function $main")
306+
module.funcs(fid) match {
307+
case FuncDef(_, FuncBodyDef(_, _, locals, _)) => locals
308+
case _ => throw new Exception("Entry function has no concrete body")
309+
}
310+
case _ => List()
311+
})
312+
case None =>
313+
module.defs.flatMap({
314+
case Start(id) =>
315+
System.err.println(s"Entering unnamed function $id")
316+
module.funcs(id) match {
317+
case FuncDef(_, FuncBodyDef(_, _, locals, body)) => locals
318+
case _ =>
319+
throw new Exception("Entry function has no concrete body")
320+
}
321+
case _ => List()
322+
})
323+
}
324+
if (instrs.isEmpty) println("Warning: nothing is executed")
325+
// initialized locals
326+
val frame = Frame(ArrayBuffer(locals.map(zero(_)): _*))
327+
evalList(instrs, List(), frame, initK[Ans], List((halt, List())), List(initK: Cont[Ans]), List())
328+
}
329+
330+
def evalTop(m: ModuleInstance): Unit =
331+
evalTop(((stack, trail, _hs) => {
332+
if (!trail.isEmpty) {
333+
throw new Exception("Composing something after halt continuation")
334+
}
335+
}): Cont[Unit])
336+
}
337+

0 commit comments

Comments
 (0)