Skip to content

Commit c9aa581

Browse files
authored
Rewrite to include Decl in the standard AST (#951)
* Register custom pprint function for generate-pprint.mc * TmDecl * LetAst -> LetDeclAst * Update comby definition to avoid keywords * TmLet -> TmDecl DeclLet * Fix type errors * Fix some further TmDecl updates
1 parent 6c9c6a6 commit c9aa581

File tree

122 files changed

+4196
-4981
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

122 files changed

+4196
-4981
lines changed

misc/mcore-comby.json

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,8 @@
11
{
22
"user_defined_delimiters": [
3-
[ "lang", "end" ],
43
[ "(", ")" ],
54
[ "[", "]" ],
6-
[ "{", "}" ],
7-
[ "let", "in" ],
8-
[ "recursive", "in" ],
9-
[ "recursive", "end" ],
10-
[ "type", "in" ],
11-
[ "use", "in" ]
5+
[ "{", "}" ]
126
],
137
"escapable_string_literals": {
148
"delimiters": [ "\"" ],

misc/test-spec.mc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -866,7 +866,11 @@ testMain
866866
}
867867

868868
, { testColl "microbenchmark"
869-
with exclusions = lam api.
869+
with checkCondition = lam.
870+
if eqi 0 (command "ocamlfind query owl >/dev/null 2>&1")
871+
then ConditionsMet ()
872+
else ConditionsUnmet ()
873+
, exclusions = lam api.
870874
-- NOTE(vipa, 2023-05-16): These are tested via new tests instead
871875
api.mark noTasks (api.glob ["test", "microbenchmark"] (IncludeSubs ()) (SuffixFile ".mc"));
872876
-- TODO(vipa, 2024-11-08): Actually run this one, not just

src/main/mi-lite.mc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ lang MCoreLiteCompile =
2323
-- code size.
2424
sem stripUtests : Expr -> Expr
2525
sem stripUtests =
26-
| TmUtest t -> stripUtests t.next
26+
| TmDecl (x & {decl = DeclUtest _}) -> stripUtests x.inexpr
2727
| t -> smap_Expr_Expr stripUtests t
2828
end
2929

src/stdlib/c/compile.mc

Lines changed: 41 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -485,12 +485,12 @@ lang MExprCCompile = MExprCCompileBase + MExprTensorCCompile + RecordTypeUtils
485485
-----------------------
486486

487487
sem collectExternals (acc: Map Name ExtInfo) =
488-
| TmExt t ->
488+
| TmDecl (x & {decl = DeclExt t}) ->
489489
let str = nameGetStr t.ident in
490490
match mapLookup str externalsMap with Some e then
491491
let e: ExtInfo = e in -- TODO(dlunde,2021-10-25): Remove with more complete type system?
492492
let acc = mapInsert t.ident e acc in
493-
sfold_Expr_Expr collectExternals acc t.inexpr
493+
sfold_Expr_Expr collectExternals acc x.inexpr
494494
else errorSingle [t.info] "Unsupported external"
495495
| expr -> sfold_Expr_Expr collectExternals acc expr
496496

@@ -853,7 +853,7 @@ lang MExprCCompile = MExprCCompileBase + MExprTensorCCompile + RecordTypeUtils
853853

854854
sem compileTops (env: CompileCEnv) (accTop: [CTop]) (accInit: [CStmt]) =
855855

856-
| TmLet { ident = ident, tyBody = tyBody, body = body, inexpr = inexpr } ->
856+
| TmDecl {decl = DeclLet { ident = ident, tyBody = tyBody, body = body }, inexpr = inexpr } ->
857857

858858
-- Functions
859859
match body with TmLam _ then
@@ -885,8 +885,8 @@ lang MExprCCompile = MExprCCompileBase + MExprTensorCCompile + RecordTypeUtils
885885
compileTops env accTop accInit inexpr
886886
else never
887887

888-
| TmRecLets { bindings = bindings, inexpr = inexpr } ->
889-
let f = lam env. lam binding: RecLetBinding.
888+
| TmDecl {decl = DeclRecLets { bindings = bindings}, inexpr = inexpr } ->
889+
let f = lam env. lam binding: DeclLetRecord.
890890
match binding with { ident = ident, tyBody = tyBody, body = body } then
891891
compileFun env ident tyBody body
892892
else never
@@ -905,7 +905,7 @@ lang MExprCCompile = MExprCCompileBase + MExprTensorCCompile + RecordTypeUtils
905905
else never
906906

907907
-- Ignore externals (handled elsewhere)
908-
| TmExt { inexpr = inexpr } -> compileTops env accTop accInit inexpr
908+
| TmDecl {decl = DeclExt _, inexpr = inexpr} -> compileTops env accTop accInit inexpr
909909

910910
-- Set up initialization code (for use, e.g., in a main function)
911911
| rest ->
@@ -1111,7 +1111,7 @@ lang MExprCCompile = MExprCCompileBase + MExprTensorCCompile + RecordTypeUtils
11111111

11121112
sem compileStmts (env: CompileCEnv) (res: Result) (acc: [CStmt]) =
11131113

1114-
| TmLet { ident = ident, tyBody = tyBody, body = body, inexpr = inexpr } ->
1114+
| TmDecl {decl = DeclLet { ident = ident, tyBody = tyBody, body = body}, inexpr = inexpr } ->
11151115

11161116
-- Optimize direct allocations
11171117
match body with TmConApp _ | TmRecord _ | TmSeq _ then
@@ -1145,7 +1145,7 @@ lang MExprCCompile = MExprCCompileBase + MExprTensorCCompile + RecordTypeUtils
11451145
| TmNever _ -> (env, snoc acc (CSNop {}))
11461146

11471147
-- Ignore externals (handled elsewhere)
1148-
| TmExt { inexpr = inexpr } -> compileStmts env res acc inexpr
1148+
| TmDecl {decl = DeclExt _, inexpr = inexpr} -> compileStmts env res acc inexpr
11491149

11501150

11511151
-----------------
@@ -1275,10 +1275,10 @@ lang MExprCCompile = MExprCCompileBase + MExprTensorCCompile + RecordTypeUtils
12751275
else errorSingle [infoTm t] "ERROR: Records cannot be handled in compileExpr."
12761276

12771277
-- Should not occur after ANF and type lifting.
1278-
| (TmRecordUpdate _ | TmLet _
1279-
| TmRecLets _ | TmType _ | TmConDef _
1280-
| TmConApp _ | TmMatch _ | TmUtest _
1281-
| TmSeq _ | TmExt _) & t ->
1278+
| (TmRecordUpdate _ | TmDecl {decl = DeclLet _}
1279+
| TmDecl {decl = DeclRecLets _} | TmDecl {decl = DeclType _} | TmDecl {decl = DeclConDef _}
1280+
| TmConApp _ | TmMatch _ | TmDecl {decl = DeclUtest _}
1281+
| TmSeq _ | TmDecl {decl = DeclExt _}) & t ->
12821282
errorSingle [infoTm t] "ERROR: Term cannot be handled in compileExpr."
12831283

12841284
-- Literals
@@ -1430,9 +1430,8 @@ let testCompile32Bit : Expr -> String = lam expr.
14301430
printCompiledCProg (compile opts expr) in
14311431

14321432
let simpleLet = bindall_ [
1433-
ulet_ "x" (int_ 1),
1434-
int_ 0
1435-
] in
1433+
ulet_ "x" (int_ 1)]
1434+
(int_ 0) in
14361435
utest testCompile simpleLet with strJoin "\n" [
14371436
"#include <stdint.h>",
14381437
"#include <stdio.h>",
@@ -1447,9 +1446,8 @@ utest testCompile simpleLet with strJoin "\n" [
14471446
let simpleFun = bindall_ [
14481447
let_ "foo" (tyarrows_ [tyint_, tyint_, tyint_])
14491448
(ulam_ "a" (ulam_ "b" (addi_ (var_ "a") (var_ "b")))),
1450-
ulet_ "x" (appf2_ (var_ "foo") (int_ 1) (int_ 2)),
1451-
int_ 0
1452-
] in
1449+
ulet_ "x" (appf2_ (var_ "foo") (int_ 1) (int_ 2))]
1450+
(int_ 0) in
14531451
utest testCompile simpleFun with strJoin "\n" [
14541452
"#include <stdint.h>",
14551453
"#include <stdio.h>",
@@ -1476,11 +1474,10 @@ let constants = bindall_ [
14761474
ulet_ "t" (eqf_ (float_ 1.) (float_ 2.)),
14771475
ulet_ "t" (lti_ (int_ 1) (int_ 2)),
14781476
ulet_ "t" (ltf_ (float_ 1.) (float_ 2.)),
1479-
ulet_ "t" (negf_ (float_ 1.)),
1477+
ulet_ "t" (negf_ (float_ 1.))]
14801478
(print_ (str_ "Hello, world!"))
1481-
])),
1482-
int_ 0
1483-
] in
1479+
))]
1480+
(int_ 0) in
14841481
utest testCompile constants with strJoin "\n" [
14851482
"#include <stdint.h>",
14861483
"#include <stdio.h>",
@@ -1554,9 +1551,8 @@ let factorial = bindall_ [
15541551
(int_ 1)
15551552
(muli_ (var_ "n")
15561553
(app_ (var_ "factorial")
1557-
(subi_ (var_ "n") (int_ 1)))))),
1558-
int_ 0
1559-
] in
1554+
(subi_ (var_ "n") (int_ 1))))))]
1555+
(int_ 0) in
15601556
utest testCompile factorial with strJoin "\n" [
15611557
"#include <stdint.h>",
15621558
"#include <stdio.h>",
@@ -1599,9 +1595,8 @@ let oddEven = bindall_ [
15991595
false_
16001596
(app_ (var_ "odd")
16011597
(subi_ (var_ "x") (int_ 1))))))
1602-
],
1603-
int_ 0
1604-
] in
1598+
]]
1599+
(int_ 0) in
16051600
utest testCompile oddEven with strJoin "\n" [
16061601
"#include <stdint.h>",
16071602
"#include <stdio.h>",
@@ -1659,10 +1654,9 @@ let typedefs = bindall_ [
16591654
(tyarrow_ (tyrecord_ [("v", (tycon_ "Integer2"))]) (tycon_ "Tree")),
16601655
condef_ "Node" (tyarrow_
16611656
(tyrecord_ [("v", tyint_), ("l", (tycon_ "Tree")), ("r", (tycon_ "Tree"))])
1662-
(tycon_ "Tree")),
1657+
(tycon_ "Tree"))]
16631658

