Skip to content

Commit 9b8a196

Browse files
VLanvinfacebook-github-bot
authored andcommitted
Rewrite ElabGuard
Summary: Several components to this change: - merge `elabTest` and `elabTestT`, propagate an upper bound everywhere (possibly `AnyType`) - gate a lot of "smart" logic behind condition `upper == trueType` - restore smart elaboration of `not` disabled in D58184266, gated behind upper bound condition - smarter elaboration of `or` and `orelse` Reviewed By: ilya-klyuchnikov Differential Revision: D60389213 fbshipit-source-id: cefa2945c4d7f08b0d0f5ee57a5de969e6f7f364
1 parent fe2b133 commit 9b8a196

File tree

1 file changed

+64
-79
lines changed

1 file changed

+64
-79
lines changed

eqwalizer/src/main/scala/com/whatsapp/eqwalizer/tc/ElabGuard.scala

+64-79
Original file line numberDiff line numberDiff line change
@@ -65,24 +65,49 @@ final class ElabGuard(pipelineContext: PipelineContext) {
6565
envAcc
6666
}
6767

68-
private def elabTest(test: Test, env: Env): Env = {
68+
def elabTestT(test: Test, upper: Type, env: Env): Env = {
6969
test match {
7070
case TestVar(v) =>
71-
// safe because we assume no unbound vars
72-
val ty = env.getOrElse(
73-
v,
74-
AnyType,
75-
)
76-
typeInfo.add(test.pos, ty)
77-
env
71+
val testType = env.get(v) match {
72+
case Some(vt) =>
73+
narrow.meet(vt, upper)
74+
case None => upper
75+
}
76+
typeInfo.add(test.pos, testType)
77+
env + (v -> testType)
78+
case TestCall(Id(pred, 1), List(arg)) if upper == trueType && elabPredicateType1.isDefinedAt(pred) =>
79+
elabTestT(arg, elabPredicateType1(pred), env)
80+
case TestCall(Id(pred, 2), List(arg1, arg2))
81+
if upper == trueType && elabPredicateType22.isDefinedAt((pred, arg2)) =>
82+
elabTestT(arg1, elabPredicateType22(pred, arg2), env)
83+
case TestCall(Id(pred, 2), List(arg1, arg2))
84+
if upper == trueType && elabPredicateType21.isDefinedAt((pred, arg1)) =>
85+
elabTestT(arg2, elabPredicateType21(pred, arg1), env)
86+
case TestCall(Id(pred, 3), List(arg1, arg2, _))
87+
if upper == trueType && elabPredicateType22.isDefinedAt((pred, arg2)) =>
88+
elabTestT(arg1, elabPredicateType22(pred, arg2), env)
89+
case TestBinOp("and" | "andalso", arg1, arg2) if upper == trueType =>
90+
val env1 = elabTestT(arg1, trueType, env)
91+
elabTestT(arg2, trueType, env1)
92+
case TestBinOp("orelse", arg1, arg2) if upper == trueType =>
93+
val envTrue = elabTestT(arg1, trueType, env)
94+
val envFalse = elabTestT(arg1, falseType, env)
95+
val envFalse2 = elabTestT(arg2, trueType, envFalse)
96+
subtype.joinEnvs(List(envTrue, envFalse2))
97+
case TestBinOp("or", arg1, arg2) if upper == trueType =>
98+
val envTrue = elabTestT(arg1, trueType, env)
99+
val envTrue2 = elabTestT(arg2, booleanType, envTrue)
100+
val envFalse = elabTestT(arg1, booleanType, env)
101+
val envFalse2 = elabTestT(arg2, trueType, envFalse)
102+
subtype.joinEnvs(List(envTrue2, envFalse2))
78103
case TestAtom(_) =>
79104
env
80105
case TestNumber(_) =>
81106
env
82107
case TestTuple(elems) =>
83108
var envAcc: Env = env
84109
for (elem <- elems) {
85-
val elemEnv = elabTest(elem, envAcc)
110+
val elemEnv = elabTestT(elem, AnyType, envAcc)
86111
envAcc = elemEnv
87112
}
88113
envAcc
@@ -91,28 +116,28 @@ final class ElabGuard(pipelineContext: PipelineContext) {
91116
case TestNil() =>
92117
env
93118
case TestCons(head, tail) =>
94-
val env1 = elabTest(head, env)
95-
val env2 = elabTest(tail, env1)
119+
val env1 = elabTestT(head, AnyType, env)
120+
val env2 = elabTestT(tail, AnyType, env1)
96121
env2
97122
case TestMapCreate(kvs) =>
98123
var envAcc: Env = env
99124
for ((k, v) <- kvs) {
100-
val kEnv = elabTest(k, envAcc)
101-
val vEnv = elabTest(v, kEnv)
125+
val kEnv = elabTestT(k, AnyType, envAcc)
126+
val vEnv = elabTestT(v, AnyType, kEnv)
102127
envAcc = vEnv
103128
}
104129
envAcc
105130
case TestCall(_, args) =>
106131
var envAcc: Env = env
107132
for (arg <- args) {
108-
val argEnv = elabTest(arg, envAcc)
133+
val argEnv = elabTestT(arg, AnyType, envAcc)
109134
envAcc = argEnv
110135
}
111136
env
112137
case unOp: TestUnOp =>
113-
elabUnOp(unOp, env)
138+
elabUnOp(unOp, upper, env)
114139
case binOp: TestBinOp =>
115-
elabBinOp(binOp, env)
140+
elabBinOp(binOp, upper, env)
116141
case TestBinaryLit() =>
117142
env
118143
case TestRecordIndex(_, _) =>
@@ -162,54 +187,12 @@ final class ElabGuard(pipelineContext: PipelineContext) {
162187
}
163188
}
164189

165-
def elabTestT(test: Test, upper: Type, env: Env): Env = {
166-
test match {
167-
case TestVar(v) =>
168-
val testType = env.get(v) match {
169-
case Some(vt) =>
170-
narrow.meet(vt, upper)
171-
case None => upper
172-
}
173-
typeInfo.add(test.pos, testType)
174-
env + (v -> testType)
175-
case TestCall(Id(pred, 1), List(arg)) if upper == trueType && elabPredicateType1.isDefinedAt(pred) =>
176-
elabTestT(arg, elabPredicateType1(pred), env)
177-
case TestCall(Id(pred, 2), List(arg1, arg2))
178-
if upper == trueType && elabPredicateType22.isDefinedAt((pred, arg2)) =>
179-
elabTestT(arg1, elabPredicateType22(pred, arg2), env)
180-
case TestCall(Id(pred, 2), List(arg1, arg2))
181-
if upper == trueType && elabPredicateType21.isDefinedAt((pred, arg1)) =>
182-
elabTestT(arg2, elabPredicateType21(pred, arg1), env)
183-
case TestCall(Id(pred, 3), List(arg1, arg2, _))
184-
if upper == trueType && elabPredicateType22.isDefinedAt((pred, arg2)) =>
185-
elabTestT(arg1, elabPredicateType22(pred, arg2), env)
186-
case TestBinOp("and", arg1, arg2) =>
187-
val env1 = elabTestT(arg1, AtomLitType("true"), env)
188-
elabTestT(arg2, upper, env1)
189-
case TestBinOp("andalso", arg1, arg2) =>
190-
val env1 = elabTestT(arg1, AtomLitType("true"), env)
191-
elabTestT(arg2, upper, env1)
192-
case TestBinOp("orelse", arg1, arg2) =>
193-
val envTrue = elabTestT(arg1, trueType, env)
194-
val envFalse = elabTestT(arg2, upper, env)
195-
subtype.joinEnvs(List(envTrue, envFalse))
196-
case TestBinOp("or", arg1, arg2) =>
197-
val env1 = elabTestT(arg1, booleanType, env)
198-
// "or" is not short-circuiting
199-
elabTestT(arg2, booleanType, env1)
200-
case _ =>
201-
elabTest(test, env)
202-
}
203-
}
204-
205-
def elabUnOp(unOp: TestUnOp, env: Env): Env = {
190+
private def elabUnOp(unOp: TestUnOp, upper: Type, env: Env): Env = {
206191
val TestUnOp(op, arg) = unOp
207192
op match {
208-
case "not" =>
209-
arg match {
210-
case TestVar(_) => elabTestT(arg, booleanType, env)
211-
case _ => env
212-
}
193+
case "not" if upper == trueType => elabTestT(arg, falseType, env)
194+
case "not" if upper == falseType => elabTestT(arg, trueType, env)
195+
case "not" => elabTestT(arg, booleanType, env)
213196
case "bnot" | "+" | "-" =>
214197
elabTestT(arg, NumberType, env)
215198
case _ =>
@@ -226,84 +209,86 @@ final class ElabGuard(pipelineContext: PipelineContext) {
226209
}
227210
}
228211

229-
private def elabComparison(binOp: TestBinOp, env: Env): Env =
212+
private def elabComparison(binOp: TestBinOp, upper: Type, env: Env): Env =
230213
binOp match {
231-
case TestBinOp("=:=" | "==", TestVar(v), NumTest()) =>
214+
case TestBinOp("=:=" | "==", TestVar(v), NumTest()) if upper == trueType =>
232215
env.get(v) match {
233216
case Some(ty) =>
234217
env + (v -> narrow.meet(ty, NumberType))
235218
case None =>
236219
env
237220
}
238-
case TestBinOp("=:=" | "==", NumTest(), TestVar(v)) =>
221+
case TestBinOp("=:=" | "==", NumTest(), TestVar(v)) if upper == trueType =>
239222
env.get(v) match {
240223
case Some(ty) =>
241224
env + (v -> narrow.meet(ty, NumberType))
242225
case None =>
243226
env
244227
}
245-
case TestBinOp("=:=" | "==", TestVar(v), TestString()) =>
228+
case TestBinOp("=:=" | "==", TestVar(v), TestString()) if upper == trueType =>
246229
env.get(v) match {
247230
case Some(ty) =>
248231
env + (v -> narrow.meet(ty, stringType))
249232
case None =>
250233
env
251234
}
252-
case TestBinOp("=:=" | "==", TestString(), TestVar(v)) =>
235+
case TestBinOp("=:=" | "==", TestString(), TestVar(v)) if upper == trueType =>
253236
env.get(v) match {
254237
case Some(ty) =>
255238
env + (v -> narrow.meet(ty, stringType))
256239
case None =>
257240
env
258241
}
259-
case TestBinOp("=:=" | "==", TestVar(v), TestAtom(a)) =>
242+
case TestBinOp("=:=" | "==", TestVar(v), TestAtom(a)) if upper == trueType =>
260243
env.get(v) match {
261244
case Some(ty) =>
262245
env + (v -> narrow.meet(ty, AtomLitType(a)))
263246
case None =>
264247
env
265248
}
266-
case TestBinOp("=:=" | "==", TestAtom(a), TestVar(v)) =>
249+
case TestBinOp("=:=" | "==", TestAtom(a), TestVar(v)) if upper == trueType =>
267250
env.get(v) match {
268251
case Some(ty) =>
269252
env + (v -> narrow.meet(ty, AtomLitType(a)))
270253
case None =>
271254
env
272255
}
273-
case TestBinOp("=:=" | "==", TestCall(Id("element", 2), List(TestNumber(Some(i)), TestVar(v))), TestAtom(a)) =>
256+
case TestBinOp("=:=" | "==", TestCall(Id("element", 2), List(TestNumber(Some(i)), TestVar(v))), TestAtom(a))
257+
if upper == trueType =>
274258
env.get(v) match {
275259
case Some(ty) =>
276260
env + (v -> narrow.filterTupleType(ty, i, AtomLitType(a)))
277261
case None =>
278262
env
279263
}
280-
case TestBinOp("=:=" | "==", TestAtom(a), TestCall(Id("element", 2), List(TestNumber(Some(i)), TestVar(v)))) =>
264+
case TestBinOp("=:=" | "==", TestAtom(a), TestCall(Id("element", 2), List(TestNumber(Some(i)), TestVar(v))))
265+
if upper == trueType =>
281266
env.get(v) match {
282267
case Some(ty) =>
283268
env + (v -> narrow.filterTupleType(ty, i, AtomLitType(a)))
284269
case None =>
285270
env
286271
}
287-
case TestBinOp("=/=" | "/=", TestVar(v), TestAtom(a)) =>
272+
case TestBinOp("=/=" | "/=", TestVar(v), TestAtom(a)) if upper == trueType =>
288273
env.get(v) match {
289274
case Some(ty) =>
290275
env + (v -> occurrence.remove(ty, AtomLitType(a)))
291276
case None =>
292277
env
293278
}
294-
case TestBinOp("=/=" | "/=", TestAtom(a), TestVar(v)) =>
279+
case TestBinOp("=/=" | "/=", TestAtom(a), TestVar(v)) if upper == trueType =>
295280
env.get(v) match {
296281
case Some(ty) =>
297282
env + (v -> occurrence.remove(ty, AtomLitType(a)))
298283
case None =>
299284
env
300285
}
301286
case TestBinOp(_, arg1, arg2) =>
302-
val env1 = elabTest(arg1, env)
303-
elabTest(arg2, env1)
287+
val env1 = elabTestT(arg1, AnyType, env)
288+
elabTestT(arg2, AnyType, env1)
304289
}
305290

306-
private def elabBinOp(binOp: TestBinOp, env: Env): Env = {
291+
private def elabBinOp(binOp: TestBinOp, upper: Type, env: Env): Env = {
307292
val TestBinOp(op, arg1, arg2) = binOp
308293
op match {
309294
case "/" | "*" | "-" | "+" | "div" | "rem" | "band" | "bor" | "bxor" | "bsl" | "bsr" =>
@@ -313,13 +298,13 @@ final class ElabGuard(pipelineContext: PipelineContext) {
313298
val env1 = elabTestT(arg1, booleanType, env)
314299
elabTestT(arg2, booleanType, env1)
315300
case ">=" | ">" | "=<" | "<" | "/=" | "=/=" | "==" | "=:=" =>
316-
elabComparison(binOp, env)
301+
elabComparison(binOp, upper, env)
317302
case "andalso" =>
318303
val env1 = elabTestT(arg1, booleanType, env)
319-
elabTest(arg2, env1)
304+
elabTestT(arg2, upper, env1)
320305
case "orelse" =>
321306
val env1 = elabTestT(arg1, booleanType, env)
322-
elabTest(arg2, env1)
307+
elabTestT(arg2, upper, env1)
323308
case _ =>
324309
throw new IllegalStateException(s"unexpected $op")
325310
}

0 commit comments

Comments
 (0)