1
+ package wasm .miniwasm
2
+
3
+ import wasm .ast ._
4
+ import wasm .memory ._
5
+
6
+ import scala .collection .mutable .ArrayBuffer
7
+ import scala .collection .mutable .HashMap
8
+ import Console .{GREEN , RED , RESET , YELLOW_B , UNDERLINED }
9
+
10
+ case class Trap () extends Exception
11
+
12
+ case class ModuleInstance (
13
+ defs : List [Definition ],
14
+ types : List [FuncLikeType ],
15
+ tags : List [FuncType ],
16
+ funcs : HashMap [Int , Callable ],
17
+ memory : List [RTMemory ] = List (RTMemory ()),
18
+ globals : List [RTGlobal ] = List (),
19
+ exports : List [Export ] = List ()
20
+ )
21
+
22
+ case class Frame (locals : ArrayBuffer [Value ])
23
+
24
+ object ModuleInstance {
25
+ def apply (module : Module ): ModuleInstance = {
26
+ val types = module.definitions
27
+ .collect({
28
+ case TypeDef (_, ft) => ft
29
+ })
30
+ .toList
31
+ val tags = module.definitions
32
+ .collect({
33
+ case Tag (id, ty) => ty
34
+ })
35
+ .toList
36
+
37
+ val funcs = module.definitions
38
+ .collect({
39
+ case FuncDef (_, fndef @ FuncBodyDef (_, _, _, _)) => fndef
40
+ })
41
+ .toList
42
+
43
+ val globals = module.definitions
44
+ .collect({
45
+ case Global (_, GlobalValue (ty, e)) =>
46
+ (e.head) match {
47
+ case Const (c) => RTGlobal (ty, c)
48
+ // Q: What is the default behavior if case in non-exhaustive
49
+ case _ => ???
50
+ }
51
+ })
52
+ .toList
53
+
54
+ // TODO: correct the behavior for memory
55
+ val memory = module.definitions
56
+ .collect({
57
+ case Memory (id, MemoryType (min, max_opt)) =>
58
+ RTMemory (min, max_opt)
59
+ })
60
+ .toList
61
+
62
+ val exports = module.definitions
63
+ .collect({
64
+ case e @ Export (_, ExportFunc (_)) => e
65
+ })
66
+ .toList
67
+
68
+ ModuleInstance (module.definitions, types, tags, module.funcEnv, memory, globals, exports)
69
+ }
70
+ }
71
+
72
+ def evalBinOp (op : BinOp , lhs : Value , rhs : Value ): Value = op match
73
+ case Add (_) =>
74
+ (lhs, rhs) match
75
+ case (I32V (v1), I32V (v2)) => I32V (v1 + v2)
76
+ case (I64V (v1), I64V (v2)) => I64V (v1 + v2)
77
+ case (F32V (v1), F32V (v2)) => F32V (v1 + v2)
78
+ case (F64V (v1), F64V (v2)) => F64V (v1 + v2)
79
+ case _ => throw new Exception (" Invalid types" )
80
+ case Mul (_) =>
81
+ (lhs, rhs) match
82
+ case (I32V (v1), I32V (v2)) => I32V (v1 * v2)
83
+ case (I64V (v1), I64V (v2)) => I64V (v1 * v2)
84
+ case _ => throw new Exception (" Invalid types" )
85
+ case Sub (_) =>
86
+ (lhs, rhs) match
87
+ case (I32V (v1), I32V (v2)) => I32V (v1 - v2)
88
+ case (I64V (v1), I64V (v2)) => I64V (v1 - v2)
89
+ case _ => throw new Exception (" Invalid types" )
90
+ case Shl (_) =>
91
+ (lhs, rhs) match
92
+ case (I32V (v1), I32V (v2)) => I32V (v1 << v2)
93
+ case (I64V (v1), I64V (v2)) => I64V (v1 << v2)
94
+ case _ => throw new Exception (" Invalid types" )
95
+ case ShrU (_) =>
96
+ (lhs, rhs) match
97
+ case (I32V (v1), I32V (v2)) => I32V (v1 >>> v2)
98
+ case (I64V (v1), I64V (v2)) => I64V (v1 >>> v2)
99
+ case _ => throw new Exception (" Invalid types" )
100
+ case And (_) =>
101
+ (lhs, rhs) match
102
+ case (I32V (v1), I32V (v2)) => I32V (v1 & v2)
103
+ case (I64V (v1), I64V (v2)) => I64V (v1 & v2)
104
+ case _ => throw new Exception (" Invalid types" )
105
+ case _ => ???
106
+
107
+ def evalUnaryOp (op : UnaryOp , value : Value ) = op match
108
+ case Clz (_) =>
109
+ value match
110
+ case I32V (v) => I32V (Integer .numberOfLeadingZeros(v))
111
+ case I64V (v) => I64V (java.lang.Long .numberOfLeadingZeros(v))
112
+ case _ => throw new Exception (" Invalid types" )
113
+ case Ctz (_) =>
114
+ value match
115
+ case I32V (v) => I32V (Integer .numberOfTrailingZeros(v))
116
+ case I64V (v) => I64V (java.lang.Long .numberOfTrailingZeros(v))
117
+ case _ => throw new Exception (" Invalid types" )
118
+ case Popcnt (_) =>
119
+ value match
120
+ case I32V (v) => I32V (Integer .bitCount(v))
121
+ case I64V (v) => I64V (java.lang.Long .bitCount(v))
122
+ case _ => throw new Exception (" Invalid types" )
123
+ case _ => ???
124
+
125
+ def evalRelOp (op : RelOp , lhs : Value , rhs : Value ) = op match
126
+ case Eq (_) =>
127
+ (lhs, rhs) match
128
+ case (I32V (v1), I32V (v2)) => I32V (if (v1 == v2) 1 else 0 )
129
+ case (I64V (v1), I64V (v2)) => I32V (if (v1 == v2) 1 else 0 )
130
+ case _ => throw new Exception (" Invalid types" )
131
+ case Ne (_) =>
132
+ (lhs, rhs) match
133
+ case (I32V (v1), I32V (v2)) => I32V (if (v1 != v2) 1 else 0 )
134
+ case (I64V (v1), I64V (v2)) => I32V (if (v1 != v2) 1 else 0 )
135
+ case _ => throw new Exception (" Invalid types" )
136
+ case LtS (_) =>
137
+ (lhs, rhs) match
138
+ case (I32V (v1), I32V (v2)) => I32V (if (v1 < v2) 1 else 0 )
139
+ case (I64V (v1), I64V (v2)) => I32V (if (v1 < v2) 1 else 0 )
140
+ case _ => throw new Exception (" Invalid types" )
141
+ case LtU (_) =>
142
+ (lhs, rhs) match
143
+ case (I32V (v1), I32V (v2)) =>
144
+ I32V (if (Integer .compareUnsigned(v1, v2) < 0 ) 1 else 0 )
145
+ case (I64V (v1), I64V (v2)) =>
146
+ I32V (if (java.lang.Long .compareUnsigned(v1, v2) < 0 ) 1 else 0 )
147
+ case _ => throw new Exception (" Invalid types" )
148
+ case GtS (_) =>
149
+ (lhs, rhs) match
150
+ case (I32V (v1), I32V (v2)) => I32V (if (v1 > v2) 1 else 0 )
151
+ case (I64V (v1), I64V (v2)) => I32V (if (v1 > v2) 1 else 0 )
152
+ case _ => throw new Exception (" Invalid types" )
153
+ case GtU (_) =>
154
+ (lhs, rhs) match
155
+ case (I32V (v1), I32V (v2)) =>
156
+ I32V (if (Integer .compareUnsigned(v1, v2) > 0 ) 1 else 0 )
157
+ case (I64V (v1), I64V (v2)) =>
158
+ I32V (if (java.lang.Long .compareUnsigned(v1, v2) > 0 ) 1 else 0 )
159
+ case _ => throw new Exception (" Invalid types" )
160
+ case LeS (_) =>
161
+ (lhs, rhs) match
162
+ case (I32V (v1), I32V (v2)) => I32V (if (v1 <= v2) 1 else 0 )
163
+ case (I64V (v1), I64V (v2)) => I32V (if (v1 <= v2) 1 else 0 )
164
+ case _ => throw new Exception (" Invalid types" )
165
+ case LeU (_) =>
166
+ (lhs, rhs) match
167
+ case (I32V (v1), I32V (v2)) =>
168
+ I32V (if (Integer .compareUnsigned(v1, v2) <= 0 ) 1 else 0 )
169
+ case (I64V (v1), I64V (v2)) =>
170
+ I32V (if (java.lang.Long .compareUnsigned(v1, v2) <= 0 ) 1 else 0 )
171
+ case _ => throw new Exception (" Invalid types" )
172
+ case GeS (_) =>
173
+ (lhs, rhs) match
174
+ case (I32V (v1), I32V (v2)) => I32V (if (v1 >= v2) 1 else 0 )
175
+ case (I64V (v1), I64V (v2)) => I32V (if (v1 >= v2) 1 else 0 )
176
+ case _ => throw new Exception (" Invalid types" )
177
+ case GeU (_) =>
178
+ (lhs, rhs) match
179
+ case (I32V (v1), I32V (v2)) =>
180
+ I32V (if (Integer .compareUnsigned(v1, v2) >= 0 ) 1 else 0 )
181
+ case (I64V (v1), I64V (v2)) =>
182
+ I32V (if (java.lang.Long .compareUnsigned(v1, v2) >= 0 ) 1 else 0 )
183
+ case _ => throw new Exception (" Invalid types" )
184
+
185
+ def evalTestOp (op : TestOp , value : Value ) = op match
186
+ case Eqz (_) =>
187
+ value match
188
+ case I32V (v) => I32V (if (v == 0 ) 1 else 0 )
189
+ case I64V (v) => I32V (if (v == 0 ) 1 else 0 )
190
+ case _ => throw new Exception (" Invalid types" )
191
+
192
+ def memOutOfBound (module : ModuleInstance , memoryIndex : Int , offset : Int , size : Int ) = {
193
+ val memory = module.memory(memoryIndex)
194
+ offset + size > memory.size
195
+ }
196
+
197
+ def zero (t : ValueType ): Value = t match
198
+ case NumType (kind) =>
199
+ kind match
200
+ case I32Type => I32V (0 )
201
+ case I64Type => I64V (0 )
202
+ case F32Type => F32V (0 )
203
+ case F64Type => F64V (0 )
204
+ case VecType (kind) => ???
205
+ case RefType (kind) => RefNullV (kind)
206
+
207
+ def getFuncType (ty : BlockType ): FuncType =
208
+ ty match
209
+ case VarBlockType (_, None ) => ??? // TODO: fill this branch until we handle type index correctly
210
+ case VarBlockType (_, Some (tipe)) => tipe
211
+ case ValBlockType (Some (tipe)) => FuncType (List (), List (), List (tipe))
212
+ case ValBlockType (None ) => FuncType (List (), List (), List ())
213
+
214
+ def extractMainInstrs (module : ModuleInstance , main : Option [String ]): List [Instr ] =
215
+ main match
216
+ case Some (func_name) =>
217
+ module.defs.flatMap({
218
+ case Export (`func_name`, ExportFunc (fid)) =>
219
+ System .err.println(s " Entering function $main" )
220
+ module.funcs(fid) match
221
+ case FuncDef (_, FuncBodyDef (_, _, locals, body)) => body
222
+ case _ => throw new Exception (" Entry function has no concrete body" )
223
+ case _ => List ()
224
+ })
225
+ case None =>
226
+ module.defs.flatMap({
227
+ case Start (id) =>
228
+ System .err.println(s " Entering unnamed function $id" )
229
+ module.funcs(id) match
230
+ case FuncDef (_, FuncBodyDef (_, _, locals, body)) => body
231
+ case _ => throw new Exception (" Entry function has no concrete body" )
232
+ case _ => List ()
233
+ })
234
+
235
+ def extractLocals (module : ModuleInstance , main : Option [String ]): List [ValueType ] =
236
+ main match
237
+ case Some (func_name) =>
238
+ module.defs.flatMap({
239
+ case Export (`func_name`, ExportFunc (fid)) =>
240
+ System .err.println(s " Entering function $main" )
241
+ module.funcs(fid) match
242
+ case FuncDef (_, FuncBodyDef (_, _, locals, _)) => locals
243
+ case _ => throw new Exception (" Entry function has no concrete body" )
244
+ case _ => List ()
245
+ })
246
+ case None =>
247
+ module.defs.flatMap({
248
+ case Start (id) =>
249
+ System .err.println(s " Entering unnamed function $id" )
250
+ module.funcs(id) match
251
+ case FuncDef (_, FuncBodyDef (_, _, locals, body)) => locals
252
+ case _ => throw new Exception (" Entry function has no concrete body" )
253
+ case _ => List ()
254
+ })
0 commit comments