From 18b86dccba9a03f59be68bda83b44e4a2d431052 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Mon, 1 Dec 2025 21:16:27 +0900 Subject: [PATCH] full-analysis: Parallelize signature analysis phase Parallelize signature analysis using `Threads.@spawn`, leveraging the thread-safe inference pipeline introduced in Julia v1.12. Each analysis task creates its own analyzer with fresh local caches to avoid data races. Key changes: - Add `SignatureAnalysisProgress` struct with atomic fields for thread-safe progress tracking - Make `LS_ANALYZER_CACHE` thread-safe using `CASContainer` - Move `AtomicContainers` include earlier to make it available in `Analyzer.jl` On a 4-core machine (`--threads=4,2`): - CSV.jl first-time analysis: 30s -> 18s (~1.7x faster) - JETLS.jl first-time analysis: 154s -> 36s (~4.3x faster) --- CHANGELOG.md | 7 +++ src/JETLS.jl | 11 +++-- src/analysis/Analyzer.jl | 8 ++- src/analysis/Interpreter.jl | 97 ++++++++++++++++++++++++------------- 4 files changed, 82 insertions(+), 41 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b8e7b8028..a48192ad0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,13 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). See [Initialization options](https://aviatesk.github.io/JETLS.jl/dev/launching/#init-options) for details. +### Changed + +- Parallelized signature analysis phase using `Threads.@spawn`, leveraging the + thread-safe inference pipeline introduced in Julia v1.12. On a 4-core machine, + first-time analysis of CSV.jl improved from 30s to 18s (~1.7x faster), and + JETLS.jl itself from 154s to 36s (~4.3x faster). + ### Fixed - Fixed handling of messages received before the initialize request per diff --git a/src/JETLS.jl b/src/JETLS.jl index 4817a453e..39974655c 100644 --- a/src/JETLS.jl +++ b/src/JETLS.jl @@ -45,6 +45,12 @@ using Glob: Glob abstract type AnalysisEntry end # used by `Analyzer.LSAnalyzer` +include("AtomicContainers/AtomicContainers.jl") +using .AtomicContainers +const SWStats = JETLS_DEV_MODE ? AtomicContainers.SWStats : Nothing +const LWStats = JETLS_DEV_MODE ? AtomicContainers.LWStats : Nothing +const CASStats = JETLS_DEV_MODE ? AtomicContainers.CASStats : Nothing + include("analysis/Analyzer.jl") using .Analyzer @@ -54,12 +60,7 @@ Analyzer.LSAnalyzer(args...; kwargs...) = LSAnalyzer(ScriptAnalysisEntry(filepat include("analysis/resolver.jl") -include("AtomicContainers/AtomicContainers.jl") include("FixedSizeFIFOQueue/FixedSizeFIFOQueue.jl") -using .AtomicContainers -const SWStats = JETLS_DEV_MODE ? AtomicContainers.SWStats : Nothing -const LWStats = JETLS_DEV_MODE ? AtomicContainers.LWStats : Nothing -const CASStats = JETLS_DEV_MODE ? AtomicContainers.CASStats : Nothing include("utils/general.jl") diff --git a/src/analysis/Analyzer.jl b/src/analysis/Analyzer.jl index 883be0cb7..f9693decc 100644 --- a/src/analysis/Analyzer.jl +++ b/src/analysis/Analyzer.jl @@ -7,6 +7,7 @@ using JET.JETInterface using JET: JET, CC using ..JETLS: AnalysisEntry +using ..JETLS.AtomicContainers: CASContainer, CASStats, store! using ..LSP # JETLS internal interface @@ -83,7 +84,10 @@ struct LSAnalyzer <: ToplevelAbstractAnalyzer end function LSAnalyzer(@nospecialize(entry::AnalysisEntry), state::AnalyzerState) analysis_cache_key = JET.compute_hash(entry, state.inf_params) - analysis_token = get!(AnalysisToken, LS_ANALYZER_CACHE, analysis_cache_key) + analysis_token = store!(LS_ANALYZER_CACHE) do cache + token = get!(AnalysisToken, cache, analysis_cache_key) + return cache, token + end cache = InterpretationStateCache() return LSAnalyzer(state, analysis_token, cache) end @@ -110,7 +114,7 @@ function JETInterface.AbstractAnalyzer(analyzer::LSAnalyzer, state::AnalyzerStat end JETInterface.AnalysisToken(analyzer::LSAnalyzer) = analyzer.analysis_token -const LS_ANALYZER_CACHE = Dict{UInt, AnalysisToken}() +const LS_ANALYZER_CACHE = CASContainer{Dict{UInt,AnalysisToken},CASStats}(Dict{UInt,AnalysisToken}()) # internal API # ============ diff --git a/src/analysis/Interpreter.jl b/src/analysis/Interpreter.jl index 5d916f88f..044c4e1d9 100644 --- a/src/analysis/Interpreter.jl +++ b/src/analysis/Interpreter.jl @@ -61,6 +61,18 @@ JET.ToplevelAbstractAnalyzer(interp::LSInterpreter) = interp.analyzer # overloads # ========= +mutable struct SignatureAnalysisProgress + const reports::Vector{JET.InferenceErrorReport} + const reports_lock::ReentrantLock + @atomic done::Int + const interval::Int + @atomic next_interval::Int + function SignatureAnalysisProgress(n_sigs::Int) + interval = max(n_sigs ÷ 25, 1) + new(JET.InferenceErrorReport[], ReentrantLock(), 0, interval, interval) + end +end + function compute_percentage(count, total, max=100) return min(round(Int, (count / total) * max), max) end @@ -83,7 +95,6 @@ function JET.analyze_from_definitions!(interp::LSInterpreter, config::JET.Toplev # This makes module context information available immediately for LS features cache_intermediate_analysis_result!(interp) - analyzer = JET.ToplevelAbstractAnalyzer(interp, JET.non_toplevel_concretized; refresh_local_cache = false) entrypoint = config.analyze_from_definitions res = JET.InterpretationState(interp).res n_sigs = length(res.toplevel_signatures) @@ -100,45 +111,63 @@ function JET.analyze_from_definitions!(interp::LSInterpreter, config::JET.Toplev percentage = 50)) yield_to_endpoint() end - next_interval = interval = max(n_sigs ÷ 25, 1) - all_reports = JET.InferenceErrorReport[] - for i = 1:n_sigs - if cancellable_token !== nothing - if is_cancelled(cancellable_token.cancel_flag) + + progress = SignatureAnalysisProgress(n_sigs) + + tasks = map(1:n_sigs) do i + Threads.@spawn :default try + if cancellable_token !== nothing && is_cancelled(cancellable_token.cancel_flag) return end - if i == next_interval - percentage = compute_percentage(i, n_sigs, 50) + 50 - send_progress(interp.server, cancellable_token.token, - WorkDoneProgressReport(; - cancellable = true, - message = "$i / $n_sigs [signature analysis]", - percentage)) - yield_to_endpoint(0.01) - next_interval += interval + tt = res.toplevel_signatures[i] + # Create a new analyzer with fresh local caches (`inf_cache` and `analysis_results`) + # to avoid data races between concurrent signature analysis tasks + analyzer = JET.ToplevelAbstractAnalyzer(interp, JET.non_toplevel_concretized; + refresh_local_cache = true) + match = Base._which(tt; + # NOTE use the latest world counter with `method_table(analyzer)` unwrapped, + # otherwise it may use a world counter when this method isn't defined yet + method_table = CC.method_table(analyzer), + world = CC.get_inference_world(analyzer), + raise = false) + if (match !== nothing && + (!(entrypoint isa Symbol) || # implies `analyze_from_definitions===true` + match.method.name === entrypoint)) + analyzer, result = JET.analyze_method_signature!(analyzer, + match.method, match.spec_types, match.sparams) + reports = JET.get_reports(analyzer, result) + isempty(reports) || @lock progress.reports_lock append!(progress.reports, reports) + else + JETLS_DEV_MODE && @warn "Couldn't find a single method matching the signature" tt end - end - tt = res.toplevel_signatures[i] - match = Base._which(tt; - # NOTE use the latest world counter with `method_table(analyzer)` unwrapped, - # otherwise it may use a world counter when this method isn't defined yet - method_table = CC.method_table(analyzer), - world = CC.get_inference_world(analyzer), - raise = false) - if (match !== nothing && - (!(entrypoint isa Symbol) || # implies `analyze_from_definitions===true` - match.method.name === entrypoint)) - analyzer, result = JET.analyze_method_signature!(analyzer, - match.method, match.spec_types, match.sparams) - append!(all_reports, JET.get_reports(analyzer, result)) - else - # something went wrong - if JETLS_DEV_MODE - @warn "Couldn't find a single method matching the signature `", tt, "`" + done = (@atomic progress.done += 1) + if cancellable_token !== nothing + current_next = @atomic progress.next_interval + if done >= current_next + # Try to update next_interval (may race with other tasks) + @atomicreplace progress.next_interval current_next => current_next + progress.interval + percentage = compute_percentage(done, n_sigs, 50) + 50 + send_progress(interp.server, cancellable_token.token, + WorkDoneProgressReport(; + cancellable = true, + message = "$done / $n_sigs [signature analysis]", + percentage)) + end end + catch e + @error "Error during signature analysis" + Base.showerror(stderr, e, catch_backtrace()) end end - append!(res.inference_error_reports, all_reports) + + for task in tasks + wait(task) + if cancellable_token !== nothing && is_cancelled(cancellable_token.cancel_flag) + break + end + end + + append!(res.inference_error_reports, progress.reports) end function JET.virtual_process!(interp::LSInterpreter,