|
| 1 | +package wasm.miniwasm |
| 2 | + |
| 3 | +import wasm.ast._ |
| 4 | +import wasm.memory._ |
| 5 | + |
| 6 | +import scala.collection.mutable.ArrayBuffer |
| 7 | +import scala.collection.mutable.HashMap |
| 8 | +import Console.{GREEN, RED, RESET, YELLOW_B, UNDERLINED} |
| 9 | + |
| 10 | +case class EvaluatorFX(module: ModuleInstance) { |
| 11 | + import Primtives._ |
| 12 | + implicit val m: ModuleInstance = module |
| 13 | + |
| 14 | + type Stack = List[Value] |
| 15 | + |
| 16 | + trait Cont[A] { |
| 17 | + def apply(stack: Stack, trail: Trail[A], handler: Handlers[A]): A |
| 18 | + } |
| 19 | + type Trail[A] = List[(Cont[A], List[Int])] // trail items are pairs of continuation and tags |
| 20 | + type MCont[A] = Stack => A |
| 21 | + |
| 22 | + type Handler[A] = Stack => A |
| 23 | + type Handlers[A] = List[(Int, Handler[A])] |
| 24 | + |
| 25 | + case class ContV[A](k: (Stack, Cont[A], Trail[A], Handlers[A]) => A) extends Value { |
| 26 | + def tipe(implicit m: ModuleInstance): ValueType = ??? |
| 27 | + |
| 28 | + // override def toString: String = "ContV" |
| 29 | + } |
| 30 | + |
| 31 | + // initK is a continuation that simply returns the inputed stack |
| 32 | + def initK[Ans](s: Stack, trail: Trail[Ans], hs: Handlers[Ans]): Ans = |
| 33 | + trail match { |
| 34 | + // Currently, the last element of the Trail is the halt continuation |
| 35 | + // the exception will never be thrown |
| 36 | + case (k1, _) :: trail => k1(s, trail, hs) |
| 37 | + case Nil => throw new Exception("No halting continuation in trail") |
| 38 | + } |
| 39 | + |
| 40 | + def eval1[Ans](inst: Instr, stack: Stack, frame: Frame, kont: Cont[Ans], |
| 41 | + trail: Trail[Ans], brTable: List[Cont[Ans]], hs: Handlers[Ans]): Ans = { |
| 42 | + // System.err.println(f"[DEBUG] ${inst} | ${frame} | ${stack.reverse} | handlers: ${hs}"); |
| 43 | + inst match { |
| 44 | + case Drop => kont(stack.tail, trail, hs) |
| 45 | + case Select(_) => |
| 46 | + val I32V(cond) :: v2 :: v1 :: newStack = stack |
| 47 | + val value = if (cond == 0) v1 else v2 |
| 48 | + kont(value :: newStack, trail, hs) |
| 49 | + case LocalGet(i) => |
| 50 | + kont(frame.locals(i) :: stack, trail, hs) |
| 51 | + case LocalSet(i) => |
| 52 | + val value :: newStack = stack |
| 53 | + frame.locals(i) = value |
| 54 | + kont(newStack, trail, hs) |
| 55 | + case LocalTee(i) => |
| 56 | + val value :: newStack = stack |
| 57 | + frame.locals(i) = value |
| 58 | + kont(stack, trail, hs) |
| 59 | + case GlobalGet(i) => |
| 60 | + kont(module.globals(i).value :: stack, trail, hs) |
| 61 | + case GlobalSet(i) => |
| 62 | + val value :: newStack = stack |
| 63 | + module.globals(i).ty match { |
| 64 | + case GlobalType(tipe, true) if value.tipe == tipe => |
| 65 | + module.globals(i).value = value |
| 66 | + case GlobalType(_, true) => throw new Exception("Invalid type") |
| 67 | + case _ => throw new Exception("Cannot set immutable global") |
| 68 | + } |
| 69 | + kont(newStack, trail, hs) |
| 70 | + case MemorySize => |
| 71 | + kont(I32V(module.memory.head.size) :: stack, trail, hs) |
| 72 | + case MemoryGrow => |
| 73 | + val I32V(delta) :: newStack = stack |
| 74 | + val mem = module.memory.head |
| 75 | + val oldSize = mem.size |
| 76 | + mem.grow(delta) match { |
| 77 | + case Some(e) => kont(I32V(-1) :: newStack, trail, hs) |
| 78 | + case _ => kont(I32V(oldSize) :: newStack, trail, hs) |
| 79 | + } |
| 80 | + case MemoryFill => |
| 81 | + val I32V(value) :: I32V(offset) :: I32V(size) :: newStack = stack |
| 82 | + if (memOutOfBound(module, 0, offset, size)) |
| 83 | + throw new Exception("Out of bounds memory access") // GW: turn this into a `trap`? |
| 84 | + else { |
| 85 | + module.memory.head.fill(offset, size, value.toByte) |
| 86 | + kont(newStack, trail, hs) |
| 87 | + } |
| 88 | + case MemoryCopy => |
| 89 | + val I32V(n) :: I32V(src) :: I32V(dest) :: newStack = stack |
| 90 | + if (memOutOfBound(module, 0, src, n) || memOutOfBound(module, 0, dest, n)) |
| 91 | + throw new Exception("Out of bounds memory access") |
| 92 | + else { |
| 93 | + module.memory.head.copy(dest, src, n) |
| 94 | + kont(newStack, trail, hs) |
| 95 | + } |
| 96 | + case Const(n) => kont(n :: stack, trail, hs) |
| 97 | + case Binary(op) => |
| 98 | + val v2 :: v1 :: newStack = stack |
| 99 | + kont(evalBinOp(op, v1, v2) :: newStack, trail, hs) |
| 100 | + case Unary(op) => |
| 101 | + val v :: newStack = stack |
| 102 | + kont(evalUnaryOp(op, v) :: newStack, trail, hs) |
| 103 | + case Compare(op) => |
| 104 | + val v2 :: v1 :: newStack = stack |
| 105 | + kont(evalRelOp(op, v1, v2) :: newStack, trail, hs) |
| 106 | + case Test(op) => |
| 107 | + val v :: newStack = stack |
| 108 | + kont(evalTestOp(op, v) :: newStack, trail, hs) |
| 109 | + case Store(StoreOp(align, offset, ty, None)) => |
| 110 | + val I32V(v) :: I32V(addr) :: newStack = stack |
| 111 | + module.memory(0).storeInt(addr + offset, v) |
| 112 | + kont(newStack, trail, hs) |
| 113 | + case Load(LoadOp(align, offset, ty, None, None)) => |
| 114 | + val I32V(addr) :: newStack = stack |
| 115 | + val value = module.memory(0).loadInt(addr + offset) |
| 116 | + kont(I32V(value) :: newStack, trail, hs) |
| 117 | + case Nop => kont(stack, trail, hs) |
| 118 | + case Unreachable => throw Trap() |
| 119 | + case Block(ty, inner) => |
| 120 | + val funcTy = getFuncType(ty) |
| 121 | + val (inputs, restStack) = stack.splitAt(funcTy.inps.size) |
| 122 | + val escape: Cont[Ans] = (s1, t1, h1) => kont(s1.take(funcTy.out.size) ++ restStack, t1, h1) |
| 123 | + evalList(inner, inputs, frame, escape, trail, escape::brTable, hs) |
| 124 | + case Loop(ty, inner) => |
| 125 | + val funcTy = getFuncType(ty) |
| 126 | + val (inputs, restStack) = stack.splitAt(funcTy.inps.size) |
| 127 | + val escape: Cont[Ans] = (s1, t1, h1) => kont(s1.take(funcTy.out.size) ++ restStack, t1, h1) |
| 128 | + def loop(retStack: List[Value], trail1: Trail[Ans], h1: Handlers[Ans]): Ans = |
| 129 | + evalList(inner, retStack.take(funcTy.inps.size), frame, escape, trail, (loop _ : Cont[Ans])::brTable, h1) |
| 130 | + loop(inputs, trail, hs) |
| 131 | + case If(ty, thn, els) => |
| 132 | + val funcTy = getFuncType(ty) |
| 133 | + val I32V(cond) :: newStack = stack |
| 134 | + val inner = if (cond != 0) thn else els |
| 135 | + val (inputs, restStack) = newStack.splitAt(funcTy.inps.size) |
| 136 | + val escape: Cont[Ans] = (s1, t1, h1) => kont(s1.take(funcTy.out.size) ++ restStack, t1, h1) |
| 137 | + evalList(inner, inputs, frame, escape, trail, escape::brTable, hs) |
| 138 | + case Br(label) => |
| 139 | + brTable(label)(stack, trail, hs) |
| 140 | + case BrIf(label) => |
| 141 | + val I32V(cond) :: newStack = stack |
| 142 | + if (cond != 0) brTable(label)(newStack, trail, hs) |
| 143 | + else kont(newStack, trail, hs) |
| 144 | + case BrTable(labels, default) => |
| 145 | + val I32V(cond) :: newStack = stack |
| 146 | + val goto = if (cond < labels.length) labels(cond) else default |
| 147 | + brTable(goto)(newStack, trail, hs) |
| 148 | + case Return => |
| 149 | + brTable.last(stack, trail, hs) |
| 150 | + case Call(f) => evalCall1(f, stack, frame, kont, trail, brTable, hs, false) |
| 151 | + case ReturnCall(f) => |
| 152 | + // System.err.println(s"[DEBUG] return call: $f") |
| 153 | + evalCall1(f, stack, frame, kont, trail, brTable, hs, true) |
| 154 | + case RefFunc(f) => |
| 155 | + // TODO: RefFuncV stores an applicable function, instead of a syntactic structure |
| 156 | + kont(RefFuncV(f) :: stack, trail, hs) |
| 157 | + // WasmFX effect handlers: |
| 158 | + case ContNew(ty) => |
| 159 | + val RefFuncV(f) :: newStack = stack |
| 160 | + def kr(s: Stack, k1: Cont[Ans], t1: Trail[Ans], hs: Handlers[Ans]): Ans = { |
| 161 | + evalCall1(f, s, frame/*?*/, k1, t1, List(), hs, false) |
| 162 | + } |
| 163 | + kont(ContV(kr) :: newStack, trail, hs) |
| 164 | + case Suspend(tagId) => |
| 165 | + val FuncType(_, inps, out) = module.tags(tagId) |
| 166 | + val (inputs, restStack) = stack.splitAt(inps.size) |
| 167 | + // System.err.println(s"[DEBUG] handlers: $hs") |
| 168 | + // System.err.println(s"[DEBUG] trail: $trail") |
| 169 | + val kr = (s: Stack, _: Cont[Ans], t1: Trail[Ans], hs1: Handlers[Ans]) => { |
| 170 | + // construct a new trail by ignoring the default handler |
| 171 | + val index = trail.indexWhere { case (_, tags) => tags.contains(tagId) } |
| 172 | + val newTrail = if (index >= 0) trail.take(index) else trail |
| 173 | + // Q: `hs` are ignored here, don't we need prepend some thing from `hs` to `hs1`? |
| 174 | + // A: No, according to fig.3 in the paper, solely using the new handlers is just engough. |
| 175 | + // Q: Should we clear tags in the `newTrail`? Is that possible suspend target tag in hs1 but also in newTrail? |
| 176 | + // A: Yes, we should maintain the consistency between `hs1` and `newTrail + t1`. |
| 177 | + // mkont lost here, and it's safe if we never modify it |
| 178 | + kont(s ++ restStack, newTrail.map({ case (c, _) => (c, List()) }) ++ t1, hs1) |
| 179 | + } |
| 180 | + val newStack = ContV(kr) :: inputs |
| 181 | + hs.find(_._1 == tagId) match { |
| 182 | + case Some((_, handler)) => |
| 183 | + // we don't need to pass trail here, because handler's trail was determined when resuming |
| 184 | + handler(newStack) |
| 185 | + case None => throw new Exception(s"no handler for tag $tagId") |
| 186 | + } |
| 187 | + case Resume(tyId, handler) => |
| 188 | + val (f: ContV[Ans]) :: newStack = stack |
| 189 | + val ContType(funcTypeId) = module.types(tyId) |
| 190 | + val FuncType(_, inps, out) = module.types(funcTypeId) |
| 191 | + val (inputs, restStack) = newStack.splitAt(inps.size) |
| 192 | + val newHs: List[(Int, Handler[Ans])] = handler.map { |
| 193 | + case Handler(tagId, labelId) => |
| 194 | + val hh: Handler[Ans] = s1 => brTable(labelId)(s1, trail, hs) |
| 195 | + (tagId, hh) |
| 196 | + } |
| 197 | + val tags = handler.map(_.tag) |
| 198 | + // rather than push `kont` to meta-continuation, maybe we can push it to `trail`? |
| 199 | + f.k(inputs, initK, List((kont,tags)) ++ trail, newHs ++ hs) |
| 200 | + |
| 201 | + case ContBind(oldContTyId, newConTyId) => |
| 202 | + val (f: ContV[Ans]) :: newStack = stack |
| 203 | + // use oldParamTy - newParamTy to get how many values to pop from the stack |
| 204 | + val ContType(oldId) = module.types(oldContTyId) |
| 205 | + val FuncType(_, oldParamTy, _) = module.types(oldId) |
| 206 | + val ContType(newId) = module.types(newConTyId) |
| 207 | + val FuncType(_, newParamTy, _) = module.types(newId) |
| 208 | + // get oldParamTy - newParamTy (there's no type checking at all) |
| 209 | + val inputSize = oldParamTy.size - newParamTy.size |
| 210 | + val (inputs, restStack) = newStack.splitAt(inputSize) |
| 211 | + // partially apply the old continuation |
| 212 | + def kr(s: Stack, k1: Cont[Ans], t1: Trail[Ans], handlers: Handlers[Ans]): Ans = { |
| 213 | + f.k(s ++ inputs, k1, t1, handlers) |
| 214 | + } |
| 215 | + kont(ContV(kr) :: restStack, trail, hs) |
| 216 | + |
| 217 | + case CallRef(ty) => |
| 218 | + val RefFuncV(f) :: newStack = stack |
| 219 | + evalCall1(f, newStack, frame, kont, trail, brTable, hs, false) |
| 220 | + |
| 221 | + case _ => |
| 222 | + println(inst) |
| 223 | + throw new Exception(s"instruction $inst not implemented") |
| 224 | + } |
| 225 | + } |
| 226 | + |
| 227 | + def evalList[Ans](insts: List[Instr], stack: Stack, frame: Frame, kont: Cont[Ans], |
| 228 | + trail1: Trail[Ans], brTable: List[Cont[Ans]], hs: Handlers[Ans]): Ans = { |
| 229 | + insts match { |
| 230 | + case Nil => kont(stack, trail1, hs) |
| 231 | + case inst :: rest => |
| 232 | + val newKont: Cont[Ans] = (s1, t1, h1) => evalList(rest, s1, frame, kont, t1, brTable, h1) |
| 233 | + eval1(inst, stack, frame, newKont, trail1, brTable, hs) |
| 234 | + } |
| 235 | + } |
| 236 | + |
| 237 | + def evalCall1[Ans](funcIndex: Int, |
| 238 | + stack: List[Value], |
| 239 | + frame: Frame, |
| 240 | + kont: Cont[Ans], |
| 241 | + trail: Trail[Ans], |
| 242 | + brTable: List[Cont[Ans]], // can be removed |
| 243 | + h: Handlers[Ans], |
| 244 | + isTail: Boolean): Ans = |
| 245 | + module.funcs(funcIndex) match { |
| 246 | + case FuncDef(_, FuncBodyDef(ty, _, locals, body)) => |
| 247 | + val args = stack.take(ty.inps.size).reverse |
| 248 | + val newStack = stack.drop(ty.inps.size) |
| 249 | + val frameLocals = args ++ locals.map(zero(_)) |
| 250 | + val newFrame = Frame(ArrayBuffer(frameLocals: _*)) |
| 251 | + if (isTail) { |
| 252 | + // when tail call, share the continuation for returning with the callee |
| 253 | + evalList(body, List(), newFrame, brTable.last, trail, List(brTable.last), h) |
| 254 | + } |
| 255 | + else { |
| 256 | + val restK: Cont[Ans] = (s1, t1, h1) => kont(s1.take(ty.out.size) ++ newStack, t1, h1) |
| 257 | + // We make a new brTable by `restK`, since function creates a new block to escape |
| 258 | + // (more or less like `return`) |
| 259 | + evalList(body, List(), newFrame, restK, trail, List(restK), h) |
| 260 | + } |
| 261 | + case Import("console", "log", _) => |
| 262 | + // println(s"[DEBUG] current stack: $stack") |
| 263 | + val I32V(v) :: newStack = stack |
| 264 | + println(v) |
| 265 | + kont(newStack, trail, h) |
| 266 | + case Import("spectest", "print_i32", _) => |
| 267 | + // println(s"[DEBUG] current stack: $stack") |
| 268 | + val I32V(v) :: newStack = stack |
| 269 | + println(v) |
| 270 | + kont(newStack, trail, h) |
| 271 | + case Import(_, _, _) => throw new Exception(s"Unknown import at $funcIndex") |
| 272 | + case _ => throw new Exception(s"Definition at $funcIndex is not callable") |
| 273 | + } |
| 274 | + |
| 275 | + // If `main` is given, then we use that function as the entry point of the program; |
| 276 | + // otherwise, we look up the top-level `start` instruction to locate the entry point. |
| 277 | + def evalTop[Ans](halt: Cont[Ans], main: Option[String] = None): Ans = { |
| 278 | + val instrs = main match { |
| 279 | + case Some(func_name) => |
| 280 | + module.defs.flatMap({ |
| 281 | + case Export(`func_name`, ExportFunc(fid)) => |
| 282 | + System.err.println(s"Entering function $main") |
| 283 | + module.funcs(fid) match { |
| 284 | + case FuncDef(_, FuncBodyDef(_, _, locals, body)) => body |
| 285 | + case _ => throw new Exception("Entry function has no concrete body") |
| 286 | + } |
| 287 | + case _ => List() |
| 288 | + }) |
| 289 | + case None => |
| 290 | + module.defs.flatMap({ |
| 291 | + case Start(id) => |
| 292 | + System.err.println(s"Entering unnamed function $id") |
| 293 | + module.funcs(id) match { |
| 294 | + case FuncDef(_, FuncBodyDef(_, _, locals, body)) => body |
| 295 | + case _ => |
| 296 | + throw new Exception("Entry function has no concrete body") |
| 297 | + } |
| 298 | + case _ => List() |
| 299 | + }) |
| 300 | + } |
| 301 | + val locals = main match { |
| 302 | + case Some(func_name) => |
| 303 | + module.defs.flatMap({ |
| 304 | + case Export(`func_name`, ExportFunc(fid)) => |
| 305 | + System.err.println(s"Entering function $main") |
| 306 | + module.funcs(fid) match { |
| 307 | + case FuncDef(_, FuncBodyDef(_, _, locals, _)) => locals |
| 308 | + case _ => throw new Exception("Entry function has no concrete body") |
| 309 | + } |
| 310 | + case _ => List() |
| 311 | + }) |
| 312 | + case None => |
| 313 | + module.defs.flatMap({ |
| 314 | + case Start(id) => |
| 315 | + System.err.println(s"Entering unnamed function $id") |
| 316 | + module.funcs(id) match { |
| 317 | + case FuncDef(_, FuncBodyDef(_, _, locals, body)) => locals |
| 318 | + case _ => |
| 319 | + throw new Exception("Entry function has no concrete body") |
| 320 | + } |
| 321 | + case _ => List() |
| 322 | + }) |
| 323 | + } |
| 324 | + if (instrs.isEmpty) println("Warning: nothing is executed") |
| 325 | + // initialized locals |
| 326 | + val frame = Frame(ArrayBuffer(locals.map(zero(_)): _*)) |
| 327 | + evalList(instrs, List(), frame, initK[Ans], List((halt, List())), List(initK: Cont[Ans]), List()) |
| 328 | + } |
| 329 | + |
| 330 | + def evalTop(m: ModuleInstance): Unit = |
| 331 | + evalTop(((stack, trail, _hs) => { |
| 332 | + if (!trail.isEmpty) { |
| 333 | + throw new Exception("Composing something after halt continuation") |
| 334 | + } |
| 335 | + }): Cont[Unit]) |
| 336 | +} |
| 337 | + |
0 commit comments