Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions spec/parser_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -905,4 +905,90 @@ describe("Parser /", function()
end)

end)

describe("Type declaration file / ", function()

local tdf = [[
typealias A = integer
record B
x: A
end
f: (A) -> B
g: (A, (A) -> B) -> (B, A)
tbl: { f1: integer, f2: B }
constant: float
typealias C = { x: A, y: B }
]]

it("should be parsed correctly", function()
local ast, errs = driver.compile_internal_d_pln("__test__.d.pln", tdf)

assert(ast)
assert.falsy(errs)
assert(ast._tag == "ast.TypeFile.TypeFile")

local nodes_expected = {
"ast.TypeFile.Typealias",
"ast.TypeFile.Record",
"ast.TypeFile.Decl",
"ast.TypeFile.Decl",
"ast.TypeFile.Decl",
"ast.TypeFile.Decl",
"ast.TypeFile.Typealias"
}

for i, expected in ipairs(nodes_expected) do
local actual = ast.decls[i]._tag
assert.are.equal(expected, actual)
end
end)

-- TODO: Move this test to the type checker
it("should have the _type field added to nodes", function()
local ast, errs = driver.compile_internal_d_pln("__test__.d.pln", tdf)

assert(ast)
assert.falsy(errs)
assert(ast._tag == "ast.TypeFile.TypeFile")

local nodes_types_expected = {
"types.T.Alias",
"types.T.Record",
"types.T.Function",
"types.T.Function",
"types.T.Table",
"types.T.Float",
"types.T.Alias",
}

for i, expected in ipairs(nodes_types_expected) do
local actual = ast.decls[i]._type._tag
assert.are.equal(expected, actual)
end
end)

it("can be empty", function()
local ast, errs = driver.compile_internal_d_pln("__test__.d.pln", "")

assert(ast)
assert.falsy(errs)
assert(ast._tag == "ast.TypeFile.TypeFile")
assert.are.equal(0, #ast.decls)
end)

it("should report unexpected token errors", function()
local unexpected_token = "#"
local file_content = tdf .. "\n" .. unexpected_token

local type_file_ast, errors = driver.compile_internal_d_pln("__test__.d.pln", file_content, "ast")

assert.falsy(type_file_ast)
assert.truthy(string.find(
errors[1],
string.format("unexpected '%s' while trying to parse a type declaration", unexpected_token)
, 1, true))
end)

end)

end)
7 changes: 7 additions & 0 deletions src/pallene/ast.lua
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ define_union("Program", {
Program = {"loc", "ret_loc", "module_name", "tls", "type_regions"}
})

define_union("TypeFile", {
TypeFile = { "loc", "module_name", "decls" },
Typealias = { "loc", "name", "type" },
Record = { "loc", "name", "field_decls" },
Decl = { "loc", "name", "type" },
})

