Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
See [Initialization options](https://aviatesk.github.io/JETLS.jl/dev/launching/#init-options)
for details.

### Changed

- Parallelized signature analysis phase using `Threads.@spawn`, leveraging the
thread-safe inference pipeline introduced in Julia v1.12. On a 4-core machine,
first-time analysis of CSV.jl improved from 30s to 18s (~1.7x faster), and
JETLS.jl itself from 154s to 36s (~4.3x faster).

### Fixed

- Fixed handling of messages received before the initialize request per
Expand Down
11 changes: 6 additions & 5 deletions src/JETLS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ using Glob: Glob

abstract type AnalysisEntry end # used by `Analyzer.LSAnalyzer`

include("AtomicContainers/AtomicContainers.jl")
using .AtomicContainers
const SWStats = JETLS_DEV_MODE ? AtomicContainers.SWStats : Nothing
const LWStats = JETLS_DEV_MODE ? AtomicContainers.LWStats : Nothing
const CASStats = JETLS_DEV_MODE ? AtomicContainers.CASStats : Nothing

include("analysis/Analyzer.jl")
using .Analyzer

Expand All @@ -54,12 +60,7 @@ Analyzer.LSAnalyzer(args...; kwargs...) = LSAnalyzer(ScriptAnalysisEntry(filepat

include("analysis/resolver.jl")

include("AtomicContainers/AtomicContainers.jl")
include("FixedSizeFIFOQueue/FixedSizeFIFOQueue.jl")
using .AtomicContainers
const SWStats = JETLS_DEV_MODE ? AtomicContainers.SWStats : Nothing
const LWStats = JETLS_DEV_MODE ? AtomicContainers.LWStats : Nothing
const CASStats = JETLS_DEV_MODE ? AtomicContainers.CASStats : Nothing

include("utils/general.jl")

Expand Down
8 changes: 6 additions & 2 deletions src/analysis/Analyzer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using JET.JETInterface
using JET: JET, CC

using ..JETLS: AnalysisEntry
using ..JETLS.AtomicContainers: CASContainer, CASStats, store!
using ..LSP

# JETLS internal interface
Expand Down Expand Up @@ -83,7 +84,10 @@ struct LSAnalyzer <: ToplevelAbstractAnalyzer
end
function LSAnalyzer(@nospecialize(entry::AnalysisEntry), state::AnalyzerState)
analysis_cache_key = JET.compute_hash(entry, state.inf_params)
analysis_token = get!(AnalysisToken, LS_ANALYZER_CACHE, analysis_cache_key)
analysis_token = store!(LS_ANALYZER_CACHE) do cache
token = get!(AnalysisToken, cache, analysis_cache_key)
return cache, token
end
cache = InterpretationStateCache()
return LSAnalyzer(state, analysis_token, cache)
end
Expand All @@ -110,7 +114,7 @@ function JETInterface.AbstractAnalyzer(analyzer::LSAnalyzer, state::AnalyzerStat
end
JETInterface.AnalysisToken(analyzer::LSAnalyzer) = analyzer.analysis_token

const LS_ANALYZER_CACHE = Dict{UInt, AnalysisToken}()
const LS_ANALYZER_CACHE = CASContainer{Dict{UInt,AnalysisToken},CASStats}(Dict{UInt,AnalysisToken}())

# internal API
# ============
Expand Down
97 changes: 63 additions & 34 deletions src/analysis/Interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ JET.ToplevelAbstractAnalyzer(interp::LSInterpreter) = interp.analyzer
# overloads
# =========

mutable struct SignatureAnalysisProgress
const reports::Vector{JET.InferenceErrorReport}
const reports_lock::ReentrantLock
@atomic done::Int
const interval::Int
@atomic next_interval::Int
function SignatureAnalysisProgress(n_sigs::Int)
interval = max(n_sigs ÷ 25, 1)
new(JET.InferenceErrorReport[], ReentrantLock(), 0, interval, interval)
end
end

function compute_percentage(count, total, max=100)
return min(round(Int, (count / total) * max), max)
end
Expand All @@ -83,7 +95,6 @@ function JET.analyze_from_definitions!(interp::LSInterpreter, config::JET.Toplev
# This makes module context information available immediately for LS features
cache_intermediate_analysis_result!(interp)

analyzer = JET.ToplevelAbstractAnalyzer(interp, JET.non_toplevel_concretized; refresh_local_cache = false)
entrypoint = config.analyze_from_definitions
res = JET.InterpretationState(interp).res
n_sigs = length(res.toplevel_signatures)
Expand All @@ -100,45 +111,63 @@ function JET.analyze_from_definitions!(interp::LSInterpreter, config::JET.Toplev
percentage = 50))
yield_to_endpoint()
end
next_interval = interval = max(n_sigs ÷ 25, 1)
all_reports = JET.InferenceErrorReport[]
for i = 1:n_sigs
if cancellable_token !== nothing
if is_cancelled(cancellable_token.cancel_flag)

progress = SignatureAnalysisProgress(n_sigs)

tasks = map(1:n_sigs) do i
Threads.@spawn :default try
if cancellable_token !== nothing && is_cancelled(cancellable_token.cancel_flag)
return
end
if i == next_interval
percentage = compute_percentage(i, n_sigs, 50) + 50
send_progress(interp.server, cancellable_token.token,
WorkDoneProgressReport(;
cancellable = true,
message = "$i / $n_sigs [signature analysis]",
percentage))
yield_to_endpoint(0.01)
next_interval += interval
tt = res.toplevel_signatures[i]
# Create a new analyzer with fresh local caches (`inf_cache` and `analysis_results`)
# to avoid data races between concurrent signature analysis tasks
analyzer = JET.ToplevelAbstractAnalyzer(interp, JET.non_toplevel_concretized;
refresh_local_cache = true)
match = Base._which(tt;
# NOTE use the latest world counter with `method_table(analyzer)` unwrapped,
# otherwise it may use a world counter when this method isn't defined yet
method_table = CC.method_table(analyzer),
world = CC.get_inference_world(analyzer),
raise = false)
if (match !== nothing &&
(!(entrypoint isa Symbol) || # implies `analyze_from_definitions===true`
match.method.name === entrypoint))
analyzer, result = JET.analyze_method_signature!(analyzer,
match.method, match.spec_types, match.sparams)
reports = JET.get_reports(analyzer, result)
isempty(reports) || @lock progress.reports_lock append!(progress.reports, reports)
else
JETLS_DEV_MODE && @warn "Couldn't find a single method matching the signature" tt
end
end
tt = res.toplevel_signatures[i]
match = Base._which(tt;
# NOTE use the latest world counter with `method_table(analyzer)` unwrapped,
# otherwise it may use a world counter when this method isn't defined yet
method_table = CC.method_table(analyzer),
world = CC.get_inference_world(analyzer),
raise = false)
if (match !== nothing &&
(!(entrypoint isa Symbol) || # implies `analyze_from_definitions===true`
match.method.name === entrypoint))
analyzer, result = JET.analyze_method_signature!(analyzer,
match.method, match.spec_types, match.sparams)
append!(all_reports, JET.get_reports(analyzer, result))
else
# something went wrong
if JETLS_DEV_MODE
@warn "Couldn't find a single method matching the signature `", tt, "`"
done = (@atomic progress.done += 1)
if cancellable_token !== nothing
current_next = @atomic progress.next_interval
if done >= current_next
# Try to update next_interval (may race with other tasks)
@atomicreplace progress.next_interval current_next => current_next + progress.interval
percentage = compute_percentage(done, n_sigs, 50) + 50
send_progress(interp.server, cancellable_token.token,
WorkDoneProgressReport(;
cancellable = true,
message = "$done / $n_sigs [signature analysis]",
percentage))
end
end
catch e
@error "Error during signature analysis"
Base.showerror(stderr, e, catch_backtrace())
end
end
append!(res.inference_error_reports, all_reports)

for task in tasks
wait(task)
if cancellable_token !== nothing && is_cancelled(cancellable_token.cancel_flag)
break
end
end

append!(res.inference_error_reports, progress.reports)
end

function JET.virtual_process!(interp::LSInterpreter,
Expand Down
Loading