Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions spec/parser_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -905,4 +905,67 @@ describe("Parser /", function()
end)

end)

describe("Type declaration file / ", function()

local tdf = [[
typealias A = integer
record B
x: A
end
f: (A) -> B
g: (A, (A) -> B) -> (B, A)
tbl: { f1: integer, f2: B }
constant: float
typealias C = { x: A, y: B }
]]

it("should be parsed correctly", function()
local ast, errs = driver.compile_internal("__test__.d.pln", tdf, "ast")

assert(ast)
assert.falsy(errs)
assert(ast._tag == "ast.TypeFile.Decls")

local nodes_expected = {
"ast.TypeFile.Typealias",
"ast.TypeFile.Record",
"ast.TypeFile.Decl",
"ast.TypeFile.Decl",
"ast.TypeFile.Decl",
"ast.TypeFile.Decl",
"ast.TypeFile.Typealias"
}

for i, expected in ipairs(nodes_expected) do
local actual = ast.decls[i]._tag
assert.are.equal(expected, actual)
end
end)

-- TODO: Move this test to the type checker
it("should have the _type field added to nodes", function()
local ast, errs = driver.compile_internal("__test__.d.pln", tdf, "typechecker")

assert(ast)
assert.falsy(errs)
assert(ast._tag == "ast.TypeFile.Decls")

local nodes_types_expected = {
"types.T.Alias",
"types.T.Record",
"types.T.Function",
"types.T.Function",
"types.T.Table",
"types.T.Float",
"types.T.Alias",
}

for i, expected in ipairs(nodes_types_expected) do
local actual = ast.decls[i]._type._tag
assert.are.equal(expected, actual)
end
end)
end)

end)
7 changes: 7 additions & 0 deletions src/pallene/ast.lua
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ define_union("Program", {
Program = {"loc", "ret_loc", "module_name", "tls", "type_regions"}
})

define_union("TypeFile", {
Decls = { "loc", "module_name", "decls" },
Typealias = { "loc", "name", "type" },
Record = { "loc", "name", "field_decls" },
Decl = { "loc", "name", "type" },
})