define_union("Type", {
Nil = {"loc"},
Name = {"loc", "name"},
Expand Down
46 changes: 46 additions & 0 deletions src/pallene/driver.lua
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,50 @@ local function compile_pln_to_d_pln(input_ext, output_ext, input_file_name, base
return true, {}
end

function driver.compile_internal_d_pln(filename, input, stop_after)
stop_after = stop_after or "typechecker"

local errs

local function abort()
if type(errs) == "string" then errs = { errs } end
table.insert(errs, "compilation aborted due to previous error")
return false, errs
end

local lexer = Lexer.new(filename, input)
if stop_after == "lexer" then return lexer end

local ast
ast, errs = parser.parse_type_file(lexer)
if not ast then return abort() end
if stop_after == "ast" then return ast end

ast, errs = typechecker.check_type_file(ast)
if not ast then return abort() end
if stop_after == "typechecker" then return ast end
error("impossible")
end

function driver.compile_type_file(path)
local base_name, err = check_source_filename("pallenec", path, "d.pln")
if not base_name then
return false, { err }
end
local input, err = util.get_file_contents(path)
if not input then
return false, { err }
end

local ast
ast, err = driver.compile_internal_d_pln(path, input)
if not ast then
return false, { err }
end

return ast, {}
end


-- Compile the contents of [input_file_name] with extension [input_ext].
-- Writes the resulting output to [output_file_name] with extension [output_ext].
Expand All @@ -218,6 +262,8 @@ function driver.compile(argv0, opt_level, input_ext, output_ext,

if output_ext == "lua" then
return compile_pln_to_lua(input_ext, output_ext, input_file_name, output_base_name)
elseif output_ext == "d.pln" then
return compile_pln_to_d_pln(input_ext, output_ext, input_file_name, output_base_name)
else
local first_step = step_index[input_ext] or error("invalid extension")
local last_step = step_index[output_ext] or error("invalid extension")
Expand Down
2 changes: 2 additions & 0 deletions src/pallene/pallenec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ do
p:mutex(
p:flag("--emit-c", "Generate a .c file instead of an executable"),
p:flag("--emit-lua", "Generate a .lua file instead of an executable"),
p:flag("--emit-types", "Generate a .d.pln file instead of an executable"),
p:flag("--compile-c", "Compile a .c file generated by --emit-c"),
p:flag("--only-check", "Check for syntax or type errors, without compiling"),
p:flag("--print-ir", "Show the intermediate representation for a program"),
Expand Down Expand Up @@ -91,6 +92,7 @@ function pallenec.main()

if opts.emit_c then compile("pln", "c", flags)
elseif opts.emit_lua then compile("pln", "lua", flags)
elseif opts.emit_types then compile("pln", "d.pln", flags)
elseif opts.compile_c then compile("c" , "so", flags)
elseif opts.only_check then do_check()
elseif opts.print_ir then do_print_ir()
Expand Down
57 changes: 57 additions & 0 deletions src/pallene/parser.lua
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,40 @@ function Parser:Program()
start_loc, end_loc, modname, tls, self.type_regions)
end

function Parser:TypeDeclarationFile(modname)

local start_loc = self.next.loc

-- type declarations
local decls = {}

while not self:peek("EOF") do
if (self:peek(is_toplevel_keyword)) then
local tl = self:Toplevel()
if tl._tag == "ast.Toplevel.Typealias" then
tl = ast.TypeFile.Typealias(tl.loc, tl.name, tl.type)
elseif tl._tag == "ast.Toplevel.Record" then
tl = ast.TypeFile.Record(tl.loc, tl.name, tl.field_decls)
end
table.insert(decls, tl)
elseif self:peek("NAME") then
local decl = self:Decl()
if not decl.type then
self:recoverable_syntax_error(decl.loc,
"type annotation expected in type declaration file")
end
decl = ast.TypeFile.Decl(decl.loc, decl.name, decl.type)
table.insert(decls, decl)
else
self:unexpected_token_error("a type declaration")
end
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we encounter an invalid token, the while loop will exit and this function will return as if there was not a problem.

  1. I think the while should be "while not end-of-file"
  2. Please create test cases for the "expected a type declaration" errors

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the while should be "while not end-of-file"

Should we still throw an error if we don't find at least one declaration? If so, we can go with a repeat-until

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should allow empty files, in case the module exports nothing. (Also test that...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to write the error test in a assert_program_error() fashion, but I couldn't get it to work. So I wrote it another way (src/spec/parser_spec.lua:979-990)


return ast.TypeFile.TypeFile(
start_loc, modname, decls
)
end

local is_allowed_toplevel = Set [[
ast.Stat.Decl
ast.Stat.Assign
Expand Down Expand Up @@ -1191,4 +1225,27 @@ function parser.parse(lexer)
end
end

function parser.parse_type_file(lexer)
local p = Parser.new(lexer)

local ok, ret = trycatch.pcall(function()
return p:TypeDeclarationFile()
end)

-- Re-throw internal errors
if not ok and ret.tag ~= "syntax-error" then
error(ret)
end

if p.errors[1] then
-- Had syntax errors
return false, p.errors
else
-- No syntax errors
assert(ok)
local prog_ast = ret
return prog_ast, {}
end
end

return parser
67 changes: 67 additions & 0 deletions src/pallene/typechecker.lua
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,24 @@ function typechecker.check(prog_ast)
end
end

function typechecker.check_type_file(type_decl_ast)
local ok, ret = trycatch.pcall(function()
return Typechecker.new():check_type_file(type_decl_ast)
end)
if ok then
type_decl_ast = ret
return type_decl_ast, {}
else
if ret.tag == "typechecker" then
local err_msg = ret.msg
return false, { err_msg }
else
-- Internal error; re-throw
error(ret)
end
end
end

local function type_error(loc, fmt, ...)
local msg = "type error: " .. loc:format_error(fmt, ...)
trycatch.error("typechecker", msg)
Expand Down Expand Up @@ -317,6 +335,55 @@ function Typechecker:check_program(prog_ast)
return prog_ast
end

function Typechecker:check_type_file(prog_ast)

assert(prog_ast._tag == "ast.TypeFile.TypeFile")

-- 1) Add primitive types to the symbol table
self:add_type_symbol("any", types.T.Any)
self:add_type_symbol("boolean", types.T.Boolean)
self:add_type_symbol("float", types.T.Float)
self:add_type_symbol("integer", types.T.Integer)
self:add_type_symbol("string", types.T.String)

-- Check toplevel
for _, decl in ipairs(prog_ast.decls) do
local tag = decl._tag

if tag == "ast.TypeFile.Typealias" then
local typ = types.T.Alias(decl.name, self:from_ast_type(decl.type))
-- self:export_type_symbol(decl.name, typ, decl.loc)
self:add_type_symbol(decl.name, typ)
decl._type = typ

elseif tag == "ast.TypeFile.Record" then
local field_names = {}
local field_types = {}
for _, field_decl in ipairs(decl.field_decls) do
local field_name = field_decl.name
if field_types[field_name] then
type_error(decl.loc, "duplicate field name '%s' in record type", field_name)
end
table.insert(field_names, field_name)
field_types[field_name] = self:from_ast_type(field_decl.type)
end

local typ = types.T.Record(decl.name, field_names, field_types, false)
self:add_type_symbol(decl.name, typ)

decl._type = typ

elseif tag == "ast.TypeFile.Decl" then
local typ = self:from_ast_type(decl.type)
decl._type = typ
else
tagged_union.error(tag)
end
end

return prog_ast
end

-- If the last expression in @rhs is a function call that returns multiple values, add ExtraRet
-- nodes to the end of the list.
function Typechecker:expand_function_returns(rhs)
Expand Down
6 changes: 5 additions & 1 deletion src/pallene/util.lua
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ end
--

function util.split_ext(file_name)
local name, ext = string.match(file_name, "(.*)%.(.*)")
local name, ext = string.match(file_name, "(.-)%.(.*)")
return name, ext
end

Expand Down Expand Up @@ -126,6 +126,10 @@ function util.Class()
return cls
end

--
-- General purpose utilities
--

function util.expand_type_aliases(ast_node, visited)
local types = require "pallene.types"
visited = visited or {}
Expand Down