diff --git a/LICENSE b/LICENSE index 482e394..5732c30 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2024 Claire Foster and contributors +Copyright (c) 2024 Julia Computing and contributors Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/Project.toml b/Project.toml index 8848348..d7b465c 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,9 @@ uuid = "f3c80556-a63f-4383-b822-37d64f81a311" authors = ["Claire Foster and contributors"] version = "1.0.0-DEV" +[deps] +JuliaSyntax = "70703baa-626e-46a2-a12c-08ffd08c73b4" + [compat] julia = "1" diff --git a/README.md b/README.md index 7d8408a..4a123fe 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,220 @@ # JuliaLowering [![Build Status](https://github.com/c42f/JuliaLowering.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/c42f/JuliaLowering.jl/actions/workflows/CI.yml?query=branch%3Amain) + +Experimental port of Julia's code "lowering" compiler passes into Julia. + +Lowering comprises four symbolic simplification steps +* Syntax desugaring - simplifying the rich surface syntax down to a small + number of forms. +* Scope analysis - analyzing identifier names used in the code to discover + local variables, closure captures, and associate global variables to the + appropriate module. +* Closure conversion - convert closures to types and deal with captured + variables efficiently where possible. +* Flattening to linear IR - convert code in hierarchical tree form to a + flat array of statements and control flow into gotos. + +## Goals + +This work is intended to +* Bring precise code provenance to Julia's lowered form (and eventually + downstream in type inference, stack traces, etc). This has many benefits + - Talk to users precisely about their code via character-precise error and + diagnostic messages from lowering + - Greatly simplify the implementation of critical tools like Revise.jl + which rely on analyzing how the user's source maps to the compiler's data + structures + - Allow tools like JuliaInterpreter to use type-inferred and optimized + code, with the potential for huge speed improvements. +* Bring improvements for macro authors + - Prototype "automatic hygiene" (no more need for `esc()`!) + - Precise author-defined error reporting from macros + - Sketch better interfaces for syntax trees (hopefully!) + +# Design Notes + +A disorganized collection of design notes :) + +## Syntax trees + +Want something something better than `JuliaSyntax.SyntaxNode`! `SyntaxTree` and +`SyntaxGraph` provide this. (These will probably end up in `JuliaSyntax`.) + +We want to allow arbitrary attributes to be attached to tree nodes by analysis +passes. This separates the analysis pass implementation from the data +structure, allowing passes which don't know about each other to act on a shared +data structure. + +Design and implementation inspiration comes in several analogies: + +Analogy 1: the ECS (Entity-Component-System) pattern for computer game design. +This pattern is highly successful because it separates game logic (systems) +from game objects (entities) by providing flexible storage +* Compiler passes are "systems" +* AST tree nodes are "entities" +* Node attributes are "components" + +Analogy 2: The AoS to SoA transformation. But here we've got a kind of +tree-of-structs-with-optional-attributes to struct-of-Dicts transformation. +The data alignment / packing efficiency and concrete type safe storage benefits +are similar. + +Analogy 3: Graph algorithms which represent graphs as a compact array of node +ids and edges with integer indices, rather than using a linked data structure. + +## Julia's existing lowering implementation + +### How does macro expansion work? + +`macroexpand(m::Module, x)` calls `jl_macroexpand` in ast.c: + +``` +jl_value_t *jl_macroexpand(jl_value_t *expr, jl_module_t *inmodule) +{ + expr = jl_copy_ast(expr); + expr = jl_expand_macros(expr, inmodule, NULL, 0, jl_world_counter, 0); + expr = jl_call_scm_on_ast("jl-expand-macroscope", expr, inmodule); + return expr; +} +``` + +First we copy the AST here. This is mostly a trivial deep copy of `Expr`s and +shallow copy of their non-`Expr` children, except for when they contain +embedded `CodeInfo/phi/phic` nodes which are also deep copied. + +Second we expand macros recursively by calling + +`jl_expand_macros(expr, inmodule, macroctx, onelevel, world, throw_load_error)` + +This relies on state indexed by `inmodule` and `world`, which gives it some +funny properties: +* `module` expressions can't be expanded: macro expansion depends on macro + lookup within the module, but we can't do that without `eval`. + +Expansion proceeds from the outermost to innermost macros. So macros see any +macro calls or quasiquote (`quote/$`) in their children as unexpanded forms. + +Things which are expanded: +* `quote` is expanded using flisp code in `julia-bq-macro` + - symbol / ssavalue -> `QuoteNode` (inert) + - atom -> itself + - at depth zero, `$` expands to its content + - Expressions `x` without `$` expand to `(copyast (inert x))` + - Other expressions containing a `$` expand to a call to `_expr` with all the + args mapped through `julia-bq-expand-`. Roughly! + - Special handling exists for multi-splatting arguments as in `quote quote $$(x...) end end` +* `macrocall` proceeds with + - Expand with `jl_invoke_julia_macro` + - Call `eval` on the macro name (!!) to get the macro function. Look up + the method. + - Set up arguments for the macro calling convention + - Wraps errors in macro invocation in `LoadError` + - Returns the expression, as well as the module at + which that method of that macro was defined and `LineNumberNode` where + the macro was invoked in the source. + - Deep copy the AST + - Recursively expand child macros in the context of the module where the + macrocall method was defined + - Wrap the result in `(hygienic-scope ,result ,newctx.m ,lineinfo)` (except + for special case optimizations) +* `hygenic-scope` expands `args[1]` with `jl_expand_macros`, with the module + of expansion set to `args[2]`. Ie, it's the `Expr` representation of the + module and expression arguments to `macroexpand`. The way this returns + either `hygenic-scope` or unwraps is a bit confusing. +* "`do` macrocalls" have their own special handling because the macrocall is + the child of the `do`. This seems like a mess!! + + +### Scope resolution + +Scopes are documented in the Juila documentation on [Scope of Variables](https://docs.julialang.org/en/v1/manual/variables-and-scoping/) + +This pass disambiguates variables which have the same name in different scopes +and fills in the list of local variables within each lambda. + +#### Which data is needed to define a scope? + +As scope is a collection of variable names by category: +* `argument` - arguments to a lambda +* `local` - variables declared local (at top level) or implicitly local (in lambdas) or desugared to local-def +* `global` - variables declared global (in lambdas) or implicitly global (at top level) +* `static-parameter` - lambda type arguments from `where` clauses + +#### How does scope resolution work? + +We traverse the AST starting at the root paying attention to certian nodes: +* Nodes representing identifiers (Identifier, operators, var) + - If a variable exists in the table, it's *replaced* with the value in the table. + - If it doesn't exist, it becomes an `outerref` +* Variable scoping constructs: `local`, `local-def` + - collected by scope-block + - removed during traversal +* Scope metadata `softscope`, `hardscope` - just removed +* New scopes + - `lambda` creates a new scope containing itself and its arguments, + otherwise copying the parent scope. It resolves the body with that new scope. + - `scope-block` is really complicated - see below +* Scope queries `islocal`, `locals` + - `islocal` - statically expand to true/false based on whether var name is a local var + - `locals` - return list of locals - see `@locals` + - `require-existing-local` - somewhat like `islocal`, but allows globals + too (whaa?! naming) and produces a lowering error immediately if variable + is not known. Should be called `require-in-scope` ?? +* `break-block`, `symbolicgoto`, `symboliclabel` need special handling because + one of their arguments is a non-quoted symbol. +* Add static parameters for generated functions `with-static-parameters` +* `method` - special handling for static params + +`scope-block` is the complicated bit. It's processed by +* Searching the expressions within the block for any `local`, `local-def`, + `global` and assigned vars. Searching doesn't recurse into `lambda`, + `scope-block`, `module` and `toplevel` +* Building lists of implicit locals or globals (depending on whether we're in a + top level thunk) +* Figuring out which local variables need to be renamed. This is any local variable + with a name which has already occurred in processing one of the previous scope blocks +* Check any conflicting local/global decls and soft/hard scope +* Build new scope with table of renames +* Resolve the body with the new scope, applying the renames + + +### Lowered IR + +See https://docs.julialang.org/en/v1/devdocs/ast/#Lowered-form + +#### CodeInfo + +```julia +mutable struct CodeInfo + code::Vector{Any} # IR statements + codelocs::Vector{Int32} # `length(code)` Vector of indices into `linetable` + ssavaluetypes::Any # `length(code)` or Vector of inferred types after opt + ssaflags::Vector{UInt32} # flag for every statement in `code` + # 0 if meta statement + # inbounds_flag - 1 bit (LSB) + # inline_flag - 1 bit + # noinline_flag - 1 bit + # ... other 8 flags which are defined in compiler/optimize.jl + # effects_flags - 9 bits + method_for_inference_limit_heuristics::Any + linetable::Any + slotnames::Vector{Symbol} # names of parameters and local vars used in the code + slotflags::Vector{UInt8} # vinfo flags from flisp + slottypes::Any # nothing (used by typeinf) + rettype::Any # Any (used by typeinf) + parent::Any # nothing (used by typeinf) + edges::Any + min_world::UInt64 + max_world::UInt64 + inferred::Bool + propagate_inbounds::Bool + has_fcall::Bool + nospecializeinfer::Bool + inlining::UInt8 + constprop::UInt8 + purity::UInt16 + inlining_cost::UInt16 +end +``` + diff --git a/src/JuliaLowering.jl b/src/JuliaLowering.jl index 65ab2a0..a4f6cd5 100644 --- a/src/JuliaLowering.jl +++ b/src/JuliaLowering.jl @@ -1,5 +1,18 @@ module JuliaLowering -# Write your package code here. +using JuliaSyntax + +using JuliaSyntax: SyntaxHead, highlight, Kind, GreenNode, @KSet_str +using JuliaSyntax: haschildren, children, child, numchildren, head, kind, flags +using JuliaSyntax: filename, first_byte, last_byte, source_location + +using JuliaSyntax: is_literal, is_number, is_operator, is_prec_assignment, is_infix_op_call, is_postfix_op_call + +include("syntax_graph.jl") +include("utils.jl") + +include("desugaring.jl") +include("scope_analysis.jl") +include("linear_ir.jl") end diff --git a/src/desugaring.jl b/src/desugaring.jl new file mode 100644 index 0000000..4529781 --- /dev/null +++ b/src/desugaring.jl @@ -0,0 +1,477 @@ +""" +Unique symbolic identity for a variable within a `DesugaringContext` +""" +const VarId = Int + +struct SSAVar + id::VarId +end + +struct LambdaInfo + # TODO: Make this concretely typed? + args::SyntaxList + ret_var::Union{Nothing,SyntaxTree} +end + +abstract type AbstractLoweringContext end + +struct DesugaringContext{GraphType} <: AbstractLoweringContext + graph::GraphType + next_var_id::Ref{VarId} +end + +function DesugaringContext() + graph = SyntaxGraph() + ensure_attributes!(graph, + kind=Kind, syntax_flags=UInt16, green_tree=GreenNode, + source_pos=Int, source=Union{SourceRef,NodeId}, + value=Any, name_val=String, + scope_type=Symbol, # :hard or :soft + var_id=VarId, + lambda_info=LambdaInfo) + DesugaringContext(freeze_attrs(graph), Ref{VarId}(1)) +end + +#------------------------------------------------------------------------------- +# AST creation utilities +_node_id(ex::NodeId) = ex +_node_id(ex::SyntaxTree) = ex.id + +_node_ids() = () +_node_ids(c, cs...) = (_node_id(c), _node_ids(cs...)...) + +function _makenode(graph::SyntaxGraph, srcref, head, children; attrs...) + id = newnode!(graph) + if kind(head) in (K"Identifier", K"core", K"top", K"SSAValue", K"Value", K"slot") || is_literal(head) + @assert length(children) == 0 + else + setchildren!(graph, id, children) + end + setattr!(graph, id; source=srcref.id, attrs...) + sethead!(graph, id, head) + return SyntaxTree(graph, id) +end + +function makenode(graph::SyntaxGraph, srcref, head, children...; attrs...) + _makenode(graph, srcref, head, children; attrs...) +end + +function makenode(ctx::AbstractLoweringContext, srcref, head, children::SyntaxTree...; attrs...) + _makenode(ctx.graph, srcref, head, _node_ids(children...); attrs...) +end + +function makenode(ctx::AbstractLoweringContext, srcref, head, children::SyntaxList; attrs...) + ctx.graph === children.graph || error("Mismatching graphs") + _makenode(ctx.graph, srcref, head, children.ids; attrs...) +end + +function mapchildren(f, ctx, ex) + cs = SyntaxList(ctx) + for e in children(ex) + push!(cs, f(e)) + end + ex2 = makenode(ctx, ex, head(ex), cs) + # Copy all attributes. + # TODO: Make this type stable and efficient + for v in values(ex.graph.attributes) + if haskey(v, ex.id) + v[ex2.id] = v[ex.id] + end + end + return ex2 +end + +function new_var_id(ctx::AbstractLoweringContext) + id = ctx.next_var_id[] + ctx.next_var_id[] += 1 + return id +end + +# Create a new SSA variable +function ssavar(ctx::AbstractLoweringContext, srcref) + id = makenode(ctx, srcref, K"SSAValue", var_id=new_var_id(ctx)) + return id +end + +# Assign `ex` to an SSA variable. +# Return (variable, assignment_node) +function assign_tmp(ctx::AbstractLoweringContext, ex) + var = ssavar(ctx, ex) + assign_var = makenode(ctx, ex, K"=", var, ex) + var, assign_var +end + +# Convenience functions to create leaf nodes referring to identifiers within +# the Core and Top modules. +core_ref(ctx, ex, name) = makenode(ctx, ex, K"core", name_val=name) +Any_type(ctx, ex) = core_ref(ctx, ex, "Any") +svec_type(ctx, ex) = core_ref(ctx, ex, "svec") +nothing_(ctx, ex) = core_ref(ctx, ex, "nothing") +unused(ctx, ex) = core_ref(ctx, ex, "UNUSED") + +top_ref(ctx, ex, name) = makenode(ctx, ex, K"top", name_val=name) + +#------------------------------------------------------------------------------- +# Predicates and accessors working on expression trees + +function is_quoted(ex) + kind(ex) in KSet"quote top core globalref outerref break inert + meta inbounds inline noinline loopinfo" +end + +function is_sym_decl(x) + k = kind(x) + k == K"Identifier" || k == K"::" +end + +# Identifier made of underscores +function is_placeholder(ex) + kind(ex) == K"Identifier" && all(==('_'), ex.name_val) +end + +function is_eventually_call(ex::SyntaxTree) + k = kind(ex) + return k == K"call" || ((k == K"where" || k == K"::") && is_eventually_call(ex[1])) +end + +function is_function_def(ex) + k = kind(ex) + return k == K"function" || k == K"->" || + (k == K"=" && numchildren(ex) == 2 && is_eventually_call(ex[1])) +end + +function identifier_name(ex) + kind(ex) == K"var" ? ex[1] : ex +end + +function is_valid_name(ex) + n = identifier_name(ex).name_val + n !== "ccall" && n !== "cglobal" +end + +function decl_var(ex) + kind(ex) == K"::" ? ex[1] : ex +end + +# given a complex assignment LHS, return the symbol that will ultimately be assigned to +function assigned_name(ex) + k = kind(ex) + if (k == K"call" || k == K"curly" || k == K"where") || (k == K"::" && is_eventually_call(ex)) + assigned_name(ex[1]) + else + ex + end +end + +#------------------------------------------------------------------------------- +# Lowering Pass 1 - basic desugaring +function expand_assignment(ctx, ex) +end + +function expand_condition(ctx, ex) + if head(ex) == K"block" || head(ex) == K"||" || head(ex) == K"&&" + # || and && get special lowering so that they compile directly to jumps + # rather than first computing a bool and then jumping. + error("TODO expand_condition") + end + expand_forms(ctx, ex) +end + +function expand_let(ctx, ex) + scope_type = get(ex, :scope_type, :hard) + blk = ex[2] + if numchildren(ex[1]) == 0 # TODO: Want to use !haschildren(ex[1]) but this doesn't work... + return makenode(ctx, ex, K"block", blk; + scope_type=scope_type) + end + for binding in Iterators.reverse(children(ex[1])) + kb = kind(binding) + if is_sym_decl(kb) + blk = makenode(ctx, ex, K"block", + makenode(ctx, ex, K"local", binding), + blk; + scope_type=scope_type + ) + elseif kb == K"=" && numchildren(binding) == 2 + lhs = binding[1] + rhs = binding[2] + if is_sym_decl(lhs) + tmp, tmpdef = assign_tmp(ctx, rhs) + blk = makenode(ctx, binding, K"block", + tmpdef, + makenode(ctx, ex, K"block", + makenode(ctx, lhs, K"local_def", lhs), # TODO: Use K"local" with attr? + makenode(ctx, rhs, K"=", decl_var(lhs), tmp), + blk; + scope_type=scope_type + ) + ) + else + TODO("Functions and multiple assignment") + end + else + throw(LoweringError(binding, "Invalid binding in let")) + continue + end + end + return blk +end + +function expand_call(ctx, ex) + cs = expand_forms(ctx, children(ex)) + if is_infix_op_call(ex) || is_postfix_op_call(ex) + cs[1], cs[2] = cs[2], cs[1] + end + # TODO: keywords + makenode(ctx, ex, K"call", cs...) +end + +# Strip variable type declarations from within a `local` or `global`, returning +# the stripped expression. Works recursively with complex left hand side +# assignments containing tuple destructuring. Eg, given +# (x::T, (y::U, z)) +# strip out stmts = (local x) (decl x T) (local x) (decl y U) (local z) +# and return (x, (y, z)) +function strip_decls!(ctx, stmts, declkind, ex) + k = kind(ex) + if k == K"Identifier" + push!(stmts, makenode(ctx, ex, declkind, ex)) + ex + elseif k == K"::" + @chk numchildren(ex) == 2 + name = ex[1] + @chk kind(name) == K"Identifier" + push!(stmts, makenode(ctx, ex, declkind, name)) + push!(stmts, makenode(ctx, ex, K"decl", name, ex[2])) + name + elseif k == K"tuple" || k == K"parameters" + cs = SyntaxList(ctx) + for e in children(ex) + push!(cs, strip_decls!(ctx, stmts, declkind, e)) + end + makenode(ctx, ex, k, cs) + end +end + +# local x, (y=2), z => local x; local y; y = 2; local z +function expand_decls(ctx, ex) + declkind = kind(ex) + stmts = SyntaxList(ctx) + for binding in children(ex) + kb = kind(binding) + if is_function_def(binding) + push!(stmts, makenode(ctx, binding, declkind, assigned_name(binding))) + push!(stmts, binding) + elseif is_prec_assignment(kb) + lhs = strip_decls!(ctx, stmts, declkind, binding[1]) + push!(stmts, makenode(ctx, binding, kb, lhs, binding[2])) + elseif is_sym_decl(binding) + strip_decls!(ctx, stmts, declkind, binding) + else + throw(LoweringError("invalid syntax in variable declaration")) + end + end + makenode(ctx, ex, K"block", stmts) +end + +function analyze_function_arg(full_ex) + name = nothing + type = nothing + default = nothing + is_slurp = false + is_nospecialize = false + ex = full_ex + while true + k = kind(ex) + if k == K"Identifier" || k == K"tuple" + name = ex + break + elseif k == K"::" + @chk numchildren(ex) in (1,2) + if numchildren(ex) == 1 + type = ex[1] + else + name = ex[1] + type = ex[2] + end + break + elseif k == K"..." + @chk full_ex !is_slurp + @chk numchildren(ex) == 1 + is_slurp = true + ex = ex[1] + elseif k == K"meta" + @chk ex[1].name_val == "nospecialize" + is_nospecialize = true + ex = ex[2] + elseif k == K"=" + @chk full_ex isnothing(default) && !is_slurp + default = ex[2] + ex = ex[1] + else + throw(LoweringError(ex, "Invalid function argument")) + end + end + return (name=name, + type=type, + default=default, + is_slurp=is_slurp, + is_nospecialize=is_nospecialize) +end + +function expand_function_def(ctx, ex) + @chk numchildren(ex) in (1,2) + name = ex[1] + if kind(name) == K"where" + TODO("where handling") + end + return_type = nothing + if kind(name) == K"::" + @chk numchildren(name) == 2 + return_type = name[2] + name = name[1] + end + if numchildren(ex) == 1 && is_identifier(name) # TODO: Or name as globalref + if !is_valid_name(name) + throw(LoweringError(name, "Invalid function name")) + end + return makenode(ctx, ex, K"method", identifier_name(name)) + elseif kind(name) == K"call" + callex = name + body = ex[2] + # TODO + # static params + # nospecialize + # argument destructuring + # dotop names + # overlays + + # Add self argument where necessary + args = name[2:end] + name = name[1] + if kind(name) == K"::" + if numchildren(name) == 1 + farg = makenode(ctx, name, K"::", + makenode(ctx, name, K"Identifier", name_val="#self#"), + name[1]) + else + TODO("Fixme type") + farg = name + end + function_name = nothing_(ctx, ex) + else + if !is_valid_name(name) + throw(LoweringError(name, "Invalid function name")) + end + farg = makenode(ctx, name, K"::", + makenode(ctx, name, K"Identifier", name_val="#self#"), + makenode(ctx, name, K"call", core_ref(ctx, name, "Typeof"), name)) + function_name = name + end + + # preamble is arbitrary code which computes + # svec(types, sparms, location) + + arg_names = SyntaxList(ctx) + arg_types = SyntaxList(ctx) + for (i,arg) in enumerate(args) + info = analyze_function_arg(arg) + aname = (isnothing(info.name) || is_placeholder(info.name)) ? + unused(ctx, arg) : info.name + push!(arg_names, aname) + atype = !isnothing(info.type) ? info.type : Any_type(ctx, arg) + @assert !info.is_nospecialize # TODO + @assert !isnothing(info.name) && kind(info.name) == K"Identifier" # TODO + if info.is_slurp + if i != length(args) + throw(LoweringError(arg, "`...` may only be used for the last function argument")) + end + atype = makenode(K"curly", core_ref(ctx, arg, "Vararg"), arg) + end + push!(arg_types, atype) + end + + preamble = makenode(ctx, ex, K"call", + svec_type(ctx, callex), + makenode(ctx, callex, K"call", + svec_type(ctx, name), + arg_types...), + makenode(ctx, callex, K"Value", value=source_location(LineNumberNode, callex)) + ) + if !isnothing(return_type) + ret_var, ret_assign = assign_tmp(ctx, return_type) + body = makenode(ctx, body, K"block", + ret_assign, + body, + scope_type=:hard) + else + ret_var = nothing + body = makenode(ctx, body, K"block", + body, + scope_type=:hard) + end + lambda = makenode(ctx, body, K"lambda", body, + lambda_info=LambdaInfo(arg_names, ret_var)) + makenode(ctx, ex, K"block", + makenode(ctx, ex, K"method", + function_name, + preamble, + lambda), + makenode(ctx, ex, K"unnecessary", function_name)) + elseif kind(name) == K"tuple" + TODO(name, "Anon function lowering") + else + throw(LoweringError(name, "Bad function definition")) + end +end + +function expand_forms(ctx::DesugaringContext, ex::SyntaxTree) + k = kind(ex) + if k == K"call" + expand_call(ctx, ex) + elseif k == K"function" + expand_forms(ctx, expand_function_def(ctx, ex)) + elseif k == K"let" + return expand_forms(ctx, expand_let(ctx, ex)) + elseif k == K"local" || k == K"global" + if numchildren(ex) == 1 && kind(ex[1]) == K"Identifier" + # Don't recurse when already simplified - `local x`, etc + ex + else + expand_forms(ctx, expand_decls(ctx, ex)) # FIXME + end + elseif is_operator(k) && !haschildren(ex) + return makenode(ctx, ex, K"Identifier", name_val=ex.name_val) + elseif k == K"char" || k == K"var" + @chk numchildren(ex) == 1 + return ex[1] + elseif k == K"string" + if numchildren(ex) == 1 && kind(ex[1]) == K"String" + return ex[1] + else + makenode(ctx, ex, K"call", top_ref(ctx, ex, "string"), expand_forms(children(ex))...) + end + elseif k == K"tuple" + # TODO: named tuples + makenode(ctx, ex, K"call", core_ref(ctx, ex, "tuple"), expand_forms(ctx, children(ex))...) + elseif !haschildren(ex) + return ex + else + if k == K"=" + @chk numchildren(ex) == 2 + if kind(ex[1]) ∉ (K"Identifier", K"SSAValue") + TODO(ex, "destructuring assignment") + end + end + mapchildren(e->expand_forms(ctx,e), ctx, ex) + end +end + +function expand_forms(ctx::DesugaringContext, exs::Union{Tuple,AbstractVector}) + res = SyntaxList(ctx) + for e in exs + push!(res, expand_forms(ctx, e)) + end + res +end + diff --git a/src/linear_ir.jl b/src/linear_ir.jl new file mode 100644 index 0000000..8746931 --- /dev/null +++ b/src/linear_ir.jl @@ -0,0 +1,504 @@ +#------------------------------------------------------------------------------- +# Lowering pass 4: Flatten to linear IR + +function is_simple_atom(ex) + k = kind(ex) + # FIXME +# (or (number? x) (string? x) (char? x) +# (and (pair? x) (memq (car x) '(ssavalue null true false thismodule))) +# (eq? (typeof x) 'julia_value))) + is_number(k) || k == K"String" || k == K"Char" +end + +# N.B.: This assumes that resolve-scopes has run, so outerref is equivalent to +# a global in the current scope. +function is_valid_ir_argument(ex) + k = kind(ex) + return is_simple_atom(ex) + # FIXME || + #(k == K"outerref" && nothrow_julia_global(ex[1])) || + #(k == K"globalref" && nothrow_julia_global(ex)) || + #(k == K"quote" || k = K"inert" || k == K"top" || + #k == K"core" || k == K"slot" || k = K"static_parameter") +end + +""" +Context for creating linear IR. + +One of these is created per lambda expression to flatten the body down to +linear IR. +""" +struct LinearIRContext{GraphType} <: AbstractLoweringContext + graph::GraphType + code::SyntaxList{GraphType, Vector{NodeId}} + next_var_id::Ref{Int} + return_type::Union{Nothing,NodeId} + var_info::Dict{VarId,VarInfo} + mod::Module +end + +function LinearIRContext(ctx::ScopeResolutionContext, mod, return_type) + LinearIRContext(ctx.graph, SyntaxList(ctx.graph), ctx.next_var_id, + return_type, ctx.var_info, mod) +end + +function LinearIRContext(ctx::LinearIRContext, return_type) + LinearIRContext(ctx.graph, SyntaxList(ctx.graph), ctx.next_var_id, + return_type, ctx.var_info, ctx.mod) +end + +function is_valid_body_ir_argument(ex) + is_valid_ir_argument(ex) && return true + return false + # FIXME + k = kind(ex) + return k == K"Identifier" && # Arguments are always defined slots + TODO("vinfo-table stuff") +end + +function is_simple_arg(ex) + k = kind(ex) + return is_simple_atom(ex) || k == K"Identifier" || k == K"quote" || k == K"inert" || + k == K"top" || k == K"core" || k == K"globalref" || k == K"outerref" +end + +function is_single_assign_var(ctx::LinearIRContext, ex) + return false # FIXME + id = ex.var_id + # return id in ctx.lambda_args || +end + +function is_const_read_arg(ctx, ex) + k = kind(ex) + return is_simple_atom(ex) || + is_single_assign_var(ctx, ex) || + k == K"quote" || k == K"inert" || k == K"top" || k == K"core" +end + +function is_valid_ir_rvalue(lhs, rhs) + return kind(lhs) == K"SSAValue" || + is_valid_ir_argument(rhs) || + (kind(lhs) == K"Identifier" && + # FIXME: add: splatnew isdefined invoke cfunction gc_preserve_begin copyast new_opaque_closure globalref outerref + kind(rhs) in KSet"new the_exception call foreigncall") +end + +# evaluate the arguments of a call, creating temporary locations as needed +function compile_args(ctx, args) + # First check if all the arguments as simple (and therefore side-effect free). + # Otherwise, we need to use ssa values for all arguments to ensure proper + # left-to-right evaluation semantics. + all_simple = all(is_simple_arg, args) + args_out = SyntaxList(ctx) + for arg in args + arg_val = compile(ctx, arg, true, false) + if (all_simple || is_const_read_arg(ctx, arg_val)) && is_valid_body_ir_argument(arg_val) + push!(args_out, arg_val) + else + push!(args_out, emit_assign_tmp(ctx, arg_val)) + end + end + return args_out +end + +function emit(ctx::LinearIRContext, ex) + push!(ctx.code, ex) + return ex +end + +function emit(ctx::LinearIRContext, srcref, k, args...) + emit(ctx, makenode(ctx, srcref, k, args...)) +end + +# Emit computation of ex, assigning the result to an ssavar and returning that +function emit_assign_tmp(ctx::LinearIRContext, ex) + # TODO: We could replace this with an index into the code array right away? + tmp = makenode(ctx, ex, K"SSAValue", var_id=ctx.next_var_id[]) + ctx.next_var_id[] += 1 + emit(ctx, ex, K"=", tmp, ex) + return tmp +end + +function emit_return(ctx, srcref, ex) + if isnothing(ex) + return + end + # TODO: return type handling + # TODO: exception stack handling + # returning lambda directly is needed for @generated + if !(is_valid_ir_argument(ex) || head(ex) == K"lambda") + ex = emit_assign_tmp(ctx, ex) + end + # TODO: if !isnothing(ctx.return_type) ... + emit(ctx, srcref, K"return", ex) +end + +function emit_assignment(ctx, srcref, lhs, rhs) + if !isnothing(rhs) + if is_valid_ir_rvalue(lhs, rhs) + emit(ctx, srcref, K"=", lhs, rhs) + else + r = emit_assign_tmp(ctx, rhs) + emit(ctx, srcref, K"=", lhs, r) + end + else + # in unreachable code (such as after return); still emit the assignment + # so that the structure of those uses is preserved + emit(ctx, rhs, K"=", lhs, nothing_(ctx, srcref)) + nothing + end +end + +# This pass behaves like an interpreter on the given code. +# To perform stateful operations, it calls `emit` to record that something +# needs to be done. In value position, it returns an expression computing +# the needed value. +# +# TODO: is it ok to return `nothing` if we have no value in some sense +function compile(ctx::LinearIRContext, ex, needs_value, in_tail_pos) + k = kind(ex) + if k == K"Identifier" || is_literal(k) || k == K"SSAValue" || k == K"quote" || k == K"inert" || + k == K"top" || k == K"core" || k == K"Value" + # TODO: other kinds: copyast the_exception $ globalref outerref thismodule cdecl stdcall fastcall thiscall llvmcall + if in_tail_pos + emit_return(ctx, ex, ex) + elseif needs_value + if is_placeholder(ex) + # TODO: ensure outterref, globalref work here + throw(LoweringError(ex, "all-underscore identifiers are write-only and their values cannot be used in expressions")) + end + ex + else + if k == K"Identifier" + emit(ctx, ex) # keep symbols for undefined-var checking + end + nothing + end + elseif k == K"call" + # TODO k ∈ splatnew foreigncall cfunction new_opaque_closure cglobal + args = compile_args(ctx, children(ex)) + callex = makenode(ctx, ex, k, args) + if in_tail_pos + emit_return(ctx, ex, callex) + elseif needs_value + callex + else + emit(ctx, callex) + nothing + end + elseif k == K"=" + lhs = ex[1] + # TODO: Handle underscore + rhs = compile(ctx, ex[2], true, false) + # TODO look up arg-map for renaming if lhs was reassigned + if needs_value && !isnothing(rhs) + r = emit_assign_tmp(ctx, rhs) + emit(ctx, ex, K"=", lhs, r) + if in_tail_pos + emit_return(ctx, ex, r) + else + r + end + else + emit_assignment(ctx, ex, lhs, rhs) + end + elseif k == K"block" + nc = numchildren(ex) + for i in 1:nc + islast = i == nc + compile(ctx, ex[i], islast && needs_value, islast && in_tail_pos) + end + elseif k == K"return" + compile(ctx, ex[1], true, true) + nothing + elseif k == K"method" + # TODO + # throw(LoweringError(ex, + # "Global method definition needs to be placed at the top level, or use `eval`")) + if numchildren(ex) == 1 + if in_tail_pos + emit_return(ctx, ex, ex) + elseif needs_value + ex + else + emit(ctx, ex) + end + else + @chk numchildren(ex) == 3 + fname = ex[1] + sig = compile(ctx, ex[2], true, false) + if !is_valid_ir_argument(sig) + sig = emit_assign_tmp(ctx, sig) + end + lam = ex[3] + if kind(lam) == K"lambda" + lam = compile_lambda(ctx, lam) + else + # lam = emit_assign_tmp(ctx, compile(ctx, lam, true, false)) + TODO(lam, "non-lambda method argument??") + end + emit(ctx, ex, K"method", fname, sig, lam) + @assert !needs_value && !in_tail_pos + nothing + end + elseif k == K"lambda" + lam = compile_lambda(ctx, ex) + if in_tail_pos + emit_return(ctx, ex, lam) + elseif needs_value + lam + else + emit(ctx, lam) + end + elseif k == K"global" + if needs_value + throw(LoweringError(ex, "misplaced `global` declaration")) + end + emit(ctx, ex) + nothing + elseif k == K"local_def" || k == K"local" + nothing + else + throw(LoweringError(ex, "Invalid syntax")) + end +end + + +#------------------------------------------------------------------------------- + +# Recursively renumber an expression within linear IR +# flisp: renumber-stuff +function _renumber(ctx, ssa_rewrites, slot_rewrites, label_table, ex) + k = kind(ex) + if k == K"Identifier" + id = ex.var_id + slot_id = get(slot_rewrites, id, nothing) + if !isnothing(slot_id) + makenode(ctx, ex, K"slot"; var_id=slot_id) + else + # TODO: look up any static parameters + ex + end + elseif k == K"outerref" || k == K"meta" + TODO(ex, "_renumber $k") + elseif is_literal(k) || is_quoted(k) || k == K"global" + ex + elseif k == K"SSAValue" + makenode(ctx, ex, K"SSAValue"; var_id=ssa_rewrites[ex.var_id]) + elseif k == K"goto" || k == K"enter" || k == K"gotoifnot" + TODO(ex, "_renumber $k") + # elseif k == K"lambda" + # renumber_lambda(ctx, ex) + else + mapchildren(ctx, ex) do e + _renumber(ctx, ssa_rewrites, slot_rewrites, label_table, e) + end + # TODO: foreigncall error check: + # "ccall function name and library expression cannot reference local variables" + end +end + +function _ir_to_expr() +end + +# flisp: renumber-lambda, compact-ir +function renumber_body(ctx, input_code, slot_rewrites) + # Step 1: Remove any assignments to SSA variables, record the indices of labels + ssa_rewrites = Dict{VarId,VarId}() + label_table = Dict{String,Int}() + code = SyntaxList(ctx) + for ex in input_code + k = kind(ex) + ex_out = nothing + if k == K"=" && kind(ex[1]) == K"SSAValue" + lhs_id = ex[1].var_id + if kind(ex[2]) == K"SSAValue" + # For SSA₁ = SSA₂, record that all uses of SSA₁ should be replaced by SSA₂ + ssa_rewrites[lhs_id] = ssa_rewrites[ex[2].var_id] + else + # Otherwise, record which `code` index this SSA value refers to + ssa_rewrites[lhs_id] = length(code) + 1 + ex_out = ex[2] + end + elseif k == K"label" + label_table[ex.name_val] = length(code) + 1 + else + ex_out = ex + end + if !isnothing(ex_out) + push!(code, ex_out) + end + end + + # Step 2: + # * Translate any SSA uses and labels into indices in the code table + # * Translate locals into slot indices + for i in 1:length(code) + code[i] = _renumber(ctx, ssa_rewrites, slot_rewrites, label_table, code[i]) + end + code +end + +function to_ir_expr(ex) + k = kind(ex) + if is_literal(k) + ex.value + elseif k == K"core" + GlobalRef(Core, Symbol(ex.name_val)) + elseif k == K"top" + GlobalRef(Base, Symbol(ex.name_val)) + elseif k == K"Identifier" + # Implicitly refers to name in parent module + # TODO: Should we even have plain identifiers at this point or should + # they all effectively be resolved into GlobalRef earlier? + Symbol(ex.name_val) + elseif k == K"slot" + Core.SlotNumber(ex.var_id) + elseif k == K"SSAValue" + Core.SSAValue(ex.var_id) + elseif k == K"return" + Core.ReturnNode(to_ir_expr(ex[1])) + elseif is_quoted(k) + TODO(ex, "Convert SyntaxTree to Expr") + else + # Allowed forms according to https://docs.julialang.org/en/v1/devdocs/ast/ + # + # call invoke static_parameter `=` method struct_type abstract_type + # primitive_type global const new splatnew isdefined the_exception + # enter leave pop_exception inbounds boundscheck loopinfo copyast meta + # foreigncall new_opaque_closure lambda + head = k == K"call" ? :call : + k == K"=" ? :(=) : + k == K"method" ? :method : + k == K"global" ? :global : + k == K"const" ? :const : + nothing + if isnothing(head) + TODO(ex, "Unhandled form") + end + Expr(head, map(to_ir_expr, children(ex))...) + end +end + +# Convert our data structures to CodeInfo +function to_code_info(input_code, mod, funcname, var_info, slot_rewrites) + # Convert code to Expr and record low res locations in table + num_stmts = length(input_code) + code = Vector{Any}(undef, num_stmts) + codelocs = Vector{Int32}(undef, num_stmts) + linetable_map = Dict{Tuple{Int,String}, Int32}() + linetable = Any[] + for i in 1:length(code) + code[i] = to_ir_expr(input_code[i]) + fname = filename(input_code[i]) + lineno, _ = source_location(input_code[i]) + loc = (lineno, fname) + codelocs[i] = get!(linetable_map, loc) do + inlined_at = 0 # FIXME: nonzero for expanded macros + full_loc = Core.LineInfoNode(mod, Symbol(funcname), Symbol(fname), + Int32(lineno), Int32(inlined_at)) + push!(linetable, full_loc) + length(linetable) + end + end + + # FIXME + ssaflags = zeros(UInt32, length(code)) + + nslots = length(slot_rewrites) + slotnames = Vector{Symbol}(undef, nslots) + slot_rename_inds = Dict{String,Int}() + slotflags = Vector{UInt8}(undef, nslots) + for (id,i) in slot_rewrites + info = var_info[id] + name = info.name + ni = get(slot_rename_inds, name, 0) + slot_rename_inds[name] = ni + 1 + if ni > 0 + name = "$name@$ni" + end + slotnames[i] = Symbol(name) + slotflags[i] = 0x00 # FIXME!! + end + + _CodeInfo( + code, + codelocs, + num_stmts, # ssavaluetypes (why put num_stmts in here??) + ssaflags, + nothing, # method_for_inference_limit_heuristics + linetable, + slotnames, + slotflags, + nothing, # slottypes + Any, # rettype + nothing, # parent + nothing, # edges + Csize_t(1), # min_world + typemax(Csize_t), # max_world + false, # inferred + false, # propagate_inbounds + false, # has_fcall + false, # nospecializeinfer + 0x00, # inlining + 0x00, # constprop + 0x0000, # purity + 0xffff, # inlining_cost + ) +end + +function renumber_lambda(ctx, lambda_info, code) + slot_rewrites = Dict{VarId,Tuple{Kind,Int}}() + # lambda arguments become K"slot"; type parameters become K"static_parameter" + info = ex.lambda_info + for (i,arg) in enumerate(info.args) + slot_rewrites[arg.var_id] = i + end + # TODO: add static_parameter here also + renumber_body(ctx, code, slot_rewrites) +end + +# flisp: compile-body +function compile_body(ctx, ex) + compile(ctx, ex, true, true) + # TODO: Fix any gotos + # TODO: Filter out any newvar nodes where the arg is definitely initialized +end + +function _add_slots!(slot_rewrites, var_info, var_ids) + n = length(slot_rewrites) + 1 + for id in var_ids + info = var_info[id] + if info.islocal + slot_rewrites[id] = n + n += 1 + end + end + slot_rewrites +end + +function compile_lambda(outer_ctx, ex) + info = ex.lambda_info + return_type = nothing # FIXME + # TODO: Add assignments for reassigned arguments to body using info.args + ctx = LinearIRContext(outer_ctx, return_type) + compile_body(ctx, ex[1]) + slot_rewrites = Dict{VarId,Int}() + _add_slots!(slot_rewrites, ctx.var_info, (a.var_id for a in info.args)) + _add_slots!(slot_rewrites, ctx.var_info, ex.lambda_vars) + code = renumber_body(ctx, ctx.code, slot_rewrites) + to_code_info(code, ctx.mod, "none", ctx.var_info, slot_rewrites) +end + +function compile_toplevel(outer_ctx, mod, ex) + return_type = nothing + ctx = LinearIRContext(outer_ctx, mod, return_type) + compile_body(ctx, ex) + slot_rewrites = Dict{VarId,Int}() + _add_slots!(slot_rewrites, ctx.var_info, ex.lambda_vars) + code = renumber_body(ctx, ctx.code, slot_rewrites) + to_code_info(code, mod, "top-level scope", ctx.var_info, slot_rewrites) + #var_info = nothing # FIXME + #makenode(ctx, ex, K"Value"; value=LambdaIR(SyntaxList(ctx), ctx.code, var_info)) +end + diff --git a/src/scope_analysis.jl b/src/scope_analysis.jl new file mode 100644 index 0000000..1bfdc7f --- /dev/null +++ b/src/scope_analysis.jl @@ -0,0 +1,274 @@ +# Lowering pass 2: analyze scopes (passes 2/3 in flisp code) +# +# This pass analyzes the names (variables/constants etc) used in scopes +# +# This pass records information about variables used by closure conversion. +# finds which variables are assigned or captured, and records variable +# type declarations. +# +# This info is recorded by setting the second argument of `lambda` expressions +# in-place to +# (var-info-lst captured-var-infos ssavalues static_params) +# where var-info-lst is a list of var-info records + +#------------------------------------------------------------------------------- +# AST traversal functions - useful for performing non-recursive AST traversals +function _schedule_traverse(stack, e) + push!(stack, e) + return nothing +end +function _schedule_traverse(stack, es::Union{Tuple,AbstractVector,Base.Generator}) + append!(stack, es) + return nothing +end + +function traverse_ast(f, exs) + todo = SyntaxList(first(exs).graph) + append!(todo, exs) + while !isempty(todo) + f(pop!(todo), e->_schedule_traverse(todo, e)) + end +end + +function traverse_ast(f, ex::SyntaxTree) + traverse_ast(f, (ex,)) +end + +function find_in_ast(f, ex::SyntaxTree) + todo = SyntaxList(ex.graph) + push!(todo, ex) + while !isempty(todo) + e1 = pop!(todo) + res = f(e1, e->_schedule_traverse(todo, e)) + if !isnothing(res) + return res + end + end + return nothing +end + +# NB: This only really works after expand_forms has already processed assignments. +function find_scope_vars(ex, children_only) + assigned_vars = Set{String}() + # TODO: + # local_vars + local_def_vars = Set{String}() + # global_vars + used_vars = Set{String}() + traverse_ast(children_only ? children(ex) : ex) do e, traverse + k = kind(e) + if k == K"Identifier" + push!(used_vars, e.name_val) + elseif !haschildren(e) || hasattr(e, :scope_type) || is_quoted(k) || + k in KSet"lambda module toplevel" + return + elseif k == K"local_def" + push!(local_def_vars, e[1].name_val) + # elseif k == K"method" TODO static parameters + elseif k == K"=" + v = decl_var(e[1]) + if !(kind(v) in KSet"SSAValue globalref outerref" || is_placeholder(v)) + push!(assigned_vars, v.name_val) + end + traverse(e[2]) + else + traverse(children(e)) + end + end + return assigned_vars, local_def_vars, used_vars +end + +function find_decls(decl_kind, ex) + vars = Vector{typeof(ex)}() + traverse_ast(ex) do e, traverse + k = kind(e) + if !haschildren(e) || is_quoted(k) || k in KSet"lambda scope_block module toplevel" + return + elseif k == decl_kind + v = decl_var(e[1]) + if !is_placeholder(v) + push!(vars, decl_var(v)) + end + else + traverse(children(e)) + end + end + var_names = [v.name_val for v in vars] + return unique(var_names) +end + +# Determine whether decl_kind is in the scope of `ex` +# +# flisp: find-scope-decl +function has_scope_decl(decl_kind, ex) + find_in_ast(ex) do e, traverse + k = kind(e) + if !haschildren(e) || is_quoted(k) || k in KSet"lambda scope_block module toplevel" + return + elseif k == decl_kind + return e + else + traverse(children(ex)) + end + end +end + +# struct LambdaVars +# # For analyze-variables pass +# # var_info_lst::Set{Tuple{Symbol,Symbol}} # ish? +# # captured_var_infos ?? +# # ssalabels::Set{SSAValue} +# # static_params::Set{Symbol} +# end + +# Mirror of flisp scope info structure +# struct ScopeInfo +# lambda_vars::Union{LambdaLocals,LambdaInfo} +# parent::Union{Nothing,ScopeBlockInfo} +# args::Set{Symbol} +# locals::Set{Symbol} +# globals::Set{Symbol} +# static_params::Set{Symbol} +# renames::Dict{Symbol,Symbol} +# implicit_globals::Set{Symbol} +# warn_vars::Set{Symbol} +# is_soft::Bool +# is_hard::Bool +# table::Dict{Symbol,Any} +# end + +""" +Metadata about a variable name - whether it's a local, etc +""" +struct VarInfo + name::String + islocal::Bool # Local variable (if unset, variable is global) + isarg::Bool # Is a function argument + is_single_assign::Bool # Single assignment +end + +struct ScopeResolutionContext{GraphType} <: AbstractLoweringContext + graph::GraphType + next_var_id::Ref{VarId} + # Stack of name=>id mappings for each scope, innermost scope last. + var_id_stack::Vector{Dict{String,VarId}} + # Stack of var `id`s for lambda (or toplevel thunk) being processed, innermost last. + lambda_vars::Vector{Set{VarId}} + # Metadata about variables. There's only one map for this, as var_id is is + # unique across the context, even for same-named vars in unrelated local + # scopes. + var_info::Dict{VarId,VarInfo} +end + +function ScopeResolutionContext(ctx::DesugaringContext) + graph = ensure_attributes(ctx.graph, lambda_vars=Set{VarId}) + ScopeResolutionContext(graph, ctx.next_var_id, + Vector{Dict{String,VarId}}(), + [Set{VarId}()], + Dict{VarId,VarInfo}()) +end + +function lookup_var(ctx, name) + for i in lastindex(ctx.var_id_stack):-1:1 + ids = ctx.var_id_stack[i] + id = get(ids, name, nothing) + if !isnothing(id) + return id + end + end + return nothing +end + +function new_var(ctx, name; isarg=false, islocal=isarg) + id = new_var_id(ctx) + ctx.var_info[id] = VarInfo(name, islocal, isarg, false) + push!(last(ctx.lambda_vars), id) + id +end + +function resolve_scope!(f::Function, ctx, ex, is_toplevel) + id_map = Dict{String,VarId}() + is_hard_scope = get(ex, :scope_type, :hard) == :hard + assigned, local_def_vars, used_vars = find_scope_vars(ex, !is_toplevel) + for name in local_def_vars + id_map[name] = new_var(ctx, name, islocal=true) + end + for name in assigned + if !haskey(id_map, name) && isnothing(lookup_var(ctx, name)) + # Previously unknown assigned vars are impicit locals or globals + id_map[name] = new_var(ctx, name, islocal=!is_toplevel) + end + end + outer_scope = is_toplevel ? id_map : ctx.var_id_stack[1] + for name in used_vars + if !haskey(id_map, name) && isnothing(lookup_var(ctx, name)) + # Identifiers which weren't discovered further up the stack are + # newly discovered globals + outer_scope[name] = new_var(ctx, name, islocal=false) + end + end + push!(ctx.var_id_stack, id_map) + res = f(ctx) + pop!(ctx.var_id_stack) + return res +end + +resolve_scopes!(ctx::DesugaringContext, ex) = resolve_scopes!(ScopeResolutionContext(ctx), ex) + +function resolve_scopes!(ctx::ScopeResolutionContext, ex) + resolve_scope!(ctx, ex, true) do cx + resolve_scopes_!(cx, ex) + end + setattr!(ctx.graph, ex.id, lambda_vars=only(ctx.lambda_vars)) + SyntaxTree(ctx.graph, ex.id) +end + +function resolve_scopes_!(ctx, ex) + k = kind(ex) + if k == K"Identifier" + if is_placeholder(ex) + return # FIXME - make these K"placeholder"? + end + # TODO: Maybe we shouldn't do this in place?? + setattr!(ctx.graph, ex.id, var_id=lookup_var(ctx, ex.name_val)) + elseif !haschildren(ex) || is_quoted(ex) || k == K"toplevel" + return + elseif k == K"global" + TODO("global") + elseif k == K"local" + TODO("local") + # TODO + # elseif require_existing_local + # elseif locals # return Dict of locals + # elseif islocal + elseif k == K"lambda" + # TODO: Lambda captures! + info = ex.lambda_info + id_map = Dict{String,VarId}() + for a in info.args + id_map[a.name_val] = new_var(ctx, a.name_val, isarg=true) + end + push!(ctx.var_id_stack, id_map) + for a in info.args + resolve_scopes!(ctx, a) + end + vars = Set{VarId}() + setattr!(ctx.graph, ex.id, lambda_vars=vars) + push!(ctx.lambda_vars, vars) + resolve_scopes_!(ctx, ex[1]) + pop!(ctx.lambda_vars) + pop!(ctx.var_id_stack) + elseif k == K"block" && hasattr(ex, :scope_type) + resolve_scope!(ctx, ex, false) do cx + for e in children(ex) + resolve_scopes_!(cx, e) + end + end + else + for e in children(ex) + resolve_scopes_!(ctx, e) + end + end + ex +end + diff --git a/src/syntax_graph.jl b/src/syntax_graph.jl new file mode 100644 index 0000000..f739ef2 --- /dev/null +++ b/src/syntax_graph.jl @@ -0,0 +1,432 @@ +const NodeId = Int + +""" +Directed graph with arbitrary attributes on nodes. Used here for representing +one or several syntax trees. +""" +struct SyntaxGraph{Attrs} + edge_ranges::Vector{UnitRange{Int}} + edges::Vector{NodeId} + attributes::Attrs +end + +SyntaxGraph() = SyntaxGraph{Dict{Symbol,Any}}(Vector{UnitRange{Int}}(), + Vector{NodeId}(), Dict{Symbol,Any}()) + +# "Freeze" attribute names and types, encoding them in the type of the returned +# SyntaxGraph. +function freeze_attrs(graph::SyntaxGraph) + frozen_attrs = (; pairs(graph.attributes)...) + SyntaxGraph(graph.edge_ranges, graph.edges, frozen_attrs) +end + +function _show_attrs(io, attributes::Dict) + show(io, MIME("text/plain"), attributes) +end +function _show_attrs(io, attributes::NamedTuple) + show(io, MIME("text/plain"), Dict(pairs(attributes)...)) +end + +function Base.show(io::IO, ::MIME"text/plain", graph::SyntaxGraph) + print(io, typeof(graph), + " with $(length(graph.edge_ranges)) vertices, $(length(graph.edges)) edges, and attributes:\n") + _show_attrs(io, graph.attributes) +end + +function ensure_attributes!(graph::SyntaxGraph; kws...) + for (k,v) in pairs(kws) + @assert k isa Symbol + @assert v isa Type + if haskey(graph.attributes, k) + v0 = valtype(graph.attributes[k]) + v == v0 || throw(ErrorException("Attribute type mismatch $v != $v0")) + else + graph.attributes[k] = Dict{NodeId,v}() + end + end +end + +function ensure_attributes(graph::SyntaxGraph; kws...) + g = SyntaxGraph(graph.edge_ranges, graph.edges, Dict(pairs(graph.attributes)...)) + ensure_attributes!(g; kws...) + freeze_attrs(g) +end + +function newnode!(graph::SyntaxGraph) + push!(graph.edge_ranges, 0:-1) # Invalid range start => leaf node + return length(graph.edge_ranges) +end + +function setchildren!(graph::SyntaxGraph, id, children::NodeId...) + setchildren!(graph, id, children) +end + +function setchildren!(graph::SyntaxGraph, id, children) + n = length(graph.edges) + graph.edge_ranges[id] = n+1:(n+length(children)) + # TODO: Reuse existing edges if possible + append!(graph.edges, children) +end + +function JuliaSyntax.haschildren(graph::SyntaxGraph, id) + first(graph.edge_ranges[id]) > 0 +end + +function JuliaSyntax.numchildren(graph::SyntaxGraph, id) + length(graph.edge_ranges[id]) +end + +function JuliaSyntax.children(graph::SyntaxGraph, id) + @view graph.edges[graph.edge_ranges[id]] +end + +function JuliaSyntax.child(graph::SyntaxGraph, id::NodeId, i::Integer) + graph.edges[graph.edge_ranges[id][i]] +end + +function getattr(graph::SyntaxGraph{<:Dict}, name::Symbol) + getfield(graph, :attributes)[name] +end + +function getattr(graph::SyntaxGraph{<:NamedTuple}, name::Symbol) + getfield(getfield(graph, :attributes), name) +end + +function getattr(graph::SyntaxGraph, name::Symbol, default) + get(getfield(graph, :attributes), name, default) +end + +# FIXME: Probably terribly non-inferrable? +function setattr!(graph::SyntaxGraph, id; attrs...) + for (k,v) in pairs(attrs) + getattr(graph, k)[id] = v + end +end + +function Base.getproperty(graph::SyntaxGraph, name::Symbol) + # FIXME: Remove access to internals + name === :edge_ranges && return getfield(graph, :edge_ranges) + name === :edges && return getfield(graph, :edges) + name === :attributes && return getfield(graph, :attributes) + return getattr(graph, name) +end + +function sethead!(graph, id::NodeId, h::SyntaxHead) + graph.kind[id] = kind(h) + f = flags(h) + if f != 0 + graph.syntax_flags[id] = f + end +end + +function sethead!(graph, id::NodeId, k::Kind) + graph.kind[id] = k +end + +function _convert_nodes(graph::SyntaxGraph, node::SyntaxNode) + id = newnode!(graph) + sethead!(graph, id, head(node)) + if !isnothing(node.val) + v = node.val + if v isa Symbol + setattr!(graph, id, name_val=string(v)) + else + setattr!(graph, id, value=v) + end + end + setattr!(graph, id, source=SourceRef(node.source, node.position, node.raw)) + if haschildren(node) + cs = map(children(node)) do n + _convert_nodes(graph, n) + end + setchildren!(graph, id, cs) + end + return id +end + +#------------------------------------------------------------------------------- +struct SyntaxTree{GraphType} + graph::GraphType + id::NodeId +end + +function Base.getproperty(tree::SyntaxTree, name::Symbol) + # FIXME: Remove access to internals + name === :graph && return getfield(tree, :graph) + name === :id && return getfield(tree, :id) + id = getfield(tree, :id) + return get(getproperty(getfield(tree, :graph), name), id) do + error("Property `$name[$id]` not found") + end +end + +function Base.propertynames(tree::SyntaxTree) + attrnames(tree) +end + +function Base.get(tree::SyntaxTree, name::Symbol, default) + attr = getattr(getfield(tree, :graph), name, nothing) + return isnothing(attr) ? default : + get(attr, getfield(tree, :id), default) +end + +function Base.getindex(tree::SyntaxTree, i::Integer) + child(tree, i) +end + +function Base.getindex(tree::SyntaxTree, r::UnitRange) + (child(tree, i) for i in r) +end + +Base.firstindex(tree::SyntaxTree) = 1 +Base.lastindex(tree::SyntaxTree) = numchildren(tree) + +function hasattr(tree::SyntaxTree, name::Symbol) + attr = getattr(tree.graph, name, nothing) + return !isnothing(attr) && haskey(attr, tree.id) +end + +function attrnames(tree::SyntaxTree) + attrs = tree.graph.attributes + [name for (name, value) in pairs(attrs) if haskey(value, tree.id)] +end + +# JuliaSyntax tree API + +function JuliaSyntax.haschildren(tree::SyntaxTree) + haschildren(tree.graph, tree.id) +end + +function JuliaSyntax.numchildren(tree::SyntaxTree) + numchildren(tree.graph, tree.id) +end + +function JuliaSyntax.children(tree::SyntaxTree) + SyntaxList(tree.graph, children(tree.graph, tree.id)) +end + +function JuliaSyntax.child(tree::SyntaxTree, i::Integer) + SyntaxTree(tree.graph, child(tree.graph, tree.id, i)) +end + +function JuliaSyntax.head(tree::SyntaxTree) + SyntaxHead(kind(tree), flags(tree)) +end + +function JuliaSyntax.kind(tree::SyntaxTree) + tree.kind +end + +function JuliaSyntax.flags(tree::SyntaxTree) + get(tree, :syntax_flags, 0x0000) +end + + +# Reference to bytes within a source file +struct SourceRef + file::SourceFile + first_byte::Int + # TODO: Do we need the green node, or would last_byte suffice? + green_tree::GreenNode +end + +JuliaSyntax.first_byte(src::SourceRef) = src.first_byte +JuliaSyntax.last_byte(src::SourceRef) = src.first_byte + span(src.green_tree) - 1 +JuliaSyntax.filename(src::SourceRef) = filename(src.file) +JuliaSyntax.source_location(::Type{LineNumberNode}, src::SourceRef) = source_location(LineNumberNode, src.file, src.first_byte) +JuliaSyntax.source_location(src::SourceRef) = source_location(src.file, src.first_byte) + +function Base.show(io::IO, ::MIME"text/plain", src::SourceRef) + highlight(io, src.file, first_byte(src):last_byte(src), note="these are the bytes you're looking for 😊", context_lines_inner=20) +end + +function sourceref(tree::SyntaxTree) + sources = tree.graph.source + id = tree.id + while true + s = sources[id] + if s isa SourceRef + return s + else + id = s::NodeId + end + end +end + +JuliaSyntax.filename(tree::SyntaxTree) = return filename(sourceref(tree)) +JuliaSyntax.source_location(::Type{LineNumberNode}, tree::SyntaxTree) = source_location(LineNumberNode, sourceref(tree)) +JuliaSyntax.source_location(tree::SyntaxTree) = source_location(sourceref(tree)) +JuliaSyntax.first_byte(tree::SyntaxTree) = first_byte(sourceref(tree)) +JuliaSyntax.last_byte(tree::SyntaxTree) = last_byte(sourceref(tree)) + +function SyntaxTree(graph::SyntaxGraph, node::SyntaxNode) + ensure_attributes!(graph, kind=Kind, syntax_flags=UInt16, source=Union{SourceRef,NodeId}, + value=Any, name_val=String) + id = _convert_nodes(graph, node) + return SyntaxTree(graph, id) +end + +function SyntaxTree(node::SyntaxNode) + return SyntaxTree(SyntaxGraph(), node) +end + +attrsummary(name, value) = string(name) +attrsummary(name, value::Number) = "$name=$value" + +function _value_string(ex) + k = kind(ex) + str = k == K"Identifier" || is_operator(k) ? ex.name_val : + k == K"SSAValue" ? "ssa" : + k == K"core" ? "core.$(ex.name_val)" : + k == K"top" ? "top.$(ex.name_val)" : + k == K"slot" ? "slot" : + repr(get(ex, :value, nothing)) + id = get(ex, :var_id, nothing) + if !isnothing(id) + idstr = replace(string(id), + "0"=>"₀", "1"=>"₁", "2"=>"₂", "3"=>"₃", "4"=>"₄", + "5"=>"₅", "6"=>"₆", "7"=>"₇", "8"=>"₈", "9"=>"₉") + str = "$(str).$idstr" + end + if k == K"slot" + # TODO: Ideally shouldn't need to rewrap the id here... + srcex = SyntaxTree(ex.graph, ex.source) + str = "$(str)/$(srcex.name_val)" + end + return str +end + +function _show_syntax_tree(io, current_filename, node, indent, show_byte_offsets) + if hasattr(node, :source) + fname = filename(node) + line, col = source_location(node) + posstr = "$(lpad(line, 4)):$(rpad(col,3))" + if show_byte_offsets + posstr *= "│$(lpad(first_byte(node),6)):$(rpad(last_byte(node),6))" + end + else + fname = nothing + posstr = " " + if show_byte_offsets + posstr *= "│ " + end + end + val = get(node, :value, nothing) + nodestr = haschildren(node) ? "[$(untokenize(head(node)))]" : _value_string(node) + + treestr = string(indent, nodestr) + + std_attrs = Set([:name_val,:value,:kind,:syntax_flags,:source,:var_id]) + attrstr = join([attrsummary(n, getproperty(node, n)) for n in attrnames(node) if n ∉ std_attrs], ",") + if !isempty(attrstr) + treestr = string(rpad(treestr, 40), "│ $attrstr") + end + + # Add filename if it's changed from the previous node + if fname != current_filename[] && !isnothing(fname) + #println(io, "# ", fname) + treestr = string(rpad(treestr, 80), "│$fname") + current_filename[] = fname + end + println(io, posstr, "│", treestr) + if haschildren(node) + new_indent = indent*" " + for n in children(node) + _show_syntax_tree(io, current_filename, n, new_indent, show_byte_offsets) + end + end +end + +function Base.show(io::IO, ::MIME"text/plain", tree::SyntaxTree; show_byte_offsets=false) + println(io, "line:col│ tree │ attributes | file_name") + _show_syntax_tree(io, Ref{Union{Nothing,String}}(nothing), tree, "", show_byte_offsets) +end + +function _show_syntax_tree_sexpr(io, ex) + if !haschildren(ex) + if is_error(ex) + print(io, "(", untokenize(head(ex)), ")") + else + print(io, _value_string(ex)) + end + else + print(io, "(", untokenize(head(ex))) + first = true + for n in children(ex) + print(io, ' ') + _show_syntax_tree_sexpr(io, n) + first = false + end + print(io, ')') + end +end + +function Base.show(io::IO, ::MIME"text/x.sexpression", node::SyntaxTree) + _show_syntax_tree_sexpr(io, node) +end + +function Base.show(io::IO, node::SyntaxTree) + _show_syntax_tree_sexpr(io, node) +end + +#------------------------------------------------------------------------------- +# Lightweight vector of nodes ids with associated pointer to graph stored separately. +struct SyntaxList{GraphType, NodeIdVecType} <: AbstractVector{SyntaxTree} + graph::GraphType + ids::NodeIdVecType +end + +function SyntaxList(graph::SyntaxGraph, ids::AbstractVector{NodeId}) + SyntaxList{typeof(graph), typeof(ids)}(graph, ids) +end + +SyntaxList(graph::SyntaxGraph) = SyntaxList(graph, Vector{NodeId}()) +SyntaxList(ctx) = SyntaxList(ctx.graph) + +Base.size(v::SyntaxList) = size(v.ids) + +Base.IndexStyle(::Type{<:SyntaxList}) = IndexLinear() + +Base.getindex(v::SyntaxList, i::Int) = SyntaxTree(v.graph, v.ids[i]) + +function Base.setindex!(v::SyntaxList, tree::SyntaxTree, i::Int) + v.graph === tree.graph || error("Mismatching syntax graphs") + v.ids[i] = tree.id +end + +function Base.setindex!(v::SyntaxList, id::NodeId, i::Int) + v.ids[i] = id +end + +function Base.push!(v::SyntaxList, tree::SyntaxTree) + v.graph === tree.graph || error("Mismatching syntax graphs") + push!(v.ids, tree.id) +end + +function Base.append!(v::SyntaxList, exs) + for e in exs + push!(v, e) + end + v +end + +function Base.append!(v::SyntaxList, exs::SyntaxList) + v.graph === exs.graph || error("Mismatching syntax graphs") + append!(v.ids, exs.ids) + v +end + +function Base.push!(v::SyntaxList, id::NodeId) + push!(v.ids, id) +end + +function Base.pop!(v::SyntaxList) + SyntaxTree(v.graph, pop!(v.ids)) +end + +#------------------------------------------------------------------------------- + +function JuliaSyntax.build_tree(::Type{SyntaxTree}, stream::JuliaSyntax.ParseStream; kws...) + SyntaxTree(build_tree(SyntaxNode, stream; kws...)) +end + diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 0000000..87136cb --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,112 @@ +# Error handling + +TODO(msg) = throw(ErrorException("Lowering TODO: $msg")) +TODO(ex, msg) = throw(LoweringError(ex, "Lowering TODO: $msg")) + +# Errors found during lowering will result in LoweringError being thrown to +# indicate the syntax causing the error. +struct LoweringError <: Exception + ex::SyntaxTree + msg::String +end + +function Base.showerror(io::IO, exc::LoweringError) + print(io, "LoweringError:\n") + src = sourceref(exc.ex) + highlight(io, src.file, first_byte(src):last_byte(src), note=exc.msg) +end + +function _chk_code(ex, cond) + cond_str = string(cond) + quote + ex = $(esc(ex)) + @assert ex isa SyntaxTree + try + ok = $(esc(cond)) + if !ok + throw(LoweringError(ex, "Expected `$($cond_str)`")) + end + catch + throw(LoweringError(ex, "Structure error evaluating `$($cond_str)`")) + end + end +end + +# Internal error checking macro. +# Check a condition involving an expression, throwing a LoweringError if it +# doesn't evaluate to true. Does some very simple pattern matching to attempt +# to extract the expression variable from the left hand side. +macro chk(cond) + ex = cond + while true + if ex isa Symbol + break + elseif ex.head == :call + ex = ex.args[2] + elseif ex.head == :ref + ex = ex.args[1] + elseif ex.head == :. + ex = ex.args[1] + elseif ex.head in (:(==), :(in), :<, :>) + ex = ex.args[1] + else + error("Can't analyze $cond") + end + end + _chk_code(ex, cond) +end + +macro chk(ex, cond) + _chk_code(ex, cond) +end + + +#------------------------------------------------------------------------------- +# CodeInfo constructor. TODO: Should be in Core? +function _CodeInfo(code, + codelocs, + ssavaluetypes, + ssaflags, + method_for_inference_limit_heuristics, + linetable, + slotnames, + slotflags, + slottypes, + rettype, + parent, + edges, + min_world, + max_world, + inferred, + propagate_inbounds, + has_fcall, + nospecializeinfer, + inlining, + constprop, + purity, + inlining_cost) + @eval $(Expr(:new, :(Core.CodeInfo), + convert(Vector{Any}, code), + convert(Vector{Int32}, codelocs), + convert(Any, ssavaluetypes), + convert(Vector{UInt32}, ssaflags), + convert(Any, method_for_inference_limit_heuristics), + convert(Any, linetable), + convert(Vector{Symbol}, slotnames), + convert(Vector{UInt8}, slotflags), + convert(Any, slottypes), + convert(Any, rettype), + convert(Any, parent), + convert(Any, edges), + convert(UInt64, min_world), + convert(UInt64, max_world), + convert(Bool, inferred), + convert(Bool, propagate_inbounds), + convert(Bool, has_fcall), + convert(Bool, nospecializeinfer), + convert(UInt8, inlining), + convert(UInt8, constprop), + convert(UInt16, purity), + convert(UInt16, inlining_cost))) +end + diff --git a/test/lowering.jl b/test/lowering.jl new file mode 100644 index 0000000..cce8863 --- /dev/null +++ b/test/lowering.jl @@ -0,0 +1,81 @@ +# Just some hacking + +using JuliaSyntax +using JuliaLowering + +using JuliaLowering: SyntaxGraph, SyntaxTree, ensure_attributes!, newnode!, setchildren!, haschildren, children, child, setattr!, sourceref + +src = """ +let + y = 1 + x = 2 + let x = sin(x) + y = x + end + (x, y) +end +""" + +# src = """ +# let +# local x, (y = 2), (w::T = ww), q::S +# end +# """ + +# src = """ +# function foo(x::f(T), y::w(let ; S end)) +# "a \$("b \$("c")")" +# end +# """ + +# src = """ +# let +# function f() Int end +# function foo(y::f(a)) +# y +# end +# end +# """ + + +# src = """ +# x + y +# """ + +t = parsestmt(SyntaxNode, src, filename="foo.jl") + +ctx = JuliaLowering.DesugaringContext() + +t2 = SyntaxTree(ctx.graph, t) + +t3 = JuliaLowering.expand_forms(ctx, t2) + +ctx2 = JuliaLowering.ScopeResolutionContext(ctx) + +t4 = JuliaLowering.resolve_scopes!(ctx2, t3) + +@info "Resolved scopes" t4 + +code = JuliaLowering.compile_toplevel(ctx2, Main, t4) + +@info "Code" code + + +# flisp parts to do +# let +# desugar/let => 76 +# desugar/func => ~100 (partial) +# desugar/call => 70 +# handle-scopes => 195 +# handle-scopes/scope-block => 99 +# handle-scopes/locals => 16 +# linear-ir => 250 (partial, approximate) +# linear-ir/func => 22 + + +# Syntax tree ideas: Want following to work? +# This can be fully inferrable! +# +# t2[3].bindings[1].lhs.string +# t2[3].body[1].signature +