1664-
int_ 0
1665-
] in
1659+
(int_ 0) in
16661660
utest testCompile typedefs with strJoin "\n" [
16671661
"#include <stdint.h>",
16681662
"#include <stdio.h>",
@@ -1686,9 +1680,8 @@ utest testCompile typedefs with strJoin "\n" [
16861680
-- Potentially tricky case with type aliases
16871681
let alias = bindall_ [
16881682
type_ "MyRec" [] (tyrecord_ [("k", tyint_)]),
1689-
let_ "myRec" (tycon_ "MyRec") (urecord_ [("k", int_ 0)]),
1690-
int_ 0
1691-
] in
1683+
let_ "myRec" (tycon_ "MyRec") (urecord_ [("k", int_ 0)])]
1684+
(int_ 0) in
16921685
utest testCompile alias with strJoin "\n" [
16931686
"#include <stdint.h>",
16941687
"#include <stdio.h>",
@@ -1705,9 +1698,8 @@ utest testCompile alias with strJoin "\n" [
17051698
-- Externals test
17061699
let ext = bindall_ [
17071700
ext_ "externalLog" false (tyarrow_ tyfloat_ tyfloat_),
1708-
let_ "x" (tyfloat_) (app_ (var_ "externalLog") (float_ 2.)),
1709-
int_ 0
1710-
] in
1701+
let_ "x" (tyfloat_) (app_ (var_ "externalLog") (float_ 2.))]
1702+
(int_ 0) in
17111703
utest testCompile ext with strJoin "\n" [
17121704
"#include <stdint.h>",
17131705
"#include <stdio.h>",
@@ -1756,10 +1748,9 @@ let trees = bindall_ [
17561748
(var_ "v") never_))
17571749
),
17581750

1759-
ulet_ "sum" (app_ (var_ "treeRec") (var_ "tree")),
1751+
ulet_ "sum" (app_ (var_ "treeRec") (var_ "tree"))]
17601752

1761-
int_ 0
1762-
] in
1753+
(int_ 0) in
17631754

17641755
utest testCompile trees with strJoin "\n" [
17651756
"#include <stdint.h>",
@@ -1842,11 +1833,9 @@ utest testCompile trees with strJoin "\n" [
18421833
-- let leaf = match tree with node then leftnode else
18431834
let manyAllocs = bindall_ [
18441835

1845-
ulet_ "rec" (match_ (bool_ true) (pbool_ true) (urecord_ [("a",int_ 1)]) (urecord_ [("a",int_ 2)])),
1846-
1847-
int_ 0
1836+
ulet_ "rec" (match_ (bool_ true) (pbool_ true) (urecord_ [("a",int_ 1)]) (urecord_ [("a",int_ 2)]))]
18481837

1849-
] in
1838+
(int_ 0) in
18501839

18511840
utest testCompile manyAllocs with strJoin "\n" [
18521841
"#include <stdint.h>",
@@ -1871,13 +1860,13 @@ utest testCompile manyAllocs with strJoin "\n" [
18711860
-- NOTE(larshum, 2022-03-02): We use type-ascriptions so that the intrinsic
18721861
-- functions are treated as monomorphic, even though they are not.
18731862
let seq = bindall_ [
1874-
let_ "s" (tyseq_ tyint_) (seq_ [int_ 1, int_ 2, int_ 3]),
1875-
app_
1863+
let_ "s" (tyseq_ tyint_) (seq_ [int_ 1, int_ 2, int_ 3])]
1864+
(app_
18761865
(bind_
18771866
(let_ "len" (tyarrow_ (tyseq_ tyint_) tyint_) (uconst_ (CLength ())))
18781867
(var_ "len"))
18791868
(var_ "s")
1880-
] in
1869+
) in
18811870

18821871
utest testCompile seq with strJoin "\n" [
18831872
"#include <stdint.h>",
@@ -1924,9 +1913,8 @@ let tensor = bindall_ [
19241913
(bind_
19251914
(let_ "s" (tytensorshape_ tyint_) (uconst_ (CTensorShape ())))
19261915
(var_ "s"))
1927-
(var_ "t"))),
1928-
int_ 0
1929-
] in
1916+
(var_ "t")))]
1917+
(int_ 0) in
19301918

19311919
utest testCompile tensor with strJoin "\n" [
19321920
"#include <stdint.h>",
@@ -2011,12 +1999,10 @@ utest testCompile tensor with strJoin "\n" [
20111999
let seqs = bindall_ [
20122000

20132001
-- Define nested sequence, and see how it is handled
2014-
ulet_ "seq" (seq_ [seq_ [int_ 1], seq_ [int_ 2]]),
2015-
2002+
ulet_ "seq" (seq_ [seq_ [int_ 1], seq_ [int_ 2]])]
20162003
-- Use "length" and "get" functions
20172004

2018-
int_ 0
2019-
2020-
] in
2005+
(int_ 0)
2006+
in
20212007

20222008
()

src/stdlib/cuda/inline-higher.mc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@ lang CudaInlineHigherOrder = MExprAst
1212
sem inlinePartialFunctionsH inlineBodies =
1313
| TmVar t ->
1414
match mapLookup t.ident inlineBodies with Some body then body else TmVar t
15-
| TmLet (t & {body = !TmLam _}) ->
15+
| TmDecl (x & {decl = DeclLet (t & {body = !TmLam _})}) ->
1616
match t.tyBody with TyArrow _ then
1717
let inlineBodies = mapInsert t.ident t.body inlineBodies in
18-
inlinePartialFunctionsH inlineBodies t.inexpr
19-
else TmLet {t with body = inlinePartialFunctionsH inlineBodies t.body,
20-
inexpr = inlinePartialFunctionsH inlineBodies t.inexpr}
18+
inlinePartialFunctionsH inlineBodies x.inexpr
19+
else TmDecl {x with decl = DeclLet {t with body = inlinePartialFunctionsH inlineBodies t.body},
20+
inexpr = inlinePartialFunctionsH inlineBodies x.inexpr}
2121
| t -> smap_Expr_Expr (inlinePartialFunctionsH inlineBodies) t
2222
end

src/stdlib/cuda/lang-fix.mc

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ lang CudaLanguageFragmentFix = PMExprAst
1616
TmLam {t with body = _eliminateFailureCodeInSemanticFunctionBody t.body}
1717
| TmMatch t ->
1818
TmMatch {t with els = _eliminateFailureCodeInSemanticFunctionBody t.els}
19-
| TmLet {
20-
body = TmApp {lhs = TmConst {val = CDPrint _}},
19+
| TmDecl {decl = DeclLet {
20+
body = TmApp {lhs = TmConst {val = CDPrint _}}},
2121
inexpr = TmApp {lhs = TmConst {val = CError _},
2222
rhs = TmSeq _},
2323
info = info} ->
@@ -27,28 +27,21 @@ lang CudaLanguageFragmentFix = PMExprAst
2727
TmNever {ty = TyUnknown {info = info}, info = info}
2828
| t -> t
2929

30-
sem _eliminateFailureCodeInSemanticFunction : RecLetBinding -> RecLetBinding
30+
sem _eliminateFailureCodeInSemanticFunction : DeclLetRecord -> DeclLetRecord
3131
sem _eliminateFailureCodeInSemanticFunction =
3232
| recLetBinding ->
33-
let recLetBinding : RecLetBinding = recLetBinding in
33+
let recLetBinding : DeclLetRecord = recLetBinding in
3434
let body = _eliminateFailureCodeInSemanticFunctionBody recLetBinding.body in
3535
{recLetBinding with body = body}
3636

3737
sem fixLanguageFragmentSemanticFunction : Expr -> Expr
3838
sem fixLanguageFragmentSemanticFunction =
39-
| TmLet t ->
40-
TmLet {t with inexpr = fixLanguageFragmentSemanticFunction t.inexpr}
41-
| TmRecLets t ->
39+
| TmDecl x -> TmDecl {x with inexpr = fixLanguageFragmentSemanticFunction x.inexpr}
40+
| TmDecl (x & {decl = DeclRecLets t}) ->
4241
let bindings = map _eliminateFailureCodeInSemanticFunction t.bindings in
43-
TmRecLets {{t with bindings = bindings}
44-
with inexpr = fixLanguageFragmentSemanticFunction t.inexpr}
45-
| TmType t ->
46-
TmType {t with inexpr = fixLanguageFragmentSemanticFunction t.inexpr}
47-
| TmConDef t ->
48-
TmConDef {t with inexpr = fixLanguageFragmentSemanticFunction t.inexpr}
49-
| TmUtest t ->
50-
TmUtest {t with next = fixLanguageFragmentSemanticFunction t.next}
51-
| TmExt t ->
52-
TmExt {t with inexpr = fixLanguageFragmentSemanticFunction t.inexpr}
42+
TmDecl
43+
{ x with decl = DeclRecLets {t with bindings = bindings}
44+
, inexpr = fixLanguageFragmentSemanticFunction x.inexpr
45+
}
5346
| t -> t
5447
end

0 commit comments

Comments
 (0)