diff --git a/TrimAnalyzer/Project.toml b/TrimAnalyzer/Project.toml new file mode 100644 index 000000000..ae592a987 --- /dev/null +++ b/TrimAnalyzer/Project.toml @@ -0,0 +1,22 @@ +name = "TrimAnalyzer" +uuid = "db0f0d6f-36c4-4e19-a1b7-72446e3087f7" +version = "0.1.0" +authors = ["Shuhei Kadowaki "] + +[deps] +InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" +LSP = "880dcf91-6fde-4251-87fc-bfd84012291a" + +[sources] +JET = {rev = "master", url = "https://github.com/aviatesk/JET.jl"} +LSP = {path = "/Users/aviatesk/julia/packages/JETLS/LSP"} + +[compat] +InteractiveUtils = "1.11.0" +JET = "0.10.6" +JSON3 = "1.14.3" +LSP = "0.1" + +[apps.report-trim] diff --git a/TrimAnalyzer/src/TrimAnalyzer.jl b/TrimAnalyzer/src/TrimAnalyzer.jl new file mode 100644 index 000000000..2bda8de6b --- /dev/null +++ b/TrimAnalyzer/src/TrimAnalyzer.jl @@ -0,0 +1,25 @@ +module TrimAnalyzer + +export report_trim, @report_trim + +include("TrimAnalyzerImpl.jl") +using .TrimAnalyzerImpl: TrimAnalyzerImpl + +# Entry points +# ============ + +using InteractiveUtils: InteractiveUtils +using JET: JET + +function report_trim(args...; jetconfigs...) + analyzer = TrimAnalyzerImpl.TrimAnalyzer(; jetconfigs...) + return JET.analyze_and_report_call!(analyzer, args...; jetconfigs...) +end +macro report_trim(ex0...) + return InteractiveUtils.gen_call_with_extracted_types_and_kwargs(__module__, :report_trim, ex0) +end + +include("app.jl") +using .TrimAnalyzerApp: main + +end # module TrimAnalyzer diff --git a/TrimAnalyzer/src/TrimAnalyzerImpl.jl b/TrimAnalyzer/src/TrimAnalyzerImpl.jl new file mode 100644 index 000000000..0e8c0179e --- /dev/null +++ b/TrimAnalyzer/src/TrimAnalyzerImpl.jl @@ -0,0 +1,283 @@ +module TrimAnalyzerImpl + +using Core.IR +using JET.JETInterface +using JET: JET, CC + +struct TrimAnalyzer <: ToplevelAbstractAnalyzer + state::AnalyzerState + analysis_token::AnalysisToken + method_table::CC.CachedMethodTable{CC.OverlayMethodTable} + function TrimAnalyzer(state::AnalyzerState, analysis_token::AnalysisToken) + method_table = CC.CachedMethodTable(CC.OverlayMethodTable(state.world, TRIM_METHOD_TABLE)) + return new(state, analysis_token, method_table) + end +end +function TrimAnalyzer(state::AnalyzerState) + analysis_cache_key = JET.compute_hash(state.inf_params) + analysis_token = get!(AnalysisToken, TRIM_ANALYZER_CACHE, analysis_cache_key) + return TrimAnalyzer(state, analysis_token) +end + +# AbstractInterpreter API +# ======================= + +# TrimAnalyzer does not need any sources, so discard them always +CC.method_table(analyzer::TrimAnalyzer) = analyzer.method_table + +# AbstractAnalyzer API +# ==================== + +JETInterface.AnalyzerState(analyzer::TrimAnalyzer) = analyzer.state +function JETInterface.AbstractAnalyzer(analyzer::TrimAnalyzer, state::AnalyzerState) + return TrimAnalyzer(state, analyzer.analysis_token) +end +JETInterface.AnalysisToken(analyzer::TrimAnalyzer) = analyzer.analysis_token + +const TRIM_ANALYZER_CACHE = Dict{UInt, AnalysisToken}() + +# TRIM_METHOD_TABLE +# =============== + +using Base.Experimental: @overlay +Base.Experimental.@MethodTable TRIM_METHOD_TABLE + +@eval @overlay TRIM_METHOD_TABLE Core.DomainError(@nospecialize(val), @nospecialize(msg::AbstractString)) = (@noinline; $(Expr(:new, :DomainError, :val, :msg))) + +@overlay TRIM_METHOD_TABLE (f::Base.RedirectStdStream)(io::Core.CoreSTDOUT) = Base._redirect_io_global(io, f.unix_fd) + +@overlay TRIM_METHOD_TABLE Base.depwarn(msg, funcsym; force::Bool=false) = nothing +@overlay TRIM_METHOD_TABLE Base._assert_tostring(msg) = "" +@overlay TRIM_METHOD_TABLE Base.reinit_stdio() = nothing +@overlay TRIM_METHOD_TABLE Base.JuliaSyntax.enable_in_core!() = nothing +@overlay TRIM_METHOD_TABLE Base.init_active_project() = Base.ACTIVE_PROJECT[] = nothing +@overlay TRIM_METHOD_TABLE Base.set_active_project(projfile::Union{AbstractString,Nothing}) = Base.ACTIVE_PROJECT[] = projfile +@overlay TRIM_METHOD_TABLE Base.disable_library_threading() = nothing +@overlay TRIM_METHOD_TABLE Base.start_profile_listener() = nothing +@overlay TRIM_METHOD_TABLE Base.invokelatest(f, args...; kwargs...) = f(args...; kwargs...) +@overlay TRIM_METHOD_TABLE function Base.sprint(f::F, args::Vararg{Any,N}; context=nothing, sizehint::Integer=0) where {F<:Function,N} + s = IOBuffer(sizehint=sizehint) + if context isa Tuple + f(IOContext(s, context...), args...) + elseif context !== nothing + f(IOContext(s, context), args...) + else + f(s, args...) + end + String(Base._unsafe_take!(s)) +end +function show_typeish(io::IO, @nospecialize(T)) + if T isa Type + show(io, T) + elseif T isa TypeVar + print(io, (T::TypeVar).name) + else + print(io, "?") + end +end +@overlay TRIM_METHOD_TABLE function Base.show(io::IO, T::Type) + if T isa DataType + print(io, T.name.name) + if T !== T.name.wrapper && length(T.parameters) > 0 + print(io, "{") + first = true + for p in T.parameters + if !first + print(io, ", ") + end + first = false + if p isa Int + show(io, p) + elseif p isa Type + show(io, p) + elseif p isa Symbol + print(io, ":") + print(io, p) + elseif p isa TypeVar + print(io, p.name) + else + print(io, "?") + end + end + print(io, "}") + end + elseif T isa Union + print(io, "Union{") + show_typeish(io, T.a) + print(io, ", ") + show_typeish(io, T.b) + print(io, "}") + elseif T isa UnionAll + print(io, T.body::Type) + print(io, " where ") + print(io, T.var.name) + end +end +@overlay TRIM_METHOD_TABLE Base.show_type_name(io::IO, tn::Core.TypeName) = print(io, tn.name) + +@overlay TRIM_METHOD_TABLE Base.mapreduce(f::F, op::F2, A::Base.AbstractArrayOrBroadcasted; dims=:, init=Base._InitialValue()) where {F, F2} = + Base._mapreduce_dim(f, op, init, A, dims) +@overlay TRIM_METHOD_TABLE Base.mapreduce(f::F, op::F2, A::Base.AbstractArrayOrBroadcasted...; kw...) where {F, F2} = + reduce(op, map(f, A...); kw...) + +@overlay TRIM_METHOD_TABLE Base._mapreduce_dim(f::F, op::F2, nt, A::Base.AbstractArrayOrBroadcasted, ::Colon) where {F, F2} = + Base.mapfoldl_impl(f, op, nt, A) + +@overlay TRIM_METHOD_TABLE Base._mapreduce_dim(f::F, op::F2, ::Base._InitialValue, A::Base.AbstractArrayOrBroadcasted, ::Colon) where {F, F2} = + Base._mapreduce(f, op, IndexStyle(A), A) + +@overlay TRIM_METHOD_TABLE Base._mapreduce_dim(f::F, op::F2, nt, A::Base.AbstractArrayOrBroadcasted, dims) where {F, F2} = + Base.mapreducedim!(f, op, Base.reducedim_initarray(A, dims, nt), A) + +@overlay TRIM_METHOD_TABLE Base._mapreduce_dim(f::F, op::F2, ::Base._InitialValue, A::Base.AbstractArrayOrBroadcasted, dims) where {F,F2} = + Base.mapreducedim!(f, op, Base.reducedim_init(f, op, A, dims), A) + +@overlay TRIM_METHOD_TABLE Base.mapreduce_empty_iter(f::F, op::F2, itr, ItrEltype) where {F, F2} = + Base.reduce_empty_iter(Base.MappingRF(f, op), itr, ItrEltype) +@overlay TRIM_METHOD_TABLE Base.mapreduce_first(f::F, op::F2, x) where {F,F2} = Base.reduce_first(op, f(x)) + +@overlay TRIM_METHOD_TABLE Base._mapreduce(f::F, op::F2, A::Base.AbstractArrayOrBroadcasted) where {F,F2} = Base._mapreduce(f, op, Base.IndexStyle(A), A) +@overlay TRIM_METHOD_TABLE Base.mapreduce_empty(::typeof(identity), op::F, T) where {F} = Base.reduce_empty(op, T) +@overlay TRIM_METHOD_TABLE Base.mapreduce_empty(::typeof(abs), op::F, T) where {F} = abs(Base.reduce_empty(op, T)) +@overlay TRIM_METHOD_TABLE Base.mapreduce_empty(::typeof(abs2), op::F, T) where {F} = abs2(Base.reduce_empty(op, T)) + +@overlay TRIM_METHOD_TABLE Base.Sys.__init_build() = nothing + +# function __init__() +# try +# ccall((:__gmp_set_memory_functions, libgmp), Cvoid, +# (Ptr{Cvoid},Ptr{Cvoid},Ptr{Cvoid}), +# cglobal(:jl_gc_counted_malloc), +# cglobal(:jl_gc_counted_realloc_with_old_size), +# cglobal(:jl_gc_counted_free_with_size)) +# ZERO.alloc, ZERO.size, ZERO.d = 0, 0, C_NULL +# ONE.alloc, ONE.size, ONE.d = 1, 1, pointer(_ONE) +# catch ex +# Base.showerror_nostdio(ex, "WARNING: Error during initialization of module GMP") +# end +# # This only works with a patched version of GMP, ignore otherwise +# try +# ccall((:__gmp_set_alloc_overflow_function, libgmp), Cvoid, +# (Ptr{Cvoid},), +# cglobal(:jl_throw_out_of_memory_error)) +# ALLOC_OVERFLOW_FUNCTION[] = true +# catch ex +# # ErrorException("ccall: could not find function...") +# if typeof(ex) != ErrorException +# rethrow() +# end +# end +# end + +@overlay TRIM_METHOD_TABLE Base.Sort.issorted(itr; + lt::T=isless, by::F=identity, rev::Union{Bool,Nothing}=nothing, order::Base.Sort.Ordering=Forward) where {T,F} = + Base.Sort.issorted(itr, Base.Sort.ord(lt,by,rev,order)) + +@overlay TRIM_METHOD_TABLE function Base.TOML.try_return_datetime(p, year, month, day, h, m, s, ms) + return Base.TOML.DateTime(year, month, day, h, m, s, ms) +end +@overlay TRIM_METHOD_TABLE function Base.TOML.try_return_date(p, year, month, day) + return Base.TOML.Date(year, month, day) +end +@overlay TRIM_METHOD_TABLE function Base.TOML.parse_local_time(l::Base.TOML.Parser) + h = Base.TOML.@try Base.TOML.parse_int(l, false) + h in 0:23 || return Base.TOML.ParserError(Base.TOML.ErrParsingDateTime) + _, m, s, ms = Base.TOML.@try Base.TOML._parse_local_time(l, true) + # TODO: Could potentially parse greater accuracy for the + # fractional seconds here. + return Base.TOML.try_return_time(l, h, m, s, ms) +end +@overlay TRIM_METHOD_TABLE function Base.TOML.try_return_time(p, h, m, s, ms) + return Base.TOML.Time(h, m, s, ms) +end + +# analysis injections +# =================== + +function CC.abstract_call_gf_by_type(analyzer::TrimAnalyzer, + @nospecialize(func), arginfo::CC.ArgInfo, si::CC.StmtInfo, @nospecialize(atype), sv::CC.InferenceState, + max_methods::Int) + ret = @invoke CC.abstract_call_gf_by_type(analyzer::ToplevelAbstractAnalyzer, + func::Any, arginfo::CC.ArgInfo, si::CC.StmtInfo, atype::Any, sv::CC.InferenceState, max_methods::Int) + atype′ = Ref{Any}(atype) + function after_abstract_call_gf_by_type(analyzer′::TrimAnalyzer, sv′::CC.InferenceState) + ret′ = ret[] + report_dispatch_error!(analyzer′, sv′, ret′, atype′[]) + return true + end + if isready(ret) + after_abstract_call_gf_by_type(analyzer, sv) + else + push!(sv.tasks, after_abstract_call_gf_by_type) + end + return ret +end + +# analysis +# ======== + +# DispatchErrorReport +# ------------------- + +@jetreport struct DispatchErrorReport <: InferenceErrorReport + @nospecialize t # ::Union{Type, Vector{Type}} +end +JETInterface.print_report_message(io::IO, report::DispatchErrorReport) = print(io, "Unresolved call found") + +function is_inlineable(analyzer::TrimAnalyzer, match, info) + mi = CC.specialize_method(match; preexisting=true) + isnothing(mi) && return false + ci = get(CC.code_cache(analyzer), mi, nothing) + isnothing(ci) && return false + src = @atomic :monotonic ci.inferred + return CC.src_inlining_policy(analyzer, src, info, zero(UInt32)) +end + +function report_dispatch_error!(analyzer::TrimAnalyzer, sv::CC.InferenceState, call::CC.CallMeta, @nospecialize(atype)) + info = call.info + if info === CC.NoCallInfo() + report = DispatchErrorReport(sv, atype) + add_new_report!(analyzer, sv.result, report) + else + if info isa CC.ConstCallInfo + info = info.call + end + if info isa CC.MethodMatchInfo + for match in info.results + if (isnothing(CC.get_compileable_sig(match.method, match.spec_types, match.sparams)) && + !is_inlineable(analyzer, match, info)) + report = DispatchErrorReport(sv, atype) + add_new_report!(analyzer, sv.result, report) + end + end + else + @assert info isa CC.UnionSplitInfo + for info in info.split + for match in info.results + if (isnothing(CC.get_compileable_sig(match.method, match.spec_types, match.sparams)) && + !is_inlineable(analyzer, match, info)) + report = DispatchErrorReport(sv, atype) + add_new_report!(analyzer, sv.result, report) + end + end + end + end + end + return false +end + +# Constructor +# =========== + +# the entry constructor +function TrimAnalyzer(world::UInt = Base.get_world_counter(); jetconfigs...) + jetconfigs = JET.kwargs_dict(jetconfigs) + jetconfigs[:max_methods] = 3 + # jetconfigs[:assume_bindings_static] = true # TODO + state = AnalyzerState(world; jetconfigs...) + return TrimAnalyzer(state) +end + +JETInterface.valid_configurations(::TrimAnalyzer) = JET.GENERAL_CONFIGURATIONS + +end # module TrimAnalyzerImpl diff --git a/TrimAnalyzer/src/app.jl b/TrimAnalyzer/src/app.jl new file mode 100644 index 000000000..c314c432c --- /dev/null +++ b/TrimAnalyzer/src/app.jl @@ -0,0 +1,232 @@ +module TrimAnalyzerApp + +using ..TrimAnalyzer: report_trim +using JET: JET +using LSP +using LSP.URIs2 +using JSON3 + +function print_usage() + println("""TrimAnalyzer - Detect dispatch errors in Julia code + + Usage: + report-trim [options] + + Options: + --project[=] Set project/environment (same as Julia's --project) + --json Output results in JSON format + -h, --help Show this help message + + Examples: + report-trim example.jl + report-trim --json example.jl + report-trim --project=@. example.jl + report-trim --project=/path/to/project example.jl + """) +end + +# TODO Share code with JETLS.jl + +""" + fix_build_path(path::AbstractString) -> fixed_path::AbstractString + +If this Julia is a built one, convert `path` to `fixed_path`, which is a path to the main +files that are editable (or tracked by git). +""" +function fix_build_path end +let build_dir = normpath(Sys.BINDIR, "..", ".."), # with path separator at the end + share_path = normpath(Sys.BINDIR, Base.DATAROOTDIR, "julia") # without path separator at the end + global fix_build_path + if ispath(normpath(build_dir), "base") + build_path = splitdir(build_dir)[1] # remove the path separator + fix_build_path(path::AbstractString) = replace(path, share_path => build_path) + else + fix_build_path(path::AbstractString) = path + end +end + +to_full_path(file::Symbol) = to_full_path(String(file)) +function to_full_path(file::AbstractString) + file = Base.fixup_stdlib_path(file) + file = something(Base.find_source_file(file), file) + # TODO we should probably make this configurable + return fix_build_path(abspath(file)) +end + +function jet_frame_to_range(frame) + line = JET.fixed_line_number(frame) + return line_range(fixed_line_number(line)) +end + +fixed_line_number(line) = line == 0 ? line : line - 1 + +function line_range(line::Int) + start = Position(; line, character=0) + var"end" = Position(; line, character=Int(typemax(Int32))) + return Range(; start, var"end") +end + +function jet_inference_error_report_to_diagnostic(@nospecialize report::JET.InferenceErrorReport) + topframe = report.vst[1] + message = JET.with_bufferring(:limit=>true) do io + JET.print_report_message(io, report) + end + relatedInformation = DiagnosticRelatedInformation[ + let frame = report.vst[i], + message = sprint(JET.print_frame_sig, frame, JET.PrintConfig()) + DiagnosticRelatedInformation(; + location = Location(; + uri = filepath2uri(to_full_path(frame.file)), + range = jet_frame_to_range(frame)), + message) + end + for i = 2:length(report.vst)] + return Diagnostic(; + range = jet_frame_to_range(topframe), + severity = LSP.DiagnosticSeverity.Error, + message, + source = "TrimAnalyzer", + relatedInformation) +end + +module MainModule end + +function parse_project_path(project::String, filename::String) + if project == "@temp" + return mktempdir() + elseif project == "@." || project == "." + # Search for Project.toml in parent directories + dir = dirname(abspath(filename)) + while true + if isfile(joinpath(dir, "Project.toml")) || isfile(joinpath(dir, "JuliaProject.toml")) + return dir + end + parent = dirname(dir) + if parent == dir # Reached root + error("No Project.toml or JuliaProject.toml found in parent directories") + end + dir = parent + end + elseif startswith(project, "@script") + # Handle @script or @script format + scriptdir = dirname(abspath(filename)) + if project == "@script" + search_dir = scriptdir + else + # Extract relative path from @script + rel_path = project[8:end] # Remove "@script" prefix + search_dir = normpath(joinpath(scriptdir, rel_path)) + end + + # Search up from script directory + dir = search_dir + while true + if isfile(joinpath(dir, "Project.toml")) || isfile(joinpath(dir, "JuliaProject.toml")) + return dir + end + parent = dirname(dir) + if parent == dir # Reached root + error("No Project.toml or JuliaProject.toml found searching from $search_dir") + end + dir = parent + end + else + # Regular directory path + return project + end +end + +function (@main)(args::Vector{String}) + json_output = false + filepath = nothing + project = nothing + + i = 1 + while i <= length(args) + arg = args[i] + + if arg == "--json" + json_output = true + elseif arg == "-h" || arg == "--help" + print_usage() + return 0 + elseif startswith(arg, "--project=") + project = arg[11:end] + elseif arg == "--project" + # Handle --project without equals sign (use current directory) + project = "." + elseif startswith(arg, "-") + println(stderr, "Error: Unknown option: $arg") + println(stderr, "Run with --help to see available options") + return 1 + else + if filepath !== nothing + println(stderr, "Error: Multiple file paths provided") + return 1 + end + filepath = arg + end + i += 1 + end + + if filepath === nothing + println(stderr, "Error: No file path provided") + println(stderr) + print_usage() + return 1 + end + + if !isfile(filepath) + println(stderr, "Error: File not found: $filepath") + return 1 + end + + # Set up LOAD_PATH based on project + if Base.should_use_main_entrypoint() + empty!(LOAD_PATH) + push!(LOAD_PATH, "@", "@v$(VERSION.major).$(VERSION.minor)", "@stdlib") + end + + if project !== nothing + project_path = parse_project_path(project, filepath) + pushfirst!(LOAD_PATH, project_path) + end + + MainModule = Core.eval(Main, :(module MainModule end)) + try + Base.include(MainModule, filepath) + catch e + println(stderr, "Error loading file: ", e) + return 1 + end + + if !(@invokelatest isdefinedglobal(MainModule, :main)) + println(stderr, "Error: `main` is not defined in $filepath") + return 1 + end + + result = report_trim(@invokelatest(MainModule.main), (Vector{String},)) + + reports = JET.get_reports(result) + success = isempty(reports) + + if json_output + diagnostics = LSP.Diagnostic[] + if !success + for report in reports + push!(diagnostics, jet_inference_error_report_to_diagnostic(report)) + end + end + JSON3.write(stdout, (; + filepath, + success, + diagnostics)) + println(stdout) + else + show(stdout, result) + end + + return success ? 0 : 1 +end + +end # module TrimAnalyzerApp