Skip to content

Commit 9d70d49

Browse files
author
Guannan Wei
committed
refactor and simplify code
1 parent f9a4d9c commit 9d70d49

File tree

6 files changed

+308
-467
lines changed

6 files changed

+308
-467
lines changed

src/main/scala/Base.scala

Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
package wasm.miniwasm
2+
3+
import wasm.ast._
4+
import wasm.memory._
5+
6+
import scala.collection.mutable.ArrayBuffer
7+
import scala.collection.mutable.HashMap
8+
import Console.{GREEN, RED, RESET, YELLOW_B, UNDERLINED}
9+
10+
case class Trap() extends Exception
11+
12+
case class ModuleInstance(
13+
defs: List[Definition],
14+
types: List[FuncLikeType],
15+
tags: List[FuncType],
16+
funcs: HashMap[Int, Callable],
17+
memory: List[RTMemory] = List(RTMemory()),
18+
globals: List[RTGlobal] = List(),
19+
exports: List[Export] = List()
20+
)
21+
22+
case class Frame(locals: ArrayBuffer[Value])
23+
24+
object ModuleInstance {
25+
def apply(module: Module): ModuleInstance = {
26+
val types = module.definitions
27+
.collect({
28+
case TypeDef(_, ft) => ft
29+
})
30+
.toList
31+
val tags = module.definitions
32+
.collect({
33+
case Tag(id, ty) => ty
34+
})
35+
.toList
36+
37+
val funcs = module.definitions
38+
.collect({
39+
case FuncDef(_, fndef @ FuncBodyDef(_, _, _, _)) => fndef
40+
})
41+
.toList
42+
43+
val globals = module.definitions
44+
.collect({
45+
case Global(_, GlobalValue(ty, e)) =>
46+
(e.head) match {
47+
case Const(c) => RTGlobal(ty, c)
48+
// Q: What is the default behavior if case in non-exhaustive
49+
case _ => ???
50+
}
51+
})
52+
.toList
53+
54+
// TODO: correct the behavior for memory
55+
val memory = module.definitions
56+
.collect({
57+
case Memory(id, MemoryType(min, max_opt)) =>
58+
RTMemory(min, max_opt)
59+
})
60+
.toList
61+
62+
val exports = module.definitions
63+
.collect({
64+
case e @ Export(_, ExportFunc(_)) => e
65+
})
66+
.toList
67+
68+
ModuleInstance(module.definitions, types, tags, module.funcEnv, memory, globals, exports)
69+
}
70+
}
71+
72+
def evalBinOp(op: BinOp, lhs: Value, rhs: Value): Value = op match
73+
case Add(_) =>
74+
(lhs, rhs) match
75+
case (I32V(v1), I32V(v2)) => I32V(v1 + v2)
76+
case (I64V(v1), I64V(v2)) => I64V(v1 + v2)
77+
case (F32V(v1), F32V(v2)) => F32V(v1 + v2)
78+
case (F64V(v1), F64V(v2)) => F64V(v1 + v2)
79+
case _ => throw new Exception("Invalid types")
80+
case Mul(_) =>
81+
(lhs, rhs) match
82+
case (I32V(v1), I32V(v2)) => I32V(v1 * v2)
83+
case (I64V(v1), I64V(v2)) => I64V(v1 * v2)
84+
case _ => throw new Exception("Invalid types")
85+
case Sub(_) =>
86+
(lhs, rhs) match
87+
case (I32V(v1), I32V(v2)) => I32V(v1 - v2)
88+
case (I64V(v1), I64V(v2)) => I64V(v1 - v2)
89+
case _ => throw new Exception("Invalid types")
90+
case Shl(_) =>
91+
(lhs, rhs) match
92+
case (I32V(v1), I32V(v2)) => I32V(v1 << v2)
93+
case (I64V(v1), I64V(v2)) => I64V(v1 << v2)
94+
case _ => throw new Exception("Invalid types")
95+
case ShrU(_) =>
96+
(lhs, rhs) match
97+
case (I32V(v1), I32V(v2)) => I32V(v1 >>> v2)
98+
case (I64V(v1), I64V(v2)) => I64V(v1 >>> v2)
99+
case _ => throw new Exception("Invalid types")
100+
case And(_) =>
101+
(lhs, rhs) match
102+
case (I32V(v1), I32V(v2)) => I32V(v1 & v2)
103+
case (I64V(v1), I64V(v2)) => I64V(v1 & v2)
104+
case _ => throw new Exception("Invalid types")
105+
case _ => ???
106+
107+
def evalUnaryOp(op: UnaryOp, value: Value) = op match
108+
case Clz(_) =>
109+
value match
110+
case I32V(v) => I32V(Integer.numberOfLeadingZeros(v))
111+
case I64V(v) => I64V(java.lang.Long.numberOfLeadingZeros(v))
112+
case _ => throw new Exception("Invalid types")
113+
case Ctz(_) =>
114+
value match
115+
case I32V(v) => I32V(Integer.numberOfTrailingZeros(v))
116+
case I64V(v) => I64V(java.lang.Long.numberOfTrailingZeros(v))
117+
case _ => throw new Exception("Invalid types")
118+
case Popcnt(_) =>
119+
value match
120+
case I32V(v) => I32V(Integer.bitCount(v))
121+
case I64V(v) => I64V(java.lang.Long.bitCount(v))
122+
case _ => throw new Exception("Invalid types")
123+
case _ => ???
124+
125+
def evalRelOp(op: RelOp, lhs: Value, rhs: Value) = op match
126+
case Eq(_) =>
127+
(lhs, rhs) match
128+
case (I32V(v1), I32V(v2)) => I32V(if (v1 == v2) 1 else 0)
129+
case (I64V(v1), I64V(v2)) => I32V(if (v1 == v2) 1 else 0)
130+
case _ => throw new Exception("Invalid types")
131+
case Ne(_) =>
132+
(lhs, rhs) match
133+
case (I32V(v1), I32V(v2)) => I32V(if (v1 != v2) 1 else 0)
134+
case (I64V(v1), I64V(v2)) => I32V(if (v1 != v2) 1 else 0)
135+
case _ => throw new Exception("Invalid types")
136+
case LtS(_) =>
137+
(lhs, rhs) match
138+
case (I32V(v1), I32V(v2)) => I32V(if (v1 < v2) 1 else 0)
139+
case (I64V(v1), I64V(v2)) => I32V(if (v1 < v2) 1 else 0)
140+
case _ => throw new Exception("Invalid types")
141+
case LtU(_) =>
142+
(lhs, rhs) match
143+
case (I32V(v1), I32V(v2)) =>
144+
I32V(if (Integer.compareUnsigned(v1, v2) < 0) 1 else 0)
145+
case (I64V(v1), I64V(v2)) =>
146+
I32V(if (java.lang.Long.compareUnsigned(v1, v2) < 0) 1 else 0)
147+
case _ => throw new Exception("Invalid types")
148+
case GtS(_) =>
149+
(lhs, rhs) match
150+
case (I32V(v1), I32V(v2)) => I32V(if (v1 > v2) 1 else 0)
151+
case (I64V(v1), I64V(v2)) => I32V(if (v1 > v2) 1 else 0)
152+
case _ => throw new Exception("Invalid types")
153+
case GtU(_) =>
154+
(lhs, rhs) match
155+
case (I32V(v1), I32V(v2)) =>
156+
I32V(if (Integer.compareUnsigned(v1, v2) > 0) 1 else 0)
157+
case (I64V(v1), I64V(v2)) =>
158+
I32V(if (java.lang.Long.compareUnsigned(v1, v2) > 0) 1 else 0)
159+
case _ => throw new Exception("Invalid types")
160+
case LeS(_) =>
161+
(lhs, rhs) match
162+
case (I32V(v1), I32V(v2)) => I32V(if (v1 <= v2) 1 else 0)
163+
case (I64V(v1), I64V(v2)) => I32V(if (v1 <= v2) 1 else 0)
164+
case _ => throw new Exception("Invalid types")
165+
case LeU(_) =>
166+
(lhs, rhs) match
167+
case (I32V(v1), I32V(v2)) =>
168+
I32V(if (Integer.compareUnsigned(v1, v2) <= 0) 1 else 0)
169+
case (I64V(v1), I64V(v2)) =>
170+
I32V(if (java.lang.Long.compareUnsigned(v1, v2) <= 0) 1 else 0)
171+
case _ => throw new Exception("Invalid types")
172+
case GeS(_) =>
173+
(lhs, rhs) match
174+
case (I32V(v1), I32V(v2)) => I32V(if (v1 >= v2) 1 else 0)
175+
case (I64V(v1), I64V(v2)) => I32V(if (v1 >= v2) 1 else 0)
176+
case _ => throw new Exception("Invalid types")
177+
case GeU(_) =>
178+
(lhs, rhs) match
179+
case (I32V(v1), I32V(v2)) =>
180+
I32V(if (Integer.compareUnsigned(v1, v2) >= 0) 1 else 0)
181+
case (I64V(v1), I64V(v2)) =>
182+
I32V(if (java.lang.Long.compareUnsigned(v1, v2) >= 0) 1 else 0)
183+
case _ => throw new Exception("Invalid types")
184+
185+
def evalTestOp(op: TestOp, value: Value) = op match
186+
case Eqz(_) =>
187+
value match
188+
case I32V(v) => I32V(if (v == 0) 1 else 0)
189+
case I64V(v) => I32V(if (v == 0) 1 else 0)
190+
case _ => throw new Exception("Invalid types")
191+
192+
def memOutOfBound(module: ModuleInstance, memoryIndex: Int, offset: Int, size: Int) = {
193+
val memory = module.memory(memoryIndex)
194+
offset + size > memory.size
195+
}
196+
197+
def zero(t: ValueType): Value = t match
198+
case NumType(kind) =>
199+
kind match
200+
case I32Type => I32V(0)
201+
case I64Type => I64V(0)
202+
case F32Type => F32V(0)
203+
case F64Type => F64V(0)
204+
case VecType(kind) => ???
205+
case RefType(kind) => RefNullV(kind)
206+
207+
def getFuncType(ty: BlockType): FuncType =
208+
ty match
209+
case VarBlockType(_, None) => ??? // TODO: fill this branch until we handle type index correctly
210+
case VarBlockType(_, Some(tipe)) => tipe
211+
case ValBlockType(Some(tipe)) => FuncType(List(), List(), List(tipe))
212+
case ValBlockType(None) => FuncType(List(), List(), List())
213+
214+
def extractMainInstrs(module: ModuleInstance, main: Option[String]): List[Instr] =
215+
main match
216+
case Some(func_name) =>
217+
module.defs.flatMap({
218+
case Export(`func_name`, ExportFunc(fid)) =>
219+
System.err.println(s"Entering function $main")
220+
module.funcs(fid) match
221+
case FuncDef(_, FuncBodyDef(_, _, locals, body)) => body
222+
case _ => throw new Exception("Entry function has no concrete body")
223+
case _ => List()
224+
})
225+
case None =>
226+
module.defs.flatMap({
227+
case Start(id) =>
228+
System.err.println(s"Entering unnamed function $id")
229+
module.funcs(id) match
230+
case FuncDef(_, FuncBodyDef(_, _, locals, body)) => body
231+
case _ => throw new Exception("Entry function has no concrete body")
232+
case _ => List()
233+
})
234+
235+
def extractLocals(module: ModuleInstance, main: Option[String]): List[ValueType] =
236+
main match
237+
case Some(func_name) =>
238+
module.defs.flatMap({
239+
case Export(`func_name`, ExportFunc(fid)) =>
240+
System.err.println(s"Entering function $main")
241+
module.funcs(fid) match
242+
case FuncDef(_, FuncBodyDef(_, _, locals, _)) => locals
243+
case _ => throw new Exception("Entry function has no concrete body")
244+
case _ => List()
245+
})
246+
case None =>
247+
module.defs.flatMap({
248+
case Start(id) =>
249+
System.err.println(s"Entering unnamed function $id")
250+
module.funcs(id) match
251+
case FuncDef(_, FuncBodyDef(_, _, locals, body)) => locals
252+
case _ => throw new Exception("Entry function has no concrete body")
253+
case _ => List()
254+
})

src/main/scala/Main.scala

Lines changed: 0 additions & 5 deletions
This file was deleted.

0 commit comments

Comments
 (0)