Skip to content

Commit 8575cfa

Browse files
committed
import modules
1 parent c2e4cc3 commit 8575cfa

19 files changed

+164
-116
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ add_library(compiler_lib STATIC
2626
src/nex_lang/post_processing/visit_args.cc
2727
src/nex_lang/post_processing/visit_expr.cc
2828
src/nex_lang/post_processing/visit_fns.cc
29+
src/nex_lang/post_processing/visit_imports.cc
2930
src/nex_lang/post_processing/visit_params.cc
3031
src/nex_lang/post_processing/visit_s.cc
3132
src/nex_lang/post_processing/visit_stmts.cc

src/nex_lang/nex_lang.cc

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ DFA make_dfa() {
164164

165165
static std::map<std::string, Terminal> keywords = {
166166
{"mod", Terminal::MODULE},
167+
{"import", Terminal::IMPORT},
167168
{"fn", Terminal::FN},
168169
{"let", Terminal::LET},
169170
{"if", Terminal::IF},
@@ -263,11 +264,17 @@ static std::map<NonTerminal, std::vector<Production>> productions = {
263264
{NonTerminal::s,
264265
{{NonTerminal::s,
265266
{Terminal::BOFS,
266-
Terminal::MODULE,
267-
Terminal::ID,
268-
Terminal::SEMI,
267+
NonTerminal::module,
268+
NonTerminal::imports,
269269
NonTerminal::fns,
270270
Terminal::EOFS}}}},
271+
{NonTerminal::module,
272+
{{NonTerminal::module, {Terminal::MODULE, Terminal::ID, Terminal::SEMI}}}},
273+
{NonTerminal::imports,
274+
{{NonTerminal::imports, {NonTerminal::import, NonTerminal::imports}},
275+
{NonTerminal::imports, {}}}},
276+
{NonTerminal::import,
277+
{{NonTerminal::import, {Terminal::IMPORT, Terminal::ID, Terminal::SEMI}}}},
271278
{NonTerminal::fns,
272279
{{NonTerminal::fns, {NonTerminal::fn, NonTerminal::fns}},
273280
{NonTerminal::fns, {NonTerminal::fn}}}},
@@ -393,26 +400,20 @@ static std::map<NonTerminal, std::vector<Production>> productions = {
393400
{NonTerminal::exprp8,
394401
{NonTerminal::exprp8, Terminal::AS, NonTerminal::type}}}},
395402
{NonTerminal::exprp9,
396-
{{NonTerminal::exprp9, {Terminal::ID}},
397-
{NonTerminal::exprp9, {Terminal::NUM}},
398-
{NonTerminal::exprp9, {Terminal::AMPERSAND, Terminal::ID}},
399-
{NonTerminal::exprp9, {Terminal::STRLITERAL}},
400-
{NonTerminal::exprp9, {Terminal::CHARLITERAL}},
401-
{NonTerminal::exprp9,
402-
{Terminal::LPAREN, NonTerminal::expr, Terminal::RPAREN}},
403-
{NonTerminal::exprp9,
404-
{Terminal::ID,
405-
Terminal::LPAREN,
406-
NonTerminal::optargs,
407-
Terminal::RPAREN}},
408-
{NonTerminal::exprp9,
409-
{Terminal::ID,
410-
Terminal::COLON,
411-
Terminal::COLON,
412-
Terminal::ID,
413-
Terminal::LPAREN,
414-
NonTerminal::optargs,
415-
Terminal::RPAREN}}}},
403+
{
404+
{NonTerminal::exprp9, {Terminal::ID}},
405+
{NonTerminal::exprp9, {Terminal::NUM}},
406+
{NonTerminal::exprp9, {Terminal::AMPERSAND, Terminal::ID}},
407+
{NonTerminal::exprp9, {Terminal::STRLITERAL}},
408+
{NonTerminal::exprp9, {Terminal::CHARLITERAL}},
409+
{NonTerminal::exprp9,
410+
{Terminal::LPAREN, NonTerminal::expr, Terminal::RPAREN}},
411+
{NonTerminal::exprp9,
412+
{Terminal::ID,
413+
Terminal::LPAREN,
414+
NonTerminal::optargs,
415+
Terminal::RPAREN}},
416+
}},
416417
{NonTerminal::optargs,
417418
{{NonTerminal::optargs, {NonTerminal::args}}, {NonTerminal::optargs, {}}}},
418419
{NonTerminal::args,
@@ -435,7 +436,7 @@ static std::set<Terminal> terminals = {
435436
Terminal::AMPERSAND, Terminal::WHILE, Terminal::BOOL,
436437
Terminal::TRUE, Terminal::FALSE, Terminal::STRLITERAL,
437438
Terminal::CHARLITERAL, Terminal::CHAR, Terminal::AS,
438-
Terminal::MODULE,
439+
Terminal::MODULE, Terminal::IMPORT,
439440
};
440441

441442
static std::set<NonTerminal> non_terminals = {
@@ -446,7 +447,8 @@ static std::set<NonTerminal> non_terminals = {
446447
NonTerminal::exprp3, NonTerminal::exprp4, NonTerminal::exprp5,
447448
NonTerminal::exprp6, NonTerminal::exprp7, NonTerminal::exprp8,
448449
NonTerminal::exprp9, NonTerminal::optargs, NonTerminal::args,
449-
NonTerminal::stmtblock,
450+
NonTerminal::stmtblock, NonTerminal::module, NonTerminal::imports,
451+
NonTerminal::import,
450452
};
451453

452454
Grammar make_grammar() {

src/nex_lang/nex_lang_grammar.txt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
s BOFS MODULE ID SEMI fns EOFS
1+
s BOFS module imports fns EOFS
2+
module MODULE ID SEMI
3+
imports import imports
4+
imports
5+
import IMPORT ID SEMI
26
fns fn fns
37
fns fn
48
fn FN ID LPAREN optparams RPAREN ARROW type stmtblock
@@ -57,7 +61,6 @@ exprp9 STRLITERAL
5761
exprp9 CHARLITERAL
5862
exprp9 LPAREN expr RPAREN
5963
exprp9 ID LPAREN optargs RPAREN
60-
exprp9 ID COLON COLON ID LPAREN optargs RPAREN
6164
optargs args
6265
optargs
6366
args expr COMMA args

src/nex_lang/post_processing/extract_s.cc

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,18 @@ void extract_s(ASTNode root, ModuleTable& module_table) {
2323
== std::vector<State> {
2424
NonTerminal::s,
2525
Terminal::BOFS,
26-
Terminal::MODULE,
27-
Terminal::ID,
28-
Terminal::SEMI,
26+
NonTerminal::module,
27+
NonTerminal::imports,
2928
NonTerminal::fns,
3029
Terminal::EOFS}) {
3130
// extract functions of program
3231

33-
ASTNode module = root.children.at(2);
34-
std::string name = module.lexeme;
32+
ASTNode module = root.children.at(1);
33+
std::string name = module.children.at(1).lexeme;
3534

3635
SymbolTable symbol_table;
3736

38-
ASTNode fns = root.children.at(4);
37+
ASTNode fns = root.children.at(3);
3938
extract_fns(fns, symbol_table);
4039

4140
module_table[name] = symbol_table;

src/nex_lang/post_processing/visit_expr.cc

Lines changed: 0 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -189,61 +189,6 @@ TypedExpr visit_expr(
189189
} else {
190190
throw SymbolNotFoundError(name, id.line_no);
191191
}
192-
193-
} else if (prod == std::vector<State> {NonTerminal::exprp9, Terminal::ID, Terminal::COLON, Terminal::COLON, Terminal::ID, Terminal::LPAREN, NonTerminal::optargs, Terminal::RPAREN}) {
194-
ASTNode module_node = root.children.at(0);
195-
std::string module_name = module_node.lexeme;
196-
197-
ASTNode fn_node = root.children.at(3);
198-
std::string fn_name = fn_node.lexeme;
199-
200-
if (module_table.count(module_name)) {
201-
auto module_symbol_table = module_table[module_name];
202-
if (module_symbol_table.count(fn_name)) {
203-
if (auto typed_procedure =
204-
std::dynamic_pointer_cast<TypedProcedure>(
205-
module_symbol_table[fn_name]
206-
)) {
207-
ASTNode optargs = root.children.at(5);
208-
std::vector<TypedExpr> typed_args = visit_optargs(
209-
optargs,
210-
symbol_table,
211-
module_table,
212-
static_data
213-
);
214-
215-
if (typed_args.size() != typed_procedure->params.size()) {
216-
throw TypeMismatchError(
217-
"Mismatched number of parameters for function call.",
218-
fn_node.line_no
219-
);
220-
}
221-
222-
for (size_t i = 0; i < typed_args.size(); ++i) {
223-
if ((*typed_args.at(i).nl_type)
224-
!= (*typed_procedure->params.at(i)->nl_type)) {
225-
throw TypeMismatchError(
226-
"Type mismatch of parameters for function call.",
227-
fn_node.line_no
228-
);
229-
}
230-
}
231-
232-
std::vector<std::shared_ptr<Code>> args;
233-
for (auto typed_arg : typed_args) {
234-
args.push_back(typed_arg.code);
235-
}
236-
result = TypedExpr {
237-
make_call(typed_procedure->procedure, args),
238-
typed_procedure->ret_type};
239-
} else {
240-
throw SymbolNotFoundError(fn_name, fn_node.line_no);
241-
}
242-
} else {
243-
throw SymbolNotFoundError(fn_name, fn_node.line_no);
244-
}
245-
}
246-
247192
} else if (prod == std::vector<State> {NonTerminal::exprp8, NonTerminal::exprp8, Terminal::AS, NonTerminal::type}) {
248193
ASTNode expr = root.children.at(0);
249194
TypedExpr expr_code = visit_expr(
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
2+
#include "visit_imports.h"
3+
4+
#include <cassert>
5+
#include <iostream>
6+
7+
void visit_imports(
8+
ASTNode root,
9+
SymbolTable& symbol_table,
10+
ModuleTable& module_table
11+
) {
12+
assert(std::get<NonTerminal>(root.state) == NonTerminal::imports);
13+
14+
std::vector<State> prod = root.get_production();
15+
if (prod == std::vector<State> {NonTerminal::imports}) {
16+
// No more imports
17+
} else if (prod == std::vector<State> {NonTerminal::imports, NonTerminal::import, NonTerminal::imports}) {
18+
ASTNode import = root.children.at(0);
19+
visit_import(import, symbol_table, module_table);
20+
21+
ASTNode imports = root.children.at(1);
22+
visit_imports(imports, symbol_table, module_table);
23+
} else {
24+
std::cerr << "TODO" << std::endl;
25+
exit(1);
26+
}
27+
}
28+
29+
void visit_import(
30+
ASTNode root,
31+
SymbolTable& symbol_table,
32+
ModuleTable& module_table
33+
) {
34+
assert(std::get<NonTerminal>(root.state) == NonTerminal::import);
35+
36+
std::vector<State> prod = root.get_production();
37+
if (prod
38+
== std::vector<State> {
39+
NonTerminal::import,
40+
Terminal::IMPORT,
41+
Terminal::ID,
42+
Terminal::SEMI}) {
43+
ASTNode id = root.children.at(1);
44+
std::string name = id.lexeme;
45+
46+
SymbolTable& module_symbol_table = module_table.at(name);
47+
symbol_table.insert(
48+
module_symbol_table.begin(),
49+
module_symbol_table.end()
50+
);
51+
} else {
52+
std::cerr << "TODO" << std::endl;
53+
exit(1);
54+
}
55+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
2+
#pragma once
3+
4+
#include "ast_node.h"
5+
#include "module_table.h"
6+
#include "symbol_table.h"
7+
8+
void visit_imports(
9+
ASTNode root,
10+
SymbolTable& symbol_table,
11+
ModuleTable& module_table
12+
);
13+
void visit_import(
14+
ASTNode root,
15+
SymbolTable& symbol_table,
16+
ModuleTable& module_table
17+
);

src/nex_lang/post_processing/visit_s.cc

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "state.h"
1313
#include "symbol_table.h"
1414
#include "visit_fns.h"
15+
#include "visit_imports.h"
1516

1617
struct Code;
1718

@@ -28,19 +29,21 @@ std::vector<std::shared_ptr<TypedProcedure>> visit_s(
2829
== std::vector<State> {
2930
NonTerminal::s,
3031
Terminal::BOFS,
31-
Terminal::MODULE,
32-
Terminal::ID,
33-
Terminal::SEMI,
32+
NonTerminal::module,
33+
NonTerminal::imports,
3434
NonTerminal::fns,
3535
Terminal::EOFS}) {
3636
// extract functions of program
3737

38-
ASTNode module = root.children.at(2);
39-
std::string name = module.lexeme;
38+
ASTNode module = root.children.at(1);
39+
std::string name = module.children.at(1).lexeme;
4040

4141
SymbolTable symbol_table = module_table.at(name);
4242

43-
ASTNode fns = root.children.at(4);
43+
ASTNode imports = root.children.at(2);
44+
visit_imports(imports, symbol_table, module_table);
45+
46+
ASTNode fns = root.children.at(3);
4447
result = visit_fns(fns, symbol_table, module_table, static_data);
4548
} else {
4649
std::cerr << "Invalid production found while processing s."

src/utils/state.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ enum class Terminal {
77
BOFS,
88
EOFS,
99
MODULE,
10+
IMPORT,
1011
FN,
1112
ID,
1213
LPAREN,
@@ -61,6 +62,9 @@ enum class Terminal {
6162

6263
enum class NonTerminal {
6364
s,
65+
module,
66+
imports,
67+
import,
6468
fns,
6569
fn,
6670
optparams,

tests/nl_files/fibonacci_module.nl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11

22
mod fib;
33

4+
import print;
5+
46
fn calc_fibonacci(count: i32) {
57
let found: i32 = 0;
68
let curFib: i32 = 0;
79
let nextFib: i32 = 1;
810
while (found < count) {
9-
print::printnum(curFib);
10-
print::print(" ");
11+
printnum(curFib);
12+
print(" ");
1113
let newNextFib: i32 = curFib + nextFib;
1214
curFib = nextFib;
1315
nextFib = newNextFib;

0 commit comments

Comments
 (0)