define_union("Type", {
Nil = {"loc"},
Name = {"loc", "name"},
Expand Down
23 changes: 22 additions & 1 deletion src/pallene/driver.lua
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ local type_extractor = require "pallene.type_extractor"
local driver = {}

local function check_source_filename(argv0, file_name, expected_ext)
local name, ext = util.split_ext(file_name)
local name, ext = util.split_pallene_ext(file_name)
if ext ~= expected_ext then
local msg = string.format("%s: %s does not have a .%s extension",
argv0, file_name, expected_ext)
Expand Down Expand Up @@ -200,6 +200,25 @@ local function compile_pln_to_d_pln(input_ext, output_ext, input_file_name, base
return true, {}
end

function driver.parse_type_file(path)
local base_name, err = check_source_filename("pallenec", path, "d.pln")
if not base_name then
return false, { err }
end
local input, err = util.get_file_contents(path)
if not input then
return false, { err }
end

local ast
ast, err = driver.compile_internal(path, input, "typechecker")
if not ast then
return false, { err }
end

return ast, {}
end


-- Compile the contents of [input_file_name] with extension [input_ext].
-- Writes the resulting output to [output_file_name] with extension [output_ext].
Expand All @@ -218,6 +237,8 @@ function driver.compile(argv0, opt_level, input_ext, output_ext,

if output_ext == "lua" then
return compile_pln_to_lua(input_ext, output_ext, input_file_name, output_base_name)
elseif output_ext == "d.pln" then
return compile_pln_to_d_pln(input_ext, output_ext, input_file_name, output_base_name)
else
local first_step = step_index[input_ext] or error("invalid extension")
local last_step = step_index[output_ext] or error("invalid extension")
Expand Down
2 changes: 2 additions & 0 deletions src/pallene/pallenec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ do
p:mutex(
p:flag("--emit-c", "Generate a .c file instead of an executable"),
p:flag("--emit-lua", "Generate a .lua file instead of an executable"),
p:flag("--emit-types", "Generate a .d.pln file instead of an executable"),
p:flag("--compile-c", "Compile a .c file generated by --emit-c"),
p:flag("--only-check", "Check for syntax or type errors, without compiling"),
p:flag("--print-ir", "Show the intermediate representation for a program"),
Expand Down Expand Up @@ -91,6 +92,7 @@ function pallenec.main()

if opts.emit_c then compile("pln", "c", flags)
elseif opts.emit_lua then compile("pln", "lua", flags)
elseif opts.emit_types then compile("pln", "d.pln", flags)
elseif opts.compile_c then compile("c" , "so", flags)
elseif opts.only_check then do_check()
elseif opts.print_ir then do_print_ir()
Expand Down
54 changes: 53 additions & 1 deletion src/pallene/parser.lua
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,11 @@ local is_toplevel_first = Union({
is_stat_first,
})

local is_type_declaration_file_first = Union({
is_toplevel_keyword,
Set 'NAME',
})

--
-- Toplevel
--
Expand Down Expand Up @@ -303,6 +308,40 @@ function Parser:Program()
start_loc, end_loc, modname, tls, self.type_regions)
end

function Parser:TypeDeclarationFile(modname)

local start_loc = self.next.loc

-- type declarations
local decls = {}

while self:peek(is_type_declaration_file_first) do
if (self:peek(is_toplevel_keyword)) then
local tl = self:Toplevel()
if tl._tag == "ast.Toplevel.Typealias" then
tl = ast.TypeFile.Typealias(tl.loc, tl.name, tl.type)
elseif tl._tag == "ast.Toplevel.Record" then
tl = ast.TypeFile.Record(tl.loc, tl.name, tl.field_decls)
end
table.insert(decls, tl)
elseif self:peek("NAME") then
local decl = self:Decl()
if not decl.type then
self:recoverable_syntax_error(decl.loc,
"type annotation expected in type declaration file")
end
decl = ast.TypeFile.Decl(decl.loc, decl.name, decl.type)
table.insert(decls, decl)
else
self:unexpected_token_error("a type declaration")
end
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we encounter an invalid token, the while loop will exit and this function will return as if there was not a problem.

  1. I think the while should be "while not end-of-file"
  2. Please create test cases for the "expected a type declaration" errors

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the while should be "while not end-of-file"

Should we still throw an error if we don't find at least one declaration? If so, we can go with a repeat-until

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should allow empty files, in case the module exports nothing. (Also test that...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to write the error test in a assert_program_error() fashion, but I couldn't get it to work. So I wrote it another way (src/spec/parser_spec.lua:979-990)


return ast.TypeFile.Decls(
start_loc, modname, decls
)
end

local is_allowed_toplevel = Set [[
ast.Stat.Decl
ast.Stat.Assign
Expand Down Expand Up @@ -1171,8 +1210,21 @@ function parser.parse(lexer)

local p = Parser.new(lexer)

local filename = lexer.file_name
local name, ext = util.split_pallene_ext(filename)
local ok, ret = trycatch.pcall(function()
return p:Program()
if (ext == "pln") then
return p:Program()
elseif (ext == "d.pln") then
return p:TypeDeclarationFile(name)
else
local msg = string.format(
"file-error: unknown file extension in '%s'; expected '.pln' or '.d.pln'",
filename
)
table.insert(p.errors, msg)
trycatch.error("file-error")
end
end)

-- Re-throw internal errors
Expand Down
58 changes: 57 additions & 1 deletion src/pallene/typechecker.lua
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,14 @@ local Typechecker = util.Class()
-- On failure, returns false and a list of compilation errors
function typechecker.check(prog_ast)
local ok, ret = trycatch.pcall(function()
return Typechecker.new():check_program(prog_ast)
local type_checker = Typechecker.new()
if prog_ast._tag == "ast.Program.Program" then
return type_checker:check_program(prog_ast)
elseif prog_ast._tag == "ast.TypeFile.Decls" then
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This kind of if-else-if doesn't feel right. The type (ast.Program) should be the same in all branches.

I suspect that you actually want to have separate type checking functions instead of a catch-all typechecker.check

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same goes to Parser.parse()? Or is it ok for that case?

If I understand it correctly, you are suggesting that we have a new function in the typechecker and the responsibility to choose which one to call must be in driver.compile_internal(). Right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

return type_checker:check_type_file(prog_ast)
else
error("unexpected AST root: " .. tostring(prog_ast._tag))
end
end)
if ok then
prog_ast = ret
Expand Down Expand Up @@ -317,6 +324,55 @@ function Typechecker:check_program(prog_ast)
return prog_ast
end

function Typechecker:check_type_file(prog_ast)

assert(prog_ast._tag == "ast.TypeFile.Decls")

-- 1) Add primitive types to the symbol table
self:add_type_symbol("any", types.T.Any)
self:add_type_symbol("boolean", types.T.Boolean)
self:add_type_symbol("float", types.T.Float)
self:add_type_symbol("integer", types.T.Integer)
self:add_type_symbol("string", types.T.String)

-- Check toplevel
for _, decl in ipairs(prog_ast.decls) do
local tag = decl._tag

if tag == "ast.TypeFile.Typealias" then
local typ = types.T.Alias(decl.name, self:from_ast_type(decl.type))
-- self:export_type_symbol(decl.name, typ, decl.loc)
self:add_type_symbol(decl.name, typ)
decl._type = typ

elseif tag == "ast.TypeFile.Record" then
local field_names = {}
local field_types = {}
for _, field_decl in ipairs(decl.field_decls) do
local field_name = field_decl.name
if field_types[field_name] then
type_error(decl.loc, "duplicate field name '%s' in record type", field_name)
end
table.insert(field_names, field_name)
field_types[field_name] = self:from_ast_type(field_decl.type)
end

local typ = types.T.Record(decl.name, field_names, field_types, false)
self:add_type_symbol(decl.name, typ)

decl._type = typ

elseif tag == "ast.TypeFile.Decl" then
local typ = self:from_ast_type(decl.type)
decl._type = typ
else
tagged_union.error(tag)
end
end

return prog_ast
end

-- If the last expression in @rhs is a function call that returns multiple values, add ExtraRet
-- nodes to the end of the list.
function Typechecker:expand_function_returns(rhs)
Expand Down
54 changes: 54 additions & 0 deletions src/pallene/util.lua
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,56 @@ function util.split_ext(file_name)
return name, ext
end

--- Return a list consisting of the base name and all extensions of a file
--- e.g. "file.d.pln" -> { "pln", "d", "file" }
---
--- @param file_name string
--- @return [string]
---
function util.split_all_ext(file_name)
local name, ext = util.split_ext(file_name)
if not ext then
return { file_name }
else
local exts = util.split_all_ext(name)
table.insert(exts, ext)
return exts
end
end

local recognized_extensions = {
["pln"] = true,
["d.pln"] = true,
["c"] = true,
["lua"] = true,
["so"] = true,
}

--- Splits the file name into two parts: the base name, the Pallene extensions.
--- If the file is not known to Pallene, returns nil.
---
---
--- e.g.
--- - "file.pln" -> "file", "pln"
--- - "file.d.pln" -> "file", "d.pln"
--- - "dotted.file.d.pln" -> "dotted.file", "d.pln"
--- - "file.txt" -> nil
---
---@param file_name string
---@return string | nil
---@return string | nil
---
function util.split_pallene_ext(file_name)
local parts = util.split_all_ext(file_name)
if #parts >= 3 and parts[#parts] == "pln" and parts[#parts - 1] == "d" then
local base_name = table.concat(parts, ".", 1, #parts - 2)
return base_name, "d.pln"
elseif #parts >= 2 and recognized_extensions[parts[#parts]] then
local base_name = table.concat(parts, ".", 1, #parts - 1)
return base_name, parts[#parts]
end
end

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it would be simpler to not split the string. We can could the string itself to see if it ends in ".pln" or ".d.pln".

Copy link
Contributor Author

@igrvlhlb igrvlhlb Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. Also now I see that logic is not necessary, as "." is not an allowed character in file names.

So what I plan to do is this:

  • get rid of that new function util.split_pallene_ext() (and split_all_ext)
  • keep using util.split_ext with the following modification
 function util.split_ext(file_name)
-    local name, ext = string.match(file_name, "(.*)%.(.*)")
+    local name, ext = string.match(file_name, "(.-)%.(.*)")
     return name, ext
 end

That way string.match() will match name only until the first '.', and ext will be the rest of the string. Is that ok?

function util.get_file_contents(file_name)
local f, err = io.open(file_name, "r")
if not f then
Expand Down Expand Up @@ -126,6 +176,10 @@ function util.Class()
return cls
end

--
-- General purpose utilities
--

function util.expand_type_aliases(ast_node, visited)
local types = require "pallene.types"
visited = visited or {}
Expand Down