Skip to content

Commit cd6c053

Browse files
committed
feat: use default for disjoint branches
1 parent 9141b15 commit cd6c053

File tree

3 files changed

+128
-17
lines changed

3 files changed

+128
-17
lines changed

src/BodyBuilder.hs

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1681,14 +1681,20 @@ mergeForks (PrimFork var0 ty0 final0 branches0 dflt0) (PrimFork var1 ty1 final1
16811681
(mergeMaybeBodies dflt0 dflt1) vars
16821682
else App.empty
16831683
where
1684-
nothingDflts = all isNothing [dflt0, dflt1]
1684+
-- assumes branches are sorted, a safe assummption from the BkwdBuilder
16851685
mergeIndexedBranches [] [] vars = return (vars, [])
1686+
-- when either branch is "missing" a value, either use the default from the other branch,
1687+
-- or if there is no default then we can just keep that branch
16861688
mergeIndexedBranches bbs0@((v0,b0):bs0) bbs1@((v1,b1):bs1) vars@Factors{renamed=renamed}
1687-
| v0 < v1 && nothingDflts
1688-
= ((v0,b0):) <$$> mergeIndexedBranches bs0 bbs1 vars
1689-
| v0 > v1 && nothingDflts
1690-
= ((v1,renameProcBody renamed b1):) <$$> mergeIndexedBranches bs0 bs1 vars
1691-
| v0 == v1
1689+
| v0 < v1
1690+
= case dflt1 of
1691+
Nothing -> ((v0,b0):) <$$> mergeIndexedBranches bs0 bbs1 vars
1692+
Just dflt1' -> mergeIndexedBranches bbs0 ((v0,dflt1'):bbs1) vars
1693+
| v0 > v1
1694+
= case dflt0 of
1695+
Nothing -> ((v0,b0):) <$$> mergeIndexedBranches bs0 bbs1 vars
1696+
Just dflt0' -> mergeIndexedBranches ((v1,dflt0'):bbs0) bbs1 vars
1697+
| otherwise
16921698
= combineMerged ((:) . (v0,))
16931699
(mergeBodies b0 b1)
16941700
(mergeIndexedBranches bs0 bs1)
@@ -1705,7 +1711,9 @@ mergeForks (MergedFork var0 ty0 final0 table0 branch0 dflt0) (MergedFork var1 ty
17051711
vars
17061712
else App.empty
17071713
where
1708-
mergeTables :: MergedForkTable -> MergedForkTable -> FactoredMerge MergedForkTable
1714+
-- assumes tables are created in the same order
1715+
-- will fail if not in the same order later on
1716+
-- XXX we can probably to better than this!
17091717
mergeTables [] [] vars = return (vars, [])
17101718
mergeTables (entry0@(var0, ty0, args0):table0) ((var1, ty1, args1):table1) vars@Factors{renamed=renamed}
17111719
| args0 == args1

test-cases/final-dump/merged_forks.exp

Lines changed: 93 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,51 @@ nested_different_default(b##0:wybe.bool, x##0:wybe.int)<{<<wybe.io.io>>}; {<<wyb
130130

131131

132132

133+
proc nested_disjoint_same_default > {noinline} (0 calls)
134+
0: merged_forks.nested_disjoint_same_default<0>
135+
nested_disjoint_same_default(b##0:wybe.bool, x##0:wybe.int)<{<<wybe.io.io>>}; {<<wybe.io.io>>}; {}>:
136+
AliasPairs: []
137+
InterestingCallProperties: []
138+
MultiSpeczDepInfo: [(2,(wybe.string.print<0>,fromList [NonAliasedParamCond 0 []]))]
139+
factored ~b##0:wybe.bool of
140+
?tmp#82##0:wybe.bool <- [ 0:wybe.bool, 1:wybe.bool ]
141+
?tmp#83##0:wybe.int <- [ 2:wybe.int, 3:wybe.int ]
142+
wybe.bool.fmt<0>(~tmp#82##0:wybe.bool, ?tmp#1##0:wybe.string) #0 @merged_forks:nn:nn
143+
wybe.string.,,<0>(~tmp#1##0:wybe.string, 1155:wybe.string, ?tmp#0##0:wybe.string) #1 @merged_forks:nn:nn
144+
wybe.string.print<0>[410bae77d3](~tmp#0##0:wybe.string)<{<<wybe.io.io>>}; {<<wybe.io.io>>}; {}> #2 @merged_forks:nn:nn
145+
case ~x##0:wybe.int of
146+
0:
147+
foreign lpvm load(<<wybe.io.io>>:wybe.phantom, ?%tmp#63##0:wybe.phantom) @merged_forks:nn:nn
148+
foreign c print_int(1:wybe.int, ~%tmp#63##0:wybe.phantom, ?%tmp#64##0:wybe.phantom) @merged_forks:nn:nn
149+
foreign c putchar('\n':wybe.char, ~tmp#64##0:wybe.phantom, ?tmp#65##0:wybe.phantom) @merged_forks:nn:nn
150+
foreign lpvm store(~%tmp#65##0:wybe.phantom, <<wybe.io.io>>:wybe.phantom) @merged_forks:nn:nn
151+
152+
2:
153+
foreign lpvm load(<<wybe.io.io>>:wybe.phantom, ?%tmp#55##0:wybe.phantom) @merged_forks:nn:nn
154+
foreign c print_int(2:wybe.int, ~%tmp#55##0:wybe.phantom, ?%tmp#56##0:wybe.phantom) @merged_forks:nn:nn
155+
foreign c putchar('\n':wybe.char, ~tmp#56##0:wybe.phantom, ?tmp#57##0:wybe.phantom) @merged_forks:nn:nn
156+
foreign lpvm store(~%tmp#57##0:wybe.phantom, <<wybe.io.io>>:wybe.phantom) @merged_forks:nn:nn
157+
158+
3:
159+
foreign lpvm load(<<wybe.io.io>>:wybe.phantom, ?%tmp#51##0:wybe.phantom) @merged_forks:nn:nn
160+
foreign c print_int(3:wybe.int, ~%tmp#51##0:wybe.phantom, ?%tmp#52##0:wybe.phantom) @merged_forks:nn:nn
161+
foreign c putchar('\n':wybe.char, ~tmp#52##0:wybe.phantom, ?tmp#53##0:wybe.phantom) @merged_forks:nn:nn
162+
foreign lpvm store(~%tmp#53##0:wybe.phantom, <<wybe.io.io>>:wybe.phantom) @merged_forks:nn:nn
163+
164+
4:
165+
foreign lpvm load(<<wybe.io.io>>:wybe.phantom, ?%tmp#59##0:wybe.phantom) @merged_forks:nn:nn
166+
foreign c print_int(4:wybe.int, ~%tmp#59##0:wybe.phantom, ?%tmp#60##0:wybe.phantom) @merged_forks:nn:nn
167+
foreign c putchar('\n':wybe.char, ~tmp#60##0:wybe.phantom, ?tmp#61##0:wybe.phantom) @merged_forks:nn:nn
168+
foreign lpvm store(~%tmp#61##0:wybe.phantom, <<wybe.io.io>>:wybe.phantom) @merged_forks:nn:nn
169+
170+
else:
171+
foreign lpvm load(<<wybe.io.io>>:wybe.phantom, ?%tmp#51##0:wybe.phantom) @merged_forks:nn:nn
172+
foreign c print_int(~tmp#83##0:wybe.int, ~%tmp#51##0:wybe.phantom, ?%tmp#52##0:wybe.phantom) @merged_forks:nn:nn
173+
foreign c putchar('\n':wybe.char, ~tmp#52##0:wybe.phantom, ?tmp#53##0:wybe.phantom) @merged_forks:nn:nn
174+
foreign lpvm store(~%tmp#53##0:wybe.phantom, <<wybe.io.io>>:wybe.phantom) @merged_forks:nn:nn
175+
176+
177+
133178
proc nested_same_default > {noinline} (8 calls)
134179
0: merged_forks.nested_same_default<0>
135180
nested_same_default(b##0:wybe.bool, x##0:wybe.int)<{<<wybe.io.io>>}; {<<wybe.io.io>>}; {}>:
@@ -206,9 +251,10 @@ target triple ????
206251

207252
@"array#0" = private unnamed_addr constant [ 2 x i1 ] [i1 0, i1 1]
208253
@"array#1" = private unnamed_addr constant [ 3 x i64 ] [i64 0, i64 1, i64 2]
209-
@"array#2" = private unnamed_addr constant [ 2 x i64 ] [i64 3, i64 2]
210-
@"array#3" = private unnamed_addr constant [ 2 x i64 ] [i64 8, i64 1]
211-
@"array#4" = private unnamed_addr constant [ 2 x i64 ] [i64 9, i64 1]
254+
@"array#2" = private unnamed_addr constant [ 2 x i64 ] [i64 2, i64 3]
255+
@"array#3" = private unnamed_addr constant [ 2 x i64 ] [i64 3, i64 2]
256+
@"array#4" = private unnamed_addr constant [ 2 x i64 ] [i64 8, i64 1]
257+
@"array#5" = private unnamed_addr constant [ 2 x i64 ] [i64 9, i64 1]
212258

213259
declare external fastcc i64 @"wybe.bool.fmt<0>"(i1)
214260
declare external fastcc void @"wybe.bool.print<0>"(i1)
@@ -248,7 +294,7 @@ if.else.0:
248294
define external fastcc void @"merged_forks.leq_2<0>"(i64 %"x##0") {
249295
%"tmp#27##0" = icmp ule i64 %"x##0", 2
250296
%"tmp#30##0" = zext i1 %"tmp#27##0" to i64
251-
%"tmp#29##0" = getelementptr inbounds [ 2 x i64 ], ptr @"array#4", i64 0, i64 %"tmp#30##0"
297+
%"tmp#29##0" = getelementptr inbounds [ 2 x i64 ], ptr @"array#5", i64 0, i64 %"tmp#30##0"
252298
%"tmp#28##0" = load i64, ptr %"tmp#29##0"
253299
call ccc void @print_int(i64 %"tmp#28##0")
254300
call ccc void @putchar(i8 10)
@@ -262,7 +308,7 @@ define external fastcc void @"merged_forks.nested<0>"(i1 %"b##0", i64 %"x##0") {
262308
tail call fastcc void @"wybe.bool.print<0>"(i1 %"tmp#59##0")
263309
%"tmp#55##0" = icmp ule i64 %"x##0", 2
264310
%"tmp#63##0" = zext i1 %"tmp#55##0" to i64
265-
%"tmp#62##0" = getelementptr inbounds [ 2 x i64 ], ptr @"array#4", i64 0, i64 %"tmp#63##0"
311+
%"tmp#62##0" = getelementptr inbounds [ 2 x i64 ], ptr @"array#5", i64 0, i64 %"tmp#63##0"
266312
%"tmp#58##0" = load i64, ptr %"tmp#62##0"
267313
call ccc void @print_int(i64 %"tmp#58##0")
268314
call ccc void @putchar(i8 10)
@@ -284,7 +330,7 @@ if.then.1:
284330
if.else.1:
285331
%"tmp#6##0" = icmp eq i64 %"x##0", 2
286332
%"tmp#61##0" = zext i1 %"tmp#6##0" to i64
287-
%"tmp#60##0" = getelementptr inbounds [ 2 x i64 ], ptr @"array#4", i64 0, i64 %"tmp#61##0"
333+
%"tmp#60##0" = getelementptr inbounds [ 2 x i64 ], ptr @"array#5", i64 0, i64 %"tmp#61##0"
288334
%"tmp#57##0" = load i64, ptr %"tmp#60##0"
289335
call ccc void @print_int(i64 %"tmp#57##0")
290336
call ccc void @putchar(i8 10)
@@ -302,19 +348,56 @@ if.then.2:
302348
if.else.2:
303349
%"tmp#8##0" = icmp eq i64 %"x##0", 3
304350
%"tmp#63##0" = zext i1 %"tmp#8##0" to i64
305-
%"tmp#62##0" = getelementptr inbounds [ 2 x i64 ], ptr @"array#3", i64 0, i64 %"tmp#63##0"
351+
%"tmp#62##0" = getelementptr inbounds [ 2 x i64 ], ptr @"array#4", i64 0, i64 %"tmp#63##0"
306352
%"tmp#56##0" = load i64, ptr %"tmp#62##0"
307353
call ccc void @print_int(i64 %"tmp#56##0")
308354
call ccc void @putchar(i8 10)
309355
ret void
310356
}
311357

358+
define external fastcc void @"merged_forks.nested_disjoint_same_default<0>"(i1 %"b##0", i64 %"x##0") {
359+
%"tmp#85##0" = zext i1 %"b##0" to i64
360+
%"tmp#84##0" = getelementptr inbounds [ 2 x i1 ], ptr @"array#0", i64 0, i64 %"tmp#85##0"
361+
%"tmp#82##0" = load i1, ptr %"tmp#84##0"
362+
%"tmp#87##0" = zext i1 %"b##0" to i64
363+
%"tmp#86##0" = getelementptr inbounds [ 2 x i64 ], ptr @"array#2", i64 0, i64 %"tmp#87##0"
364+
%"tmp#83##0" = load i64, ptr %"tmp#86##0"
365+
%"tmp#1##0" = tail call fastcc i64 @"wybe.bool.fmt<0>"(i1 %"tmp#82##0")
366+
%"tmp#0##0" = tail call fastcc i64 @"wybe.string.,,<0>"(i64 %"tmp#1##0", i64 1155)
367+
tail call fastcc void @"wybe.string.print<0>[410bae77d3]"(i64 %"tmp#0##0")
368+
switch i64 %"x##0", label %default.switch.0 [
369+
i64 0, label %case.0.switch.0
370+
i64 2, label %case.2.switch.0
371+
i64 3, label %case.3.switch.0
372+
i64 4, label %case.4.switch.0 ]
373+
case.0.switch.0:
374+
call ccc void @print_int(i64 1)
375+
call ccc void @putchar(i8 10)
376+
ret void
377+
case.2.switch.0:
378+
call ccc void @print_int(i64 2)
379+
call ccc void @putchar(i8 10)
380+
ret void
381+
case.3.switch.0:
382+
call ccc void @print_int(i64 3)
383+
call ccc void @putchar(i8 10)
384+
ret void
385+
case.4.switch.0:
386+
call ccc void @print_int(i64 4)
387+
call ccc void @putchar(i8 10)
388+
ret void
389+
default.switch.0:
390+
call ccc void @print_int(i64 %"tmp#83##0")
391+
call ccc void @putchar(i8 10)
392+
ret void
393+
}
394+
312395
define external fastcc void @"merged_forks.nested_same_default<0>"(i1 %"b##0", i64 %"x##0") {
313396
%"tmp#61##0" = zext i1 %"b##0" to i64
314397
%"tmp#60##0" = getelementptr inbounds [ 2 x i1 ], ptr @"array#0", i64 0, i64 %"tmp#61##0"
315398
%"tmp#58##0" = load i1, ptr %"tmp#60##0"
316399
%"tmp#63##0" = zext i1 %"b##0" to i64
317-
%"tmp#62##0" = getelementptr inbounds [ 2 x i64 ], ptr @"array#2", i64 0, i64 %"tmp#63##0"
400+
%"tmp#62##0" = getelementptr inbounds [ 2 x i64 ], ptr @"array#3", i64 0, i64 %"tmp#63##0"
318401
%"tmp#59##0" = load i64, ptr %"tmp#62##0"
319402
%"tmp#1##0" = tail call fastcc i64 @"wybe.bool.fmt<0>"(i1 %"tmp#58##0")
320403
%"tmp#0##0" = tail call fastcc i64 @"wybe.string.,,<0>"(i64 %"tmp#1##0", i64 1155)
@@ -328,7 +411,7 @@ if.then.0:
328411
if.else.0:
329412
%"tmp#6##0" = icmp eq i64 %"x##0", %"tmp#59##0"
330413
%"tmp#65##0" = zext i1 %"tmp#6##0" to i64
331-
%"tmp#64##0" = getelementptr inbounds [ 2 x i64 ], ptr @"array#4", i64 0, i64 %"tmp#65##0"
414+
%"tmp#64##0" = getelementptr inbounds [ 2 x i64 ], ptr @"array#5", i64 0, i64 %"tmp#65##0"
332415
%"tmp#57##0" = load i64, ptr %"tmp#64##0"
333416
call ccc void @print_int(i64 %"tmp#57##0")
334417
call ccc void @putchar(i8 10)
@@ -345,7 +428,7 @@ if.then.0:
345428
if.else.0:
346429
%"tmp#1##0" = icmp eq i64 %"x##0", 2
347430
%"tmp#22##0" = zext i1 %"tmp#1##0" to i64
348-
%"tmp#21##0" = getelementptr inbounds [ 2 x i64 ], ptr @"array#4", i64 0, i64 %"tmp#22##0"
431+
%"tmp#21##0" = getelementptr inbounds [ 2 x i64 ], ptr @"array#5", i64 0, i64 %"tmp#22##0"
349432
%"tmp#20##0" = load i64, ptr %"tmp#21##0"
350433
call ccc void @print_int(i64 %"tmp#20##0")
351434
call ccc void @putchar(i8 10)

test-cases/final-dump/merged_forks.wybe

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,26 @@ def {noinline} nested_different_default(b:bool, x:int) use !io {
8787
})
8888
}
8989
}
90+
def {noinline} nested_disjoint_same_default(b:bool, x:int) use !io {
91+
if {
92+
b ::
93+
!print("$b ")
94+
!println(case x in {
95+
0 :: 1
96+
| 4 :: 4
97+
| 2 :: 2
98+
| else :: 3
99+
})
100+
| else ::
101+
!print("$b ")
102+
!println(case x in {
103+
0 :: 1
104+
| 4 :: 4
105+
| 3 :: 3
106+
| else :: 2
107+
})
108+
}
109+
}
90110

91111
!nested_same_default(true, 0)
92112
!nested_same_default(true, 1)

0 commit comments

Comments
 (0)