@@ -65,24 +65,49 @@ final class ElabGuard(pipelineContext: PipelineContext) {
65
65
envAcc
66
66
}
67
67
68
- private def elabTest (test : Test , env : Env ): Env = {
68
+ def elabTestT (test : Test , upper : Type , env : Env ): Env = {
69
69
test match {
70
70
case TestVar (v) =>
71
- // safe because we assume no unbound vars
72
- val ty = env.getOrElse(
73
- v,
74
- AnyType ,
75
- )
76
- typeInfo.add(test.pos, ty)
77
- env
71
+ val testType = env.get(v) match {
72
+ case Some (vt) =>
73
+ narrow.meet(vt, upper)
74
+ case None => upper
75
+ }
76
+ typeInfo.add(test.pos, testType)
77
+ env + (v -> testType)
78
+ case TestCall (Id (pred, 1 ), List (arg)) if upper == trueType && elabPredicateType1.isDefinedAt(pred) =>
79
+ elabTestT(arg, elabPredicateType1(pred), env)
80
+ case TestCall (Id (pred, 2 ), List (arg1, arg2))
81
+ if upper == trueType && elabPredicateType22.isDefinedAt((pred, arg2)) =>
82
+ elabTestT(arg1, elabPredicateType22(pred, arg2), env)
83
+ case TestCall (Id (pred, 2 ), List (arg1, arg2))
84
+ if upper == trueType && elabPredicateType21.isDefinedAt((pred, arg1)) =>
85
+ elabTestT(arg2, elabPredicateType21(pred, arg1), env)
86
+ case TestCall (Id (pred, 3 ), List (arg1, arg2, _))
87
+ if upper == trueType && elabPredicateType22.isDefinedAt((pred, arg2)) =>
88
+ elabTestT(arg1, elabPredicateType22(pred, arg2), env)
89
+ case TestBinOp (" and" | " andalso" , arg1, arg2) if upper == trueType =>
90
+ val env1 = elabTestT(arg1, trueType, env)
91
+ elabTestT(arg2, trueType, env1)
92
+ case TestBinOp (" orelse" , arg1, arg2) if upper == trueType =>
93
+ val envTrue = elabTestT(arg1, trueType, env)
94
+ val envFalse = elabTestT(arg1, falseType, env)
95
+ val envFalse2 = elabTestT(arg2, trueType, envFalse)
96
+ subtype.joinEnvs(List (envTrue, envFalse2))
97
+ case TestBinOp (" or" , arg1, arg2) if upper == trueType =>
98
+ val envTrue = elabTestT(arg1, trueType, env)
99
+ val envTrue2 = elabTestT(arg2, booleanType, envTrue)
100
+ val envFalse = elabTestT(arg1, booleanType, env)
101
+ val envFalse2 = elabTestT(arg2, trueType, envFalse)
102
+ subtype.joinEnvs(List (envTrue2, envFalse2))
78
103
case TestAtom (_) =>
79
104
env
80
105
case TestNumber (_) =>
81
106
env
82
107
case TestTuple (elems) =>
83
108
var envAcc : Env = env
84
109
for (elem <- elems) {
85
- val elemEnv = elabTest (elem, envAcc)
110
+ val elemEnv = elabTestT (elem, AnyType , envAcc)
86
111
envAcc = elemEnv
87
112
}
88
113
envAcc
@@ -91,28 +116,28 @@ final class ElabGuard(pipelineContext: PipelineContext) {
91
116
case TestNil () =>
92
117
env
93
118
case TestCons (head, tail) =>
94
- val env1 = elabTest (head, env)
95
- val env2 = elabTest (tail, env1)
119
+ val env1 = elabTestT (head, AnyType , env)
120
+ val env2 = elabTestT (tail, AnyType , env1)
96
121
env2
97
122
case TestMapCreate (kvs) =>
98
123
var envAcc : Env = env
99
124
for ((k, v) <- kvs) {
100
- val kEnv = elabTest(k , envAcc)
101
- val vEnv = elabTest(v , kEnv)
125
+ val kEnv = elabTestT(k, AnyType , envAcc)
126
+ val vEnv = elabTestT(v, AnyType , kEnv)
102
127
envAcc = vEnv
103
128
}
104
129
envAcc
105
130
case TestCall (_, args) =>
106
131
var envAcc : Env = env
107
132
for (arg <- args) {
108
- val argEnv = elabTest (arg, envAcc)
133
+ val argEnv = elabTestT (arg, AnyType , envAcc)
109
134
envAcc = argEnv
110
135
}
111
136
env
112
137
case unOp : TestUnOp =>
113
- elabUnOp(unOp, env)
138
+ elabUnOp(unOp, upper, env)
114
139
case binOp : TestBinOp =>
115
- elabBinOp(binOp, env)
140
+ elabBinOp(binOp, upper, env)
116
141
case TestBinaryLit () =>
117
142
env
118
143
case TestRecordIndex (_, _) =>
@@ -162,54 +187,12 @@ final class ElabGuard(pipelineContext: PipelineContext) {
162
187
}
163
188
}
164
189
165
- def elabTestT (test : Test , upper : Type , env : Env ): Env = {
166
- test match {
167
- case TestVar (v) =>
168
- val testType = env.get(v) match {
169
- case Some (vt) =>
170
- narrow.meet(vt, upper)
171
- case None => upper
172
- }
173
- typeInfo.add(test.pos, testType)
174
- env + (v -> testType)
175
- case TestCall (Id (pred, 1 ), List (arg)) if upper == trueType && elabPredicateType1.isDefinedAt(pred) =>
176
- elabTestT(arg, elabPredicateType1(pred), env)
177
- case TestCall (Id (pred, 2 ), List (arg1, arg2))
178
- if upper == trueType && elabPredicateType22.isDefinedAt((pred, arg2)) =>
179
- elabTestT(arg1, elabPredicateType22(pred, arg2), env)
180
- case TestCall (Id (pred, 2 ), List (arg1, arg2))
181
- if upper == trueType && elabPredicateType21.isDefinedAt((pred, arg1)) =>
182
- elabTestT(arg2, elabPredicateType21(pred, arg1), env)
183
- case TestCall (Id (pred, 3 ), List (arg1, arg2, _))
184
- if upper == trueType && elabPredicateType22.isDefinedAt((pred, arg2)) =>
185
- elabTestT(arg1, elabPredicateType22(pred, arg2), env)
186
- case TestBinOp (" and" , arg1, arg2) =>
187
- val env1 = elabTestT(arg1, AtomLitType (" true" ), env)
188
- elabTestT(arg2, upper, env1)
189
- case TestBinOp (" andalso" , arg1, arg2) =>
190
- val env1 = elabTestT(arg1, AtomLitType (" true" ), env)
191
- elabTestT(arg2, upper, env1)
192
- case TestBinOp (" orelse" , arg1, arg2) =>
193
- val envTrue = elabTestT(arg1, trueType, env)
194
- val envFalse = elabTestT(arg2, upper, env)
195
- subtype.joinEnvs(List (envTrue, envFalse))
196
- case TestBinOp (" or" , arg1, arg2) =>
197
- val env1 = elabTestT(arg1, booleanType, env)
198
- // "or" is not short-circuiting
199
- elabTestT(arg2, booleanType, env1)
200
- case _ =>
201
- elabTest(test, env)
202
- }
203
- }
204
-
205
- def elabUnOp (unOp : TestUnOp , env : Env ): Env = {
190
+ private def elabUnOp (unOp : TestUnOp , upper : Type , env : Env ): Env = {
206
191
val TestUnOp (op, arg) = unOp
207
192
op match {
208
- case " not" =>
209
- arg match {
210
- case TestVar (_) => elabTestT(arg, booleanType, env)
211
- case _ => env
212
- }
193
+ case " not" if upper == trueType => elabTestT(arg, falseType, env)
194
+ case " not" if upper == falseType => elabTestT(arg, trueType, env)
195
+ case " not" => elabTestT(arg, booleanType, env)
213
196
case " bnot" | " +" | " -" =>
214
197
elabTestT(arg, NumberType , env)
215
198
case _ =>
@@ -226,84 +209,86 @@ final class ElabGuard(pipelineContext: PipelineContext) {
226
209
}
227
210
}
228
211
229
- private def elabComparison (binOp : TestBinOp , env : Env ): Env =
212
+ private def elabComparison (binOp : TestBinOp , upper : Type , env : Env ): Env =
230
213
binOp match {
231
- case TestBinOp (" =:=" | " ==" , TestVar (v), NumTest ()) =>
214
+ case TestBinOp (" =:=" | " ==" , TestVar (v), NumTest ()) if upper == trueType =>
232
215
env.get(v) match {
233
216
case Some (ty) =>
234
217
env + (v -> narrow.meet(ty, NumberType ))
235
218
case None =>
236
219
env
237
220
}
238
- case TestBinOp (" =:=" | " ==" , NumTest (), TestVar (v)) =>
221
+ case TestBinOp (" =:=" | " ==" , NumTest (), TestVar (v)) if upper == trueType =>
239
222
env.get(v) match {
240
223
case Some (ty) =>
241
224
env + (v -> narrow.meet(ty, NumberType ))
242
225
case None =>
243
226
env
244
227
}
245
- case TestBinOp (" =:=" | " ==" , TestVar (v), TestString ()) =>
228
+ case TestBinOp (" =:=" | " ==" , TestVar (v), TestString ()) if upper == trueType =>
246
229
env.get(v) match {
247
230
case Some (ty) =>
248
231
env + (v -> narrow.meet(ty, stringType))
249
232
case None =>
250
233
env
251
234
}
252
- case TestBinOp (" =:=" | " ==" , TestString (), TestVar (v)) =>
235
+ case TestBinOp (" =:=" | " ==" , TestString (), TestVar (v)) if upper == trueType =>
253
236
env.get(v) match {
254
237
case Some (ty) =>
255
238
env + (v -> narrow.meet(ty, stringType))
256
239
case None =>
257
240
env
258
241
}
259
- case TestBinOp (" =:=" | " ==" , TestVar (v), TestAtom (a)) =>
242
+ case TestBinOp (" =:=" | " ==" , TestVar (v), TestAtom (a)) if upper == trueType =>
260
243
env.get(v) match {
261
244
case Some (ty) =>
262
245
env + (v -> narrow.meet(ty, AtomLitType (a)))
263
246
case None =>
264
247
env
265
248
}
266
- case TestBinOp (" =:=" | " ==" , TestAtom (a), TestVar (v)) =>
249
+ case TestBinOp (" =:=" | " ==" , TestAtom (a), TestVar (v)) if upper == trueType =>
267
250
env.get(v) match {
268
251
case Some (ty) =>
269
252
env + (v -> narrow.meet(ty, AtomLitType (a)))
270
253
case None =>
271
254
env
272
255
}
273
- case TestBinOp (" =:=" | " ==" , TestCall (Id (" element" , 2 ), List (TestNumber (Some (i)), TestVar (v))), TestAtom (a)) =>
256
+ case TestBinOp (" =:=" | " ==" , TestCall (Id (" element" , 2 ), List (TestNumber (Some (i)), TestVar (v))), TestAtom (a))
257
+ if upper == trueType =>
274
258
env.get(v) match {
275
259
case Some (ty) =>
276
260
env + (v -> narrow.filterTupleType(ty, i, AtomLitType (a)))
277
261
case None =>
278
262
env
279
263
}
280
- case TestBinOp (" =:=" | " ==" , TestAtom (a), TestCall (Id (" element" , 2 ), List (TestNumber (Some (i)), TestVar (v)))) =>
264
+ case TestBinOp (" =:=" | " ==" , TestAtom (a), TestCall (Id (" element" , 2 ), List (TestNumber (Some (i)), TestVar (v))))
265
+ if upper == trueType =>
281
266
env.get(v) match {
282
267
case Some (ty) =>
283
268
env + (v -> narrow.filterTupleType(ty, i, AtomLitType (a)))
284
269
case None =>
285
270
env
286
271
}
287
- case TestBinOp (" =/=" | " /=" , TestVar (v), TestAtom (a)) =>
272
+ case TestBinOp (" =/=" | " /=" , TestVar (v), TestAtom (a)) if upper == trueType =>
288
273
env.get(v) match {
289
274
case Some (ty) =>
290
275
env + (v -> occurrence.remove(ty, AtomLitType (a)))
291
276
case None =>
292
277
env
293
278
}
294
- case TestBinOp (" =/=" | " /=" , TestAtom (a), TestVar (v)) =>
279
+ case TestBinOp (" =/=" | " /=" , TestAtom (a), TestVar (v)) if upper == trueType =>
295
280
env.get(v) match {
296
281
case Some (ty) =>
297
282
env + (v -> occurrence.remove(ty, AtomLitType (a)))
298
283
case None =>
299
284
env
300
285
}
301
286
case TestBinOp (_, arg1, arg2) =>
302
- val env1 = elabTest (arg1, env)
303
- elabTest (arg2, env1)
287
+ val env1 = elabTestT (arg1, AnyType , env)
288
+ elabTestT (arg2, AnyType , env1)
304
289
}
305
290
306
- private def elabBinOp (binOp : TestBinOp , env : Env ): Env = {
291
+ private def elabBinOp (binOp : TestBinOp , upper : Type , env : Env ): Env = {
307
292
val TestBinOp (op, arg1, arg2) = binOp
308
293
op match {
309
294
case " /" | " *" | " -" | " +" | " div" | " rem" | " band" | " bor" | " bxor" | " bsl" | " bsr" =>
@@ -313,13 +298,13 @@ final class ElabGuard(pipelineContext: PipelineContext) {
313
298
val env1 = elabTestT(arg1, booleanType, env)
314
299
elabTestT(arg2, booleanType, env1)
315
300
case " >=" | " >" | " =<" | " <" | " /=" | " =/=" | " ==" | " =:=" =>
316
- elabComparison(binOp, env)
301
+ elabComparison(binOp, upper, env)
317
302
case " andalso" =>
318
303
val env1 = elabTestT(arg1, booleanType, env)
319
- elabTest (arg2, env1)
304
+ elabTestT (arg2, upper , env1)
320
305
case " orelse" =>
321
306
val env1 = elabTestT(arg1, booleanType, env)
322
- elabTest (arg2, env1)
307
+ elabTestT (arg2, upper , env1)
323
308
case _ =>
324
309
throw new IllegalStateException (s " unexpected $op" )
325
310
}
0 commit comments