From 55dcca5fcab9277601b5a46d96e8c56dcfae490b Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 27 Feb 2025 16:10:06 +0000 Subject: [PATCH 01/69] Rework test cases a bit --- src/test_resources.jl | 40 ++++++++++++++++++++++++++++++++++++++++ test/issues.jl | 42 ++---------------------------------------- test/tf.jl | 34 ---------------------------------- 3 files changed, 42 insertions(+), 74 deletions(-) create mode 100644 src/test_resources.jl delete mode 100644 test/tf.jl diff --git a/src/test_resources.jl b/src/test_resources.jl new file mode 100644 index 00000000..d05e7f0b --- /dev/null +++ b/src/test_resources.jl @@ -0,0 +1,40 @@ +module TestResources + +# Old test case without any produce statements used to test TapedFunction. Since this +# doesn't exist as a distinct entity anymore, not clear that this test case is useful. +mutable struct S + i::Int + S(x, y) = new(x + y) +end + +# Old test case without any produce statements. Might make sense to ensure that something +# vaguely like this is included in the test suite, but isn't directly relevant. +function g(x, y) + if x>y + r = string(sin(x)) + else + r = sin(x) * cos(y) + end + return r +end + +# Old test case -- github.com/TuringLang/Libtask.jl/issues/148, unused argument +function f(x) + produce(1) +end + +# Old test case. Probably redundant, but makes sense to check. Might want to replace the +# final statement with a produce statement to make the test case meaningful. +function g(x, y) + c = x + y + return (; c, x, y) +end + +# Make sure I provide a test case in which a function contains consts. +function f() + # this line generates: %1 = 1::Core.Const(1) + r = (a = 1) + return nothing +end + +end diff --git a/test/issues.jl b/test/issues.jl index 370c0235..f534d0b1 100644 --- a/test/issues.jl +++ b/test/issues.jl @@ -26,6 +26,8 @@ @test ex === nothing end + # TODO: this test will need to change because I'm going to modify the interface _very_ + # slightly. @testset "Issue-140, copy unstarted task" begin function f(x) for i in 1:3 @@ -40,44 +42,4 @@ @test consume(ttask3) == 5 end - @testset "Issue-148, unused argument" begin - function f(x) - produce(1) - end - - ttask = TapedTask(f, 2) - @test consume(ttask) == 1 - end - - @testset "Issue-Turing-1873, NamedTuple syntax" begin - function g(x, y) - c = x + y - return (; c, x, y) - end - - tf = Libtask.TapedFunction(g, 1, 2) - r = tf(1, 2) - @test r == (c=3, x=1, y=2) - end - - @testset "Issue-Libtask-174, SSAValue=Int and static parameter" begin - # SSAValue = Int - function f() - # this line generates: %1 = 1::Core.Const(1) - r = (a = 1) - return nothing - end - tf = Libtask.TapedFunction(f) - r = tf() - @test r == nothing - - # static parameter - function g(::Type{T}) where {T} - a = zeros(T, 10) - end - tf = Libtask.TapedFunction(g, Float64) - r = tf(Float64) - @test r == zeros(Float64, 10) - end - end diff --git a/test/tf.jl b/test/tf.jl deleted file mode 100644 index f1fd87c5..00000000 --- a/test/tf.jl +++ /dev/null @@ -1,34 +0,0 @@ -using Libtask - -@testset "tapedfunction" begin - # Test case 1: stack allocated objects are deep copied. - @testset "Instruction{typeof(__new__)}" begin - mutable struct S - i::Int - S(x, y) = new(x + y) - end - - tf = Libtask.TapedFunction(S, 1, 2) - s1 = tf(1, 2) - @test s1.i == 3 - newins = findall(x -> isa(x, Libtask.Instruction{typeof(Libtask.__new__)}), tf.tape) - @test length(newins) == 1 - end - - @testset "Compiled Tape" begin - function g(x, y) - if x>y - r= string(sin(x)) - else - r= sin(x) * cos(y) - end - return r - end - - tf = Libtask.TapedFunction(g, 1., 2.) - ctf = Libtask.compile(tf) - r = ctf(1., 2.) - - @test typeof(r) === Float64 - end -end From 8f9c3ce073080d52cffde8fbcc7b0e745f03f460 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 27 Feb 2025 16:18:45 +0000 Subject: [PATCH 02/69] Add formatter config --- .JuliaFormatter.toml | 1 + 1 file changed, 1 insertion(+) create mode 100644 .JuliaFormatter.toml diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml new file mode 100644 index 00000000..c7439503 --- /dev/null +++ b/.JuliaFormatter.toml @@ -0,0 +1 @@ +style = "blue" \ No newline at end of file From 9f0ad12b75947f737e0b5c39dddf99fb871517d3 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 27 Feb 2025 16:19:08 +0000 Subject: [PATCH 03/69] Update gitignore to remove redundant items --- .gitignore | 57 ------------------------------------------------------ 1 file changed, 57 deletions(-) diff --git a/.gitignore b/.gitignore index d039efaf..1c7787bd 100644 --- a/.gitignore +++ b/.gitignore @@ -1,60 +1,3 @@ -# Prerequisites -*.d - -# Object files -*.o -*.ko -*.obj -*.elf - -# Linker output -*.ilk -*.map -*.exp - -# Precompiled Headers -*.gch -*.pch - -# Libraries -*.lib -*.a -*.la -*.lo - -# Shared objects (inc. Windows DLLs) -*.dll -*.so -*.so.* -*.dylib - -# Executables -*.exe -*.out -*.app -*.i*86 -*.x86_64 -*.hex - -# Debug files -*.dSYM/ -*.su -*.idb -*.pdb - -# Kernel Module Compile Results -*.mod* -*.cmd -.tmp_versions/ -modules.order -Module.symvers -Mkfile.old -dkms.conf - # Projects files Manifest.toml -deps/build.log -deps/deps.jl -deps/usr/ -deps/tmp-build.jl *.cov From ffd86e2088f592817e08ba2bd28145b909232ce0 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 27 Feb 2025 16:19:28 +0000 Subject: [PATCH 04/69] Bump minor version because small breaking change --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 41385419..d683d9b2 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" license = "MIT" desc = "Tape based task copying in Turing" repo = "https://github.com/TuringLang/Libtask.jl.git" -version = "0.8.8" +version = "0.9.0" [deps] FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" From 2d063467edc36c022a64ebb5ee9001f24401751c Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 27 Feb 2025 16:20:02 +0000 Subject: [PATCH 05/69] Tell users that various old types have been actually removed --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index c9b623fe..219bd5b4 100644 --- a/README.md +++ b/README.md @@ -97,3 +97,5 @@ Notes: to a tape and copying that tape. Before that version, it is based on a tricky hack on the Julia internals. You can check the commit history of this repo to see the details. + +- From version 0.9.0, the old `TArray` and `TRef` types are completely removed, where previously they were deprecated. \ No newline at end of file From 0c545bb243196ef856c1cb5d51bb80aefbf44b82 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 27 Feb 2025 16:22:29 +0000 Subject: [PATCH 06/69] Formatting of perf --- perf/benchmark.jl | 37 ++++++++++++++++++++----------------- perf/p0.jl | 9 ++++----- perf/p1.jl | 25 +++++++++++-------------- 3 files changed, 35 insertions(+), 36 deletions(-) diff --git a/perf/benchmark.jl b/perf/benchmark.jl index dcfb3638..a78a59bf 100644 --- a/perf/benchmark.jl +++ b/perf/benchmark.jl @@ -24,15 +24,15 @@ function benchmark_driver!(f, x...; f_displayname=string(f)) GC.gc() print(" Run TapedTask: ") - x = (x[1:end-1]..., produce) + x = (x[1:(end - 1)]..., produce) # show the number of produce calls inside `f` function f_task(f, x; verbose=false) tt = TapedTask(f, x...) c = 0 - while consume(tt)!==nothing - c+=1 + while consume(tt) !== nothing + c += 1 end - verbose && print("#produce=", c, "; ") + return verbose && print("#produce=", c, "; ") end # Note that we need to pass `f` instead of `tf` to avoid # default continuation in `TapedTask` constructor, see, e.g. @@ -40,15 +40,15 @@ function benchmark_driver!(f, x...; f_displayname=string(f)) f_task(f, x; verbose=true) # print #produce calls @btime $f_task($f, $x) GC.gc() + return nothing end #################################################################### - function rosenbrock(x, callback=nothing) i = x[2:end] - j = x[1:end-1] - ret = sum((1 .- j).^2 + 100*(i - j.^2).^2) + j = x[1:(end - 1)] + ret = sum((1 .- j) .^ 2 + 100 * (i - j .^ 2) .^ 2) callback !== nothing && callback(ret) return ret end @@ -59,17 +59,20 @@ benchmark_driver!(rosenbrock, x) #################################################################### function ackley(x::AbstractVector, callback=nothing) - a, b, c = 20.0, -0.2, 2.0*π + a, b, c = 20.0, -0.2, 2.0 * π len_recip = inv(length(x)) sum_sqrs = zero(eltype(x)) sum_cos = sum_sqrs for i in x - sum_cos += cos(c*i) + sum_cos += cos(c * i) sum_sqrs += i^2 callback !== nothing && callback(sum_sqrs) end - return (-a * exp(b * sqrt(len_recip*sum_sqrs)) - - exp(len_recip*sum_cos) + a + MathConstants.e) + return ( + -a * exp(b * sqrt(len_recip * sum_sqrs)) - exp(len_recip * sum_cos) + + a + + MathConstants.e + ) end x = rand(100000) @@ -79,8 +82,8 @@ benchmark_driver!(ackley, x) function generate_matrix_test(n) return (x, callback=nothing) -> begin # @assert length(x) == 2n^2 + n - a = reshape(x[1:n^2], n, n) - b = reshape(x[n^2 + 1:2n^2], n, n) + a = reshape(x[1:(n^2)], n, n) + b = reshape(x[(n^2 + 1):(2n^2)], n, n) ret = log.((a * b) + a - b) callback !== nothing && callback(ret) return ret @@ -94,7 +97,7 @@ benchmark_driver!(matrix_test, x; f_displayname="matrix_test") #################################################################### relu(x) = log.(1.0 .+ exp.(x)) -sigmoid(n) = 1. / (1. + exp(-n)) +sigmoid(n) = 1.0 / (1.0 + exp(-n)) function neural_net(w1, w2, w3, x1, callback=nothing) x2 = relu(w1 * x1) @@ -104,7 +107,7 @@ function neural_net(w1, w2, w3, x1, callback=nothing) return ret end -xs = (randn(10,10), randn(10,10), randn(10), rand(10)) +xs = (randn(10, 10), randn(10, 10), randn(10), rand(10)) benchmark_driver!(neural_net, xs...) #################################################################### @@ -114,11 +117,11 @@ println("======= breakdown benchmark =======") x = rand(100000) tf = Libtask.TapedFunction(ackley, x, nothing) tf(x, nothing); -idx = findlast((x)->isa(x, Libtask.Instruction), tf.tape) +idx = findlast((x) -> isa(x, Libtask.Instruction), tf.tape) ins = tf.tape[idx] b = ins.input[1] -@show ins.input |> length +@show length(ins.input) @btime map(x -> Libtask._lookup(tf, x), ins.input) @btime Libtask._lookup(tf, b) @btime tf.binding_values[b] diff --git a/perf/p0.jl b/perf/p0.jl index 757d8b3a..23a71250 100644 --- a/perf/p0.jl +++ b/perf/p0.jl @@ -5,17 +5,16 @@ using BenchmarkTools @model gdemo(x, y) = begin # Assumptions - σ ~ InverseGamma(2,3) - μ ~ Normal(0,sqrt(σ)) + σ ~ InverseGamma(2, 3) + μ ~ Normal(0, sqrt(σ)) # Observations x ~ Normal(μ, sqrt(σ)) y ~ Normal(μ, sqrt(σ)) end - # Case 1: Sample from the prior. rng = MersenneTwister() -m = Turing.Core.TracedModel(gdemo(1.5, 2.), SampleFromPrior(), VarInfo(), rng) +m = Turing.Core.TracedModel(gdemo(1.5, 2.0), SampleFromPrior(), VarInfo(), rng) f = m.evaluator[1]; args = m.evaluator[2:end]; @@ -28,7 +27,7 @@ println("Run a tape...") @btime t.tf(args...) # Case 2: SMC sampler -m = Turing.Core.TracedModel(gdemo(1.5, 2.), Sampler(SMC(50)), VarInfo(), rng) +m = Turing.Core.TracedModel(gdemo(1.5, 2.0), Sampler(SMC(50)), VarInfo(), rng) f = m.evaluator[1]; args = m.evaluator[2:end]; diff --git a/perf/p1.jl b/perf/p1.jl index 4ecd2ec8..34797f3c 100644 --- a/perf/p1.jl +++ b/perf/p1.jl @@ -2,25 +2,22 @@ using Turing, Test, AbstractMCMC, DynamicPPL, Random import AbstractMCMC.AbstractSampler -function check_numerical(chain, - symbols::Vector, - exact_vals::Vector; - atol=0.2, - rtol=0.0) +function check_numerical(chain, symbols::Vector, exact_vals::Vector; atol=0.2, rtol=0.0) for (sym, val) in zip(symbols, exact_vals) - E = val isa Real ? - mean(chain[sym]) : - vec(mean(chain[sym], dims=1)) + E = val isa Real ? mean(chain[sym]) : vec(mean(chain[sym]; dims=1)) @info (symbol=sym, exact=val, evaluated=E) - @test E ≈ val atol=atol rtol=rtol + @test E ≈ val atol = atol rtol = rtol end end function check_MoGtest_default(chain; atol=0.2, rtol=0.0) - check_numerical(chain, - [:z1, :z2, :z3, :z4, :mu1, :mu2], - [1.0, 1.0, 2.0, 2.0, 1.0, 4.0], - atol=atol, rtol=rtol) + return check_numerical( + chain, + [:z1, :z2, :z3, :z4, :mu1, :mu2], + [1.0, 1.0, 2.0, 2.0, 1.0, 4.0]; + atol=atol, + rtol=rtol, + ) end @model gdemo_d(x, y) = begin @@ -36,4 +33,4 @@ chain = sample(gdemo_d(1.5, 2.0), alg, 5_000) @show chain -check_numerical(chain, [:s, :m], [49/24, 7/6], atol=0.1) +check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) From dc5ab50a7c6f0d1bb2628292af449158f3a48814 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 27 Feb 2025 16:22:57 +0000 Subject: [PATCH 07/69] Remove old code --- src/tapedfunction.jl | 514 ------------------------------------------- src/tapedtask.jl | 219 ------------------ test/tape_copy.jl | 195 ---------------- test/tapedtask.jl | 159 ------------- 4 files changed, 1087 deletions(-) delete mode 100644 src/tapedfunction.jl delete mode 100644 src/tapedtask.jl delete mode 100644 test/tape_copy.jl delete mode 100644 test/tapedtask.jl diff --git a/src/tapedfunction.jl b/src/tapedfunction.jl deleted file mode 100644 index 5e85b87b..00000000 --- a/src/tapedfunction.jl +++ /dev/null @@ -1,514 +0,0 @@ -#= -`TapedFunction` converts a Julia function to a friendly tape for user-specified interpreters. -With this tape-like abstraction for functions, we gain some control over how a function is -executed, like capturing continuations, caching variables, injecting additional control flows -(i.e., produce/consume) between instructions on the tape, etc. - -Under the hood, we first used Julia's compiler API to get the IR code of the original function. -We use the unoptimized typed code in a non-strict SSA form. Then we convert each IR instruction -to a Julia data structure (an object of a subtype of AbstractInstruction). All the operands -(i.e., the variables) these instructions use are stored in a data structure called `Bindings`. -This conversion/binding process is performed at compile-time / tape-recording time and is only -done once for each function. - -In a nutshell, there are two types of instructions (or primitives) on a tape: - - Ordinary function call - - Control-flow instruction: GotoInstruction and CondGotoInstruction, ReturnInstruction - -Once the tape is recorded, we can run the tape just like calling the original function. -We first plugin the arguments, run each instruction on the tape and stop after encountering -a ReturnInstruction. We also provide a mechanism to add a callback after each instruction. -This API allowed us to implement the `produce/consume` mechanism in TapedTask. And exploiting -these features, we implemented a fork mechanism for TapedTask. - -Some potentially sharp edges of this implementation: - - 1. GlobalRef is evaluated at the tape-recording time (compile-time). Most times, - the value/object associated with a GlobalRef does not change at run time. - So this works well. But, if you do something like `module A v=1 end; make tapedfunction; A.eval(:(v=2)); run tf;`, - The assignment won't work. - 2. QuoteNode is also evaluated at the tape-recording time (compile-time). Primarily - the result of evaluating a QuoteNode is a Symbol, which usually works well. - 3. Each Instruction execution contains one unnecessary allocation at the moment. - So writing a function with vectorized computation will be more performant, - for example, using broadcasting instead of a loop. -=# - -const LOGGING = Ref(false) - -## Instruction and TapedFunction -abstract type AbstractInstruction end -const RawTape = Vector{AbstractInstruction} - -function _infer(f, args_type) - # `code_typed` returns a vector: [Pair{Core.CodeInfo, DataType}] - ir0 = code_typed(f, Tuple{args_type...}, optimize=false)[1][1] - return ir0 -end - -const Bindings = Vector{Any} - -mutable struct TapedFunction{F, TapeType} - func::F # maybe a function, a constructor, or a callable object - arity::Int - ir::Core.CodeInfo - tape::TapeType - counter::Int - binding_values::Bindings - arg_binding_slots::Vector{Int} # arg indices in binding_values - retval_binding_slot::Int # 0 indicates the function has not returned - deepcopy_types::Type # use a Union type for multiple types - - function TapedFunction{F, T}(f::F, args...; cache=false, deepcopy_types=Union{}) where {F, T} - args_type = _accurate_typeof.(args) - cache_key = (f, deepcopy_types, args_type...) - - if cache && haskey(TRCache, cache_key) # use cache - cached_tf = TRCache[cache_key]::TapedFunction{F, T} - tf = copy(cached_tf) - tf.counter = 1 - return tf - end - ir = _infer(f, args_type) - binding_values, slots, tape = translate!(RawTape(), ir) - - tf = new{F, T}(f, length(args), ir, tape, 1, binding_values, slots, 0, deepcopy_types) - TRCache[cache_key] = tf # set cache - return tf - end - - TapedFunction(f, args...; cache=false, deepcopy_types=Union{}) = - TapedFunction{typeof(f), RawTape}(f, args...; cache=cache, deepcopy_types=deepcopy_types) - - function TapedFunction{F, T0}(tf::TapedFunction{F, T1}) where {F, T0, T1} - new{F, T0}(tf.func, tf.arity, tf.ir, tf.tape, - tf.counter, tf.binding_values, tf.arg_binding_slots, 0, tf.deepcopy_types) - end - - TapedFunction(tf::TapedFunction{F, T}) where {F, T} = TapedFunction{F, T}(tf) -end - -const TRCache = LRU{Tuple, TapedFunction}(maxsize=10) -const CompiledTape = Vector{FunctionWrapper{Nothing, Tuple{TapedFunction}}} - -function Base.convert(::Type{CompiledTape}, tape::RawTape) - ctape = CompiledTape(undef, length(tape)) - for idx in 1:length(tape) - ctape[idx] = FunctionWrapper{Nothing, Tuple{TapedFunction}}(tape[idx]) - end - return ctape -end - -compile(tf::TapedFunction{F, RawTape}) where {F} = TapedFunction{F, CompiledTape}(tf) - -@inline _lookup(tf::TapedFunction, v::Int) = @inbounds tf.binding_values[v] -@inline _update_var!(tf::TapedFunction, v::Int, c) = @inbounds tf.binding_values[v] = c - -""" - Instruction - -An `Instruction` stands for a function call -""" -struct Instruction{F, N} <: AbstractInstruction - func::F - input::NTuple{N, Int} - output::Int -end - -struct GotoInstruction <: AbstractInstruction - # we enusre a 1-to-1 mapping between ir.code and instruction - # so here we can use the index directly. - dest::Int -end - -struct CondGotoInstruction <: AbstractInstruction - condition::Int - dest::Int -end - -struct ReturnInstruction <: AbstractInstruction - arg::Int -end - -struct NOOPInstruction <: AbstractInstruction end - -@inline result(t::TapedFunction) = t.binding_values[t.retval_binding_slot] -@inline function _arg(tf::TapedFunction, i::Int; default=nothing) - length(tf.arg_binding_slots) < i && return default - tf.arg_binding_slots[i] > 0 && return tf.binding_values[tf.arg_binding_slots[i]] - return default -end -@inline function _arg!(tf::TapedFunction, i::Int, v) - length(tf.arg_binding_slots) >= i && - tf.arg_binding_slots[i] > 0 && _update_var!(tf, tf.arg_binding_slots[i], v) -end - -function (tf::TapedFunction)(args...; callback=nothing, continuation=false) - if !continuation # reset counter and retval_binding_slot to run from the start - tf.counter = 1 - tf.retval_binding_slot = 0 - end - - # set args - if tf.counter <= 1 - # The first slot in `binding_values` is assumed to be `tf.func`. - _arg!(tf, 1, tf.func) - for i in 1:length(args) # the subsequent arg_binding_slots are arguments - slot = i + 1 - _arg!(tf, slot, args[i]) - end - end - - # run the raw tape - while true - ins = tf.tape[tf.counter] - ins(tf) - callback !== nothing && callback() - tf.retval_binding_slot != 0 && break - end - return result(tf) -end - -function Base.show(io::IO, tf::TapedFunction) - # we use an extra IOBuffer to collect all the data and then - # output it once to avoid output interrupt during task context - # switching - buf = IOBuffer() - println(buf, "TapedFunction:") - println(buf, "* .func => $(tf.func)") - println(buf, "* .ir =>") - println(buf, "------------------") - println(buf, tf.ir) - println(buf, "------------------") - print(io, String(take!(buf))) -end - -function Base.show(io::IO, rtape::RawTape) - buf = IOBuffer() - print(buf, length(rtape), "-element RawTape") - isempty(rtape) || println(buf, ":") - i = 1 - for instr in rtape - print(buf, "\t", i, " => ") - show(buf, instr) - i += 1 - end - print(io, String(take!(buf))) -end - -## methods for Instruction -Base.show(io::IO, instr::AbstractInstruction) = println(io, "A ", typeof(instr)) - -function Base.show(io::IO, instr::Instruction) - println(io, "Instruction(", instr.output, "=", instr.func, instr.input) -end - -function Base.show(io::IO, instr::GotoInstruction) - println(io, "GotoInstruction(dest=", instr.dest, ")") -end - -function Base.show(io::IO, instr::CondGotoInstruction) - println(io, "CondGotoInstruction(", instr.condition, ", dest=", instr.dest, ")") -end - -function (instr::Instruction{F})(tf::TapedFunction) where F - # catch run-time exceptions / errors. - try - func = F === Int ? _lookup(tf, instr.func) : instr.func - inputs = map(x -> _lookup(tf, x), instr.input) - output = func(inputs...) - _update_var!(tf, instr.output, output) - tf.counter += 1 - catch e - println("counter=", tf.counter) - println("tf=", tf) - println(e, catch_backtrace()); - rethrow(e); - end -end - -function (instr::GotoInstruction)(tf::TapedFunction) - tf.counter = instr.dest -end - -function (instr::CondGotoInstruction)(tf::TapedFunction) - cond = _lookup(tf, instr.condition) - if cond - tf.counter += 1 - else # goto dest unless cond - tf.counter = instr.dest - end -end - -function (instr::ReturnInstruction)(tf::TapedFunction) - tf.retval_binding_slot = instr.arg -end - -function (instr::NOOPInstruction)(tf::TapedFunction) - tf.counter += 1 -end - -## internal functions -_accurate_typeof(v) = typeof(v) -_accurate_typeof(::Type{V}) where V = Type{V} -_loose_type(t) = t -_loose_type(::Type{Type{T}}) where T = isa(T, DataType) ? Type{T} : typeof(T) - -""" - __new__(T, args...) - -Return a new instance of `T` with `args` even when there is no inner constructor for these args. -Source: https://discourse.julialang.org/t/create-a-struct-with-uninitialized-fields/6967/5 -""" -@generated function __new__(T, args...) - return Expr(:splatnew, :T, :args) -end - - -## Translation: CodeInfo -> Tape - -const IRVar = Union{Core.SSAValue, Core.SlotNumber} - -function bind_var!(var, bindings::Bindings, ir::Core.CodeInfo) - # for literal constants, and static parameters - var = Meta.isexpr(var, :static_parameter) ? ir.parent.sparam_vals[var.args[1]] : var - push!(bindings, var) - idx = length(bindings) - return idx -end -function bind_var!(var::GlobalRef, bindings::Bindings, ir::Core.CodeInfo) - in(var.mod, (Base, Core)) || - LOGGING[] && @info "evaluating GlobalRef $var at compile time" - bind_var!(getproperty(var.mod, var.name), bindings, ir) -end -function bind_var!(var::QuoteNode, bindings::Bindings, ir::Core.CodeInfo) - LOGGING[] && @info "evaluating QuoteNode $var at compile time" - bind_var!(eval(var), bindings, ir) -end -function bind_var!(var::TypedSlot, bindings::Bindings, ir::Core.CodeInfo) - # turn TypedSlot to SlotNumber - bind_var!(Core.SlotNumber(var.id), bindings, ir) -end -function bind_var!(var::Core.SlotNumber, bindings::Bindings, ir::Core.CodeInfo) - get!(bindings[1], var, allocate_binding!(var, bindings, ir.slottypes[var.id])) -end -function bind_var!(var::Core.SSAValue, bindings::Bindings, ir::Core.CodeInfo) - get!(bindings[1], var, allocate_binding!(var, bindings, ir.ssavaluetypes[var.id])) -end -allocate_binding!(var, bindings::Bindings, c::Core.Const) = - allocate_binding!(var, bindings, _loose_type(Type{_accurate_typeof(c.val)})) - -allocate_binding!(var, bindings::Bindings, c::Core.PartialStruct) = - allocate_binding!(var, bindings, _loose_type(c.typ)) -function allocate_binding!(var, bindings::Bindings, ::Type{T}) where T - # we may use the type info (T) here - push!(bindings, nothing) - idx = length(bindings) - return idx -end - -function translate!(tape::RawTape, ir::Core.CodeInfo) - binding_values = Bindings() - sizehint!(binding_values, 128) - bcache = Dict{IRVar, Int}() - # the first slot of binding_values is used to store a cache at compile time - push!(binding_values, bcache) - slots = Dict{Int, Int}() - - for (idx, line) in enumerate(ir.code) - isa(line, Core.Const) && (line = line.val) # unbox Core.Const - isconst = isa(ir.ssavaluetypes[idx], Core.Const) - ins = translate!!(Core.SSAValue(idx), line, binding_values, isconst, ir) - push!(tape, ins) - end - for (k, v) in bcache - isa(k, Core.SlotNumber) && (slots[k.id] = v) - end - arg_binding_slots = fill(0, maximum(keys(slots); init=0)) - for (k, v) in slots - arg_binding_slots[k] = v - end - binding_values[1] = 0 # drop bcache - return (binding_values, arg_binding_slots, tape) -end - -function _const_instruction(var::IRVar, v, bindings::Bindings, ir) - if isa(var, Core.SSAValue) - box = bind_var!(var, bindings, ir) - bindings[box] = v - return NOOPInstruction() - end - return Instruction(identity, (bind_var!(v, bindings, ir),), bind_var!(var, bindings, ir)) -end - -function translate!!(var::IRVar, line::Core.NewvarNode, - bindings::Bindings, isconst::Bool, @nospecialize(ir)) - # use a no-op to ensure the 1-to-1 mapping from ir.code to instructions on tape. - return NOOPInstruction() -end - -function translate!!(var::IRVar, line::GlobalRef, - bindings::Bindings, isconst::Bool, ir) - if isconst - v = ir.ssavaluetypes[var.id].val - return _const_instruction(var, v, bindings, ir) - end - func() = getproperty(line.mod, line.name) - return Instruction(func, (), bind_var!(var, bindings, ir)) -end - -function translate!!(var::IRVar, line::Core.SlotNumber, - bindings::Bindings, isconst::Bool, ir) - if isconst - v = ir.ssavaluetypes[var.id].val - return _const_instruction(var, v, bindings, ir) - end - func = identity - input = (bind_var!(line, bindings, ir),) - output = bind_var!(var, bindings, ir) - return Instruction(func, input, output) -end - -function translate!!(var::IRVar, line::Number, # literal vars - bindings::Bindings, isconst::Bool, ir) - func = identity - input = (bind_var!(line, bindings, ir),) - output = bind_var!(var, bindings, ir) - return Instruction(func, input, output) -end - -function translate!!(var::IRVar, line::NTuple{N, Symbol}, - bindings::Bindings, isconst::Bool, ir) where {N} - # for syntax (; x, y, z), see Turing.jl#1873 - func = identity - input = (bind_var!(line, bindings, ir),) - output = bind_var!(var, bindings, ir) - return Instruction(func, input, output) -end - -function translate!!(var::IRVar, line::TypedSlot, - bindings::Bindings, isconst::Bool, ir) - input_box = bind_var!(Core.SlotNumber(line.id), bindings, ir) - return Instruction(identity, (input_box,), bind_var!(var, bindings, ir)) -end - -function translate!!(var::IRVar, line::Core.GotoIfNot, - bindings::Bindings, isconst::Bool, ir) - cond = bind_var!(line.cond, bindings, ir) - return CondGotoInstruction(cond, line.dest) -end - -function translate!!(var::IRVar, line::Core.GotoNode, - bindings::Bindings, isconst::Bool, @nospecialize(ir)) - return GotoInstruction(line.label) -end - -function translate!!(var::IRVar, line::Core.ReturnNode, - bindings::Bindings, isconst::Bool, ir) - return ReturnInstruction(bind_var!(line.val, bindings, ir)) -end - -_canbeoptimized(v) = isa(v, DataType) || isprimitivetype(typeof(v)) -function translate!!(var::IRVar, line::Expr, - bindings::Bindings, isconst::Bool, ir::Core.CodeInfo) - head = line.head - _bind_fn = (x) -> bind_var!(x, bindings, ir) - if head === :new - args = map(_bind_fn, line.args) - return Instruction(__new__, args |> Tuple, _bind_fn(var)) - elseif head === :call - # Only some of the function calls can be optimized even though many of their results are - # inferred as constants: we only optimize primitive and datatype constants for now. For - # optimised function calls, we will evaluate the function at compile-time and cache results. - if isconst - v = ir.ssavaluetypes[var.id].val - _canbeoptimized(v) && return _const_instruction(var, v, bindings, ir) - end - args = map(_bind_fn, line.args) - # args[1] is the function - func = line.args[1] - if Meta.isexpr(func, :static_parameter) # func is a type parameter - func = ir.parent.sparam_vals[func.args[1]] - elseif isa(func, GlobalRef) - func = getproperty(func.mod, func.name) # Staging out global reference variable (constants). - else # a var? - func = args[1] # a var(box) - end - return Instruction(func, args[2:end] |> Tuple, _bind_fn(var)) - elseif head === :(=) - # line.args[1] (the left hand side) is a SlotNumber, and it should be the output - lhs = line.args[1] - rhs = line.args[2] # the right hand side, maybe a Expr, or a var, or ... - if Meta.isexpr(rhs, (:new, :call)) - return translate!!(lhs, rhs, bindings, false, ir) - else # rhs is a single value - if isconst - v = ir.ssavaluetypes[var.id].val - return Instruction(identity, (_bind_fn(v),), _bind_fn(lhs)) - end - return Instruction(identity, (_bind_fn(rhs),), _bind_fn(lhs)) - end - elseif head === :static_parameter - v = ir.parent.sparam_vals[line.args[1]] - return Instruction(identity, (_bind_fn(v),), _bind_fn(var)) - else - @error "Unknown Expression: " typeof(var) var typeof(line) line - throw(ErrorException("Unknown Expression")) - end -end - -function translate!!(var, line, bindings, isconst, ir) - @error "Unknown IR code: " typeof(var) var typeof(line) line - throw(ErrorException("Unknown IR code")) -end - -## copy Bindings, TapedFunction - -""" - tape_shallowcopy(x) - tape_deepcopy(x) - -Function `tape_shallowcopy` and `tape_deepcopy` are used to copy data -while copying a TapedFunction. A value in the bindings of a -TapedFunction is either `tape_shallowcopy`ed or `tape_deepcopy`ed. For -TapedFunction, all types are shallow copied by default, and you can -specify some types to be deep copied by giving the `deepcopy_types` -kwyword argument while constructing a TapedFunction. - -The default behaviour of `tape_shallowcopy` is, we return its argument -untouched, like `identity` does, i.e., `tape_copy(x) = x`. The default -behaviour of `tape_deepcopy` is, we call `deepcopy` on its argument -and return the result, `tape_deepcopy(x) = deepcopy(x)`. If one wants -some kinds of data to be copied (shallowly or deeply) in a different -way, one can overload these functions. - -""" -function tape_shallowcopy end, function tape_deepcopy end - -tape_shallowcopy(x) = x -tape_deepcopy(x) = deepcopy(x) - -# Core.Box is used as closure captured variable container, so we should tape_copy its contents -_tape_copy(box::Core.Box, deepcopy_types) = Core.Box(_tape_copy(box.contents, deepcopy_types)) - -function _tape_copy(v, deepcopy_types) - if isa(v, deepcopy_types) - tape_deepcopy(v) - else - tape_shallowcopy(v) - end -end - -function copy_bindings(old::Bindings, deepcopy_types) - newb = copy(old) - for k in 1:length(old) - newb[k] = _tape_copy(old[k], deepcopy_types) - end - return newb -end - -function Base.copy(tf::TapedFunction) - new_tf = TapedFunction(tf) - new_tf.binding_values = copy_bindings(tf.binding_values, tf.deepcopy_types) - return new_tf -end diff --git a/src/tapedtask.jl b/src/tapedtask.jl deleted file mode 100644 index c96f4720..00000000 --- a/src/tapedtask.jl +++ /dev/null @@ -1,219 +0,0 @@ -struct TapedTaskException - exc::Exception - backtrace::Vector{Any} -end - -struct TapedTask{F, AT<:Tuple} - task::Task - tf::TapedFunction{F} - args::AT - produce_ch::Channel{Any} - consume_ch::Channel{Int} - produced_val::Vector{Any} - - function TapedTask( - t::Task, - tf::TapedFunction{F}, - args::AT, - produce_ch::Channel{Any}, - consume_ch::Channel{Int} - ) where {F, AT<:Tuple} - new{F, AT}(t, tf, args, produce_ch, consume_ch, Any[]) - end -end - -function producer() - ttask = current_task().storage[:tapedtask]::TapedTask - if length(ttask.produced_val) > 0 - val = pop!(ttask.produced_val) - put!(ttask.produce_ch, val) - take!(ttask.consume_ch) # wait for next consumer - end - return nothing -end - -function wrap_task(tf, produce_ch, consume_ch, args...) - try - tf(args...; callback=producer, continuation=true) - catch e - bt = catch_backtrace() - put!(produce_ch, TapedTaskException(e, bt)) - # @error "TapedTask Error: " exception=(e, bt) - rethrow() - finally - @static if VERSION >= v"1.4" - # we don't do this under Julia 1.3, because `isempty` always hangs on - # an empty channel. - while !isempty(produce_ch) - yield() - end - end - close(produce_ch) - close(consume_ch) - end -end - -function TapedTask(tf::TapedFunction, args...) - produce_ch = Channel() - consume_ch = Channel{Int}() - task = @task wrap_task(tf, produce_ch, consume_ch, args...) - t = TapedTask(task, tf, args, produce_ch, consume_ch) - task.storage === nothing && (task.storage = IdDict()) - task.storage[:tapedtask] = t - return t -end - -BASE_COPY_TYPES = Union{Array, Ref} - -# NOTE: evaluating model without a trace, see -# https://github.com/TuringLang/Turing.jl/pull/1757#diff-8d16dd13c316055e55f300cd24294bb2f73f46cbcb5a481f8936ff56939da7ceR329 -function TapedTask(f, args...; deepcopy_types=nothing) # deepcoy Array and Ref by default. - if isnothing(deepcopy_types) - deepcopy = BASE_COPY_TYPES - else - deepcopy = Union{BASE_COPY_TYPES, deepcopy_types} - end - tf = TapedFunction(f, args...; cache=true, deepcopy_types=deepcopy) - TapedTask(tf, args...) -end - -TapedTask(finfo::Tuple{Any, Type}, args...) = TapedTask(finfo[1], args...; deepcopy_types=finfo[2]) -TapedTask(t::TapedTask, args...) = TapedTask(func(t), args...; deepcopy_types=t.tf.deepcopy_types) -func(t::TapedTask) = t.tf.func - -#= -# ** Approach (A) to implement `produce`: -# Make`produce` a standalone instturction. This approach does NOT -# support `produce` in a nested call -function internal_produce(instr::Instruction, val) - tf = instr.tape - ttask = tf.owner - put!(ttask.produce_ch, val) - take!(ttask.consume_ch) # wait for next consumer -end - -function produce(val) - error("Libtask.produce can only be directly called in a task!") -end - -function (instr::Instruction{typeof(produce)})() - args = val(instr.input[1]) - internal_produce(instr, args) -end -=# - - -# ** Approach (B) to implement `produce`: -# This way has its caveat: -# `produce` may deeply hide in an instruction, but not be an instruction -# itself, and when we copy a task, the newly copied task will resume from -# the instruction after the one which contains this `produce` call. If the -# call to `produce` is not the last expression in the instuction, that -# instruction will not be whole executed in the copied task. -@inline function is_in_tapedtask() - ct = current_task() - ct.storage === nothing && return false - haskey(ct.storage, :tapedtask) || return false - # check if we are recording a tape - ttask = ct.storage[:tapedtask]::TapedTask - return !isempty(ttask.tf.tape) -end - -function produce(val) - is_in_tapedtask() || return nothing - ttask = current_task().storage[:tapedtask]::TapedTask - length(ttask.produced_val) > 1 && - error("There is a produced value which is not consumed.") - push!(ttask.produced_val, val) - return nothing -end - -function consume(ttask::TapedTask) - if istaskstarted(ttask.task) - # tell producer that a consumer is coming - put!(ttask.consume_ch, 0) - else - schedule(ttask.task) - end - - val = try - take!(ttask.produce_ch) - catch e - isa(e, InvalidStateException) || rethrow() - istaskfailed(ttask.task) && throw(ttask.task.exception) - # TODO: we return nothing to indicate the end of a task, - # remove this when AdvancedPS is udpated. - istaskdone(ttask.task) && return nothing - end - - # yield to let the task resume, this is necessary when there's - # an exception is thrown in the task, it gives the task the chance - # to rethow the exception and set its proper status: - yield() - isa(val, TapedTaskException) && throw(val.exc) - return val -end - -# Iteration interface. -function Base.iterate(t::TapedTask, state=nothing) - try - consume(t), nothing - catch ex - !isa(ex, InvalidStateException) && rethrow - nothing - end -end -Base.IteratorSize(::Type{<:TapedTask}) = Base.SizeUnknown() -Base.IteratorEltype(::Type{<:TapedTask}) = Base.EltypeUnknown() - - -# copy the task - -function Base.copy(t::TapedTask; args=()) - length(args) > 0 && t.tf.counter >1 && - error("can't copy started task with new arguments") - tf = copy(t.tf) - task_args = if length(args) > 0 - # this cond implies t.tf.counter == 0, i.e., the task is not started yet - typeof(args) == typeof(t.args) || error("bad arguments") - args - else - if t.tf.counter > 1 - # the task is running, we find the real args from the copied binding_values - map(1:length(t.args)) do i - s = i + 1 - _arg(tf, s; default=t.args[i]) - end - else - # the task is not started yet, but no args is given - map(a -> _tape_copy(a, t.tf.deepcopy_types), t.args) - end - end - new_t = TapedTask(tf, task_args...) - storage = t.task.storage::IdDict{Any,Any} - new_t.task.storage = copy(storage) - new_t.task.storage[:tapedtask] = new_t - return new_t -end - -# TArray and TRef back-compat -function TArray(args...) - Base.depwarn("`TArray` is deprecated, please use `Array` instead.", :TArray) - Array(args...) -end -function TArray(T::Type, dim) - Base.depwarn("`TArray` is deprecated, please use `Array` instead.", :TArray) - Array{T}(undef, dim) -end -function tzeros(args...) - Base.depwarn("`tzeros` is deprecated, please use `zeros` instead.", :tzeros) - zeros(args...) -end -function tfill(args...) - Base.depwarn("`tfill` is deprecated, please use `fill` instead.", :tzeros) - fill(args...) -end -function TRef(x) - Base.depwarn("`TRef` is deprecated, please use `Ref` instead.", :TRef) - Ref(x) -end diff --git a/test/tape_copy.jl b/test/tape_copy.jl deleted file mode 100644 index 6edc3f7c..00000000 --- a/test/tape_copy.jl +++ /dev/null @@ -1,195 +0,0 @@ -@testset "tape copy" begin - # Test case 1: stack allocated objects are deep copied. - @testset "stack allocated objects shallow copy" begin - function f() - t = 0 - while true - produce(t) - t = 1 + t - end - end - - ttask = TapedTask(f) - @test consume(ttask) == 0 - @test consume(ttask) == 1 - a = copy(ttask) - @test consume(a) == 2 - @test consume(a) == 3 - @test consume(ttask) == 2 - @test consume(ttask) == 3 - - @inferred Libtask.TapedFunction(f) - end - - # Test case 2: Array objects are deeply copied. - @testset "Array objects deep copy" begin - function f() - t = [0 1 2] - while true - produce(t[1]) - t[1] = 1 + t[1] - end - end - - ttask = TapedTask(f) - @test consume(ttask) == 0 - @test consume(ttask) == 1 - a = copy(ttask) - @test consume(a) == 2 - @test consume(a) == 3 - @test consume(ttask) == 2 - @test consume(ttask) == 3 - @test consume(ttask) == 4 - @test consume(ttask) == 5 - end - - # Test case 3: Dict objects are shallowly copied. - @testset "Dict objects shallow copy" begin - function f() - t = Dict(1=>10, 2=>20) - while true - produce(t[1]) - t[1] = 1 + t[1] - end - end - - ttask = TapedTask(f) - - @test consume(ttask) == 10 - @test consume(ttask) == 11 - - a = copy(ttask) - @test consume(a) == 12 - @test consume(a) == 13 - - @test consume(ttask) == 14 - @test consume(ttask) == 15 - end - - @testset "Array deep copy 2" begin - function f() - t = Array{Int}(undef, 1) - t[1] = 0 - while true - produce(t[1]) - t[1] - t[1] = 1 + t[1] - end - end - - ttask = TapedTask(f) - - consume(ttask) - consume(ttask) - a = copy(ttask) - consume(a) - consume(a) - - @test consume(ttask) == 2 - @test consume(a) == 4 - - DATA = Dict{Task, Array}() - function g() - ta = zeros(UInt64, 4) - for i in 1:4 - ta[i] = hash(current_task()) - DATA[current_task()] = ta - produce(ta[i]) - end - end - - ttask = TapedTask(g) - @test consume(ttask) == hash(ttask.task) # index = 1 - @test consume(ttask) == hash(ttask.task) # index = 2 - - a = copy(ttask) - @test consume(a) == hash(a.task) # index = 3 - @test consume(a) == hash(a.task) # index = 4 - - @test consume(ttask) == hash(ttask.task) # index = 3 - - @test DATA[ttask.task] == [hash(ttask.task), hash(ttask.task), hash(ttask.task), 0] - @test DATA[a.task] == [hash(ttask.task), hash(ttask.task), hash(a.task), hash(a.task)] - end - - # Test atomic values. - @testset "ref atomic" begin - function f() - t = Ref(1) - t[] = 0 - for _ in 1:6 - produce(t[]) - t[] - t[] += 1 - end - end - - ctask = TapedTask(f) - - consume(ctask) - consume(ctask) - - a = copy(ctask) - consume(a) - consume(a) - - @test consume(ctask) == 2 - @test consume(a) == 4 - end - - @testset "ref of dictionary deep copy" begin - function f() - t = Ref(Dict("A" => 1, 5 => "B")) - t[]["A"] = 0 - for _ in 1:6 - produce(t[]["A"]) - t[]["A"] += 1 - end - end - - ctask = TapedTask(f) - - consume(ctask) - consume(ctask) - - a = copy(ctask) - consume(a) - consume(a) - - @test consume(ctask) == 2 - @test consume(a) == 4 - end - - @testset "ref of array deep copy" begin - # Create a TRef storing a matrix. - x = TRef([1 2 3; 4 5 6]) - x[][1, 3] = 900 - @test x[][1,3] == 900 - - # TRef holding an array. - y = TRef([1,2,3]) - y[][2] = 19 - @test y[][2] == 19 - end - - @testset "override deepcopy_types #57" begin - struct DummyType end - - function f(start::Int) - t = [start] - while true - produce(t[1]) - t[1] = 1 + t[1] - end - end - - ttask = TapedTask(f, 0; deepcopy_types=DummyType) - consume(ttask) - - ttask2 = copy(ttask) - consume(ttask2) - - @test consume(ttask) == 1 - @test consume(ttask2) == 2 - end -end diff --git a/test/tapedtask.jl b/test/tapedtask.jl deleted file mode 100644 index f55f83e1..00000000 --- a/test/tapedtask.jl +++ /dev/null @@ -1,159 +0,0 @@ -@testset "tapedtask" begin - @testset "construction" begin - function f() - t = 1 - while true - produce(t) - t = 1 + t - end - end - - ttask = TapedTask(f) - @test consume(ttask) == 1 - - ttask = TapedTask((f, Union{})) - @test consume(ttask) == 1 - end - - @testset "iteration" begin - function f() - t = 1 - while true - produce(t) - t = 1 + t - end - end - - ttask = TapedTask(f) - - next = iterate(ttask) - @test next === (1, nothing) - - val, state = next - next = iterate(ttask, state) - @test next === (2, nothing) - - val, state = next - next = iterate(ttask, state) - @test next === (3, nothing) - - a = collect(Iterators.take(ttask, 7)) - @test eltype(a) === Int - @test a == 4:10 - end - - # Test of `Exception`. - @testset "Exception" begin - @testset "method error" begin - function f() - t = 0 - while true - t[3] = 1 - produce(t) - t = t + 1 - end - end - - ttask = TapedTask(f) - try - consume(ttask) - catch ex - @test ex isa MethodError - end - if VERSION >= v"1.5" - @test ttask.task.exception isa MethodError - end - end - - @testset "error test" begin - function f() - x = 1 - while true - error("error test") - produce(x) - x += 1 - end - end - - ttask = TapedTask(f) - try - consume(ttask) - catch ex - @test ex isa ErrorException - end - if VERSION >= v"1.5" - @test ttask.task.exception isa ErrorException - end - end - - @testset "OutOfBounds Test Before" begin - function f() - x = zeros(2) - while true - x[1] = 1 - x[2] = 2 - x[3] = 3 - produce(x[1]) - end - end - - ttask = TapedTask(f) - try - consume(ttask) - catch ex - @test ex isa BoundsError - end - if VERSION >= v"1.5" - @test ttask.task.exception isa BoundsError - end - end - - @testset "OutOfBounds Test After `produce`" begin - function f() - x = zeros(2) - while true - x[1] = 1 - x[2] = 2 - produce(x[2]) - x[3] = 3 - end - end - - ttask = TapedTask(f) - @test consume(ttask) == 2 - try - consume(ttask) - catch ex - @test ex isa BoundsError - end - if VERSION >= v"1.5" - @test ttask.task.exception isa BoundsError - end - end - - @testset "OutOfBounds Test After `copy`" begin - function f() - x = zeros(2) - while true - x[1] = 1 - x[2] = 2 - produce(x[2]) - x[3] = 3 - end - end - - ttask = TapedTask(f) - @test consume(ttask) == 2 - ttask2 = copy(ttask) - try - consume(ttask2) - catch ex - @test ex isa BoundsError - end - @test ttask.task.exception === nothing - if VERSION >= v"1.5" - @test ttask2.task.exception isa BoundsError - end - end - end -end From 5228dc2fe3f01186ea61b9fc6d117a82fcad26a3 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 27 Feb 2025 16:23:34 +0000 Subject: [PATCH 08/69] Rework includes in package and runtests --- src/Libtask.jl | 23 ++--------------------- test/runtests.jl | 4 +--- 2 files changed, 3 insertions(+), 24 deletions(-) diff --git a/src/Libtask.jl b/src/Libtask.jl index 8fa79533..51b60f95 100644 --- a/src/Libtask.jl +++ b/src/Libtask.jl @@ -1,26 +1,7 @@ module Libtask -using FunctionWrappers: FunctionWrapper -using LRUCache +include("copyable_task.jl") -export TapedTask, consume, produce - -export TArray, tzeros, tfill, TRef # legacy types back compat - - -@static if isdefined(Core, :TypedSlot) || isdefined(Core.Compiler, :TypedSlot) - # Julia v1.10 moved Core.TypedSlot to Core.Compiler - # Julia v1.11 removed Core.Compiler.TypedSlot - const TypedSlot = @static if isdefined(Core, :TypedSlot) - Core.TypedSlot - else - Core.Compiler.TypedSlot - end -else - struct TypedSlot end # Dummy -end - -include("tapedfunction.jl") -include("tapedtask.jl") +export CopyableTask, consume, produce end diff --git a/test/runtests.jl b/test/runtests.jl index a045454f..f0e021e5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,9 +1,7 @@ using Libtask using Test -include("tf.jl") -include("tapedtask.jl") -include("tape_copy.jl") +include("copyable_task.jl") include("issues.jl") if haskey(ENV, "BENCHMARK") From fc0173224e17004d0fc477a5925bc48629c17180 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 27 Feb 2025 16:24:19 +0000 Subject: [PATCH 09/69] More formatting --- src/test_resources.jl | 3 ++- test/issues.jl | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/test_resources.jl b/src/test_resources.jl index d05e7f0b..932733ae 100644 --- a/src/test_resources.jl +++ b/src/test_resources.jl @@ -10,7 +10,7 @@ end # Old test case without any produce statements. Might make sense to ensure that something # vaguely like this is included in the test suite, but isn't directly relevant. function g(x, y) - if x>y + if x > y r = string(sin(x)) else r = sin(x) * cos(y) @@ -21,6 +21,7 @@ end # Old test case -- github.com/TuringLang/Libtask.jl/issues/148, unused argument function f(x) produce(1) + return nothing end # Old test case. Probably redundant, but makes sense to check. Might want to replace the diff --git a/test/issues.jl b/test/issues.jl index f534d0b1..d649a151 100644 --- a/test/issues.jl +++ b/test/issues.jl @@ -41,5 +41,4 @@ ttask3 = copy(ttask; args=(4,)) @test consume(ttask3) == 5 end - end From 5ddfb6058823d79cece5f3da1beb205daaf80aa4 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 27 Feb 2025 16:24:27 +0000 Subject: [PATCH 10/69] Add in copyable_task files --- src/copyable_task.jl | 1 + test/copyable_task.jl | 357 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 358 insertions(+) create mode 100644 src/copyable_task.jl create mode 100644 test/copyable_task.jl diff --git a/src/copyable_task.jl b/src/copyable_task.jl new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/src/copyable_task.jl @@ -0,0 +1 @@ + diff --git a/test/copyable_task.jl b/test/copyable_task.jl new file mode 100644 index 00000000..bd401a71 --- /dev/null +++ b/test/copyable_task.jl @@ -0,0 +1,357 @@ +@testset "copyable_task" begin + @testset "construction" begin + function f() + t = 1 + while true + produce(t) + t = 1 + t + end + end + + ttask = TapedTask(f) + @test consume(ttask) == 1 + + ttask = TapedTask((f, Union{})) + @test consume(ttask) == 1 + end + + @testset "iteration" begin + function f() + t = 1 + while true + produce(t) + t = 1 + t + end + end + + ttask = TapedTask(f) + + next = iterate(ttask) + @test next === (1, nothing) + + val, state = next + next = iterate(ttask, state) + @test next === (2, nothing) + + val, state = next + next = iterate(ttask, state) + @test next === (3, nothing) + + a = collect(Iterators.take(ttask, 7)) + @test eltype(a) === Int + @test a == 4:10 + end + + # Test of `Exception`. + @testset "Exception" begin + @testset "method error" begin + function f() + t = 0 + while true + t[3] = 1 + produce(t) + t = t + 1 + end + end + + ttask = TapedTask(f) + try + consume(ttask) + catch ex + @test ex isa MethodError + end + if VERSION >= v"1.5" + @test ttask.task.exception isa MethodError + end + end + + @testset "error test" begin + function f() + x = 1 + while true + error("error test") + produce(x) + x += 1 + end + end + + ttask = TapedTask(f) + try + consume(ttask) + catch ex + @test ex isa ErrorException + end + if VERSION >= v"1.5" + @test ttask.task.exception isa ErrorException + end + end + + @testset "OutOfBounds Test Before" begin + function f() + x = zeros(2) + while true + x[1] = 1 + x[2] = 2 + x[3] = 3 + produce(x[1]) + end + end + + ttask = TapedTask(f) + try + consume(ttask) + catch ex + @test ex isa BoundsError + end + if VERSION >= v"1.5" + @test ttask.task.exception isa BoundsError + end + end + + @testset "OutOfBounds Test After `produce`" begin + function f() + x = zeros(2) + while true + x[1] = 1 + x[2] = 2 + produce(x[2]) + x[3] = 3 + end + end + + ttask = TapedTask(f) + @test consume(ttask) == 2 + try + consume(ttask) + catch ex + @test ex isa BoundsError + end + if VERSION >= v"1.5" + @test ttask.task.exception isa BoundsError + end + end + + @testset "OutOfBounds Test After `copy`" begin + function f() + x = zeros(2) + while true + x[1] = 1 + x[2] = 2 + produce(x[2]) + x[3] = 3 + end + end + + ttask = TapedTask(f) + @test consume(ttask) == 2 + ttask2 = copy(ttask) + try + consume(ttask2) + catch ex + @test ex isa BoundsError + end + @test ttask.task.exception === nothing + if VERSION >= v"1.5" + @test ttask2.task.exception isa BoundsError + end + end + end + + @testset "copying" begin + # Test case 1: stack allocated objects are deep copied. + @testset "stack allocated objects shallow copy" begin + function f() + t = 0 + while true + produce(t) + t = 1 + t + end + end + + ttask = TapedTask(f) + @test consume(ttask) == 0 + @test consume(ttask) == 1 + a = copy(ttask) + @test consume(a) == 2 + @test consume(a) == 3 + @test consume(ttask) == 2 + @test consume(ttask) == 3 + + @inferred Libtask.TapedFunction(f) + end + + # Test case 2: Array objects are deeply copied. + @testset "Array objects deep copy" begin + function f() + t = [0 1 2] + while true + produce(t[1]) + t[1] = 1 + t[1] + end + end + + ttask = TapedTask(f) + @test consume(ttask) == 0 + @test consume(ttask) == 1 + a = copy(ttask) + @test consume(a) == 2 + @test consume(a) == 3 + @test consume(ttask) == 2 + @test consume(ttask) == 3 + @test consume(ttask) == 4 + @test consume(ttask) == 5 + end + + # Test case 3: Dict objects are shallowly copied. + @testset "Dict objects shallow copy" begin + function f() + t = Dict(1 => 10, 2 => 20) + while true + produce(t[1]) + t[1] = 1 + t[1] + end + end + + ttask = TapedTask(f) + + @test consume(ttask) == 10 + @test consume(ttask) == 11 + + a = copy(ttask) + @test consume(a) == 12 + @test consume(a) == 13 + + @test consume(ttask) == 14 + @test consume(ttask) == 15 + end + + @testset "Array deep copy 2" begin + function f() + t = Array{Int}(undef, 1) + t[1] = 0 + while true + produce(t[1]) + t[1] + t[1] = 1 + t[1] + end + end + + ttask = TapedTask(f) + + consume(ttask) + consume(ttask) + a = copy(ttask) + consume(a) + consume(a) + + @test consume(ttask) == 2 + @test consume(a) == 4 + + DATA = Dict{Task,Array}() + function g() + ta = zeros(UInt64, 4) + for i in 1:4 + ta[i] = hash(current_task()) + DATA[current_task()] = ta + produce(ta[i]) + end + end + + ttask = TapedTask(g) + @test consume(ttask) == hash(ttask.task) # index = 1 + @test consume(ttask) == hash(ttask.task) # index = 2 + + a = copy(ttask) + @test consume(a) == hash(a.task) # index = 3 + @test consume(a) == hash(a.task) # index = 4 + + @test consume(ttask) == hash(ttask.task) # index = 3 + + @test DATA[ttask.task] == + [hash(ttask.task), hash(ttask.task), hash(ttask.task), 0] + @test DATA[a.task] == + [hash(ttask.task), hash(ttask.task), hash(a.task), hash(a.task)] + end + + # Test atomic values. + @testset "ref atomic" begin + function f() + t = Ref(1) + t[] = 0 + for _ in 1:6 + produce(t[]) + t[] + t[] += 1 + end + end + + ctask = TapedTask(f) + + consume(ctask) + consume(ctask) + + a = copy(ctask) + consume(a) + consume(a) + + @test consume(ctask) == 2 + @test consume(a) == 4 + end + + @testset "ref of dictionary deep copy" begin + function f() + t = Ref(Dict("A" => 1, 5 => "B")) + t[]["A"] = 0 + for _ in 1:6 + produce(t[]["A"]) + t[]["A"] += 1 + end + end + + ctask = TapedTask(f) + + consume(ctask) + consume(ctask) + + a = copy(ctask) + consume(a) + consume(a) + + @test consume(ctask) == 2 + @test consume(a) == 4 + end + + @testset "ref of array deep copy" begin + # Create a TRef storing a matrix. + x = TRef([1 2 3; 4 5 6]) + x[][1, 3] = 900 + @test x[][1, 3] == 900 + + # TRef holding an array. + y = TRef([1, 2, 3]) + y[][2] = 19 + @test y[][2] == 19 + end + + @testset "override deepcopy_types #57" begin + struct DummyType end + + function f(start::Int) + t = [start] + while true + produce(t[1]) + t[1] = 1 + t[1] + end + end + + ttask = TapedTask(f, 0; deepcopy_types=DummyType) + consume(ttask) + + ttask2 = copy(ttask) + consume(ttask2) + + @test consume(ttask) == 1 + @test consume(ttask2) == 2 + end + end +end From 159fe13a939bc06a00309b646e25b86bc9f2de45 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 27 Feb 2025 16:26:34 +0000 Subject: [PATCH 11/69] Clean up runtests --- test/runtests.jl | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index f0e021e5..8235535d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,9 +1,7 @@ -using Libtask -using Test +using JuliaFormatter, Libtask, Test -include("copyable_task.jl") -include("issues.jl") +@testset "Libtask" begin -if haskey(ENV, "BENCHMARK") - include("benchmarks.jl") + include("copyable_task.jl") + include("issues.jl") end From 61594e4bf50e210c6dd2f788f103d186b6fc5901 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 27 Feb 2025 16:30:01 +0000 Subject: [PATCH 12/69] Update project deps etc --- Project.toml | 17 ++++++++--------- perf/Project.toml | 2 +- test/runtests.jl | 5 ++++- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/Project.toml b/Project.toml index d683d9b2..4a62edfe 100644 --- a/Project.toml +++ b/Project.toml @@ -6,19 +6,18 @@ repo = "https://github.com/TuringLang/Libtask.jl.git" version = "0.9.0" [deps] -FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" -LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" [compat] -FunctionWrappers = "1.1" -LRUCache = "1.3" -julia = "1.7" +Aqua = "0.8.11" +JuliaFormatter = "1.0.62" +Mooncake = "0.4.99" +julia = "1.10.8" [extras] -BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "BenchmarkTools"] +test = ["Aqua", "JuliaFormatter", "Test"] diff --git a/perf/Project.toml b/perf/Project.toml index 9e9ab49b..6522964d 100644 --- a/perf/Project.toml +++ b/perf/Project.toml @@ -8,7 +8,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [compat] -julia = "1.3" +julia = "1.10.8" [targets] test = ["Test", "BenchmarkTools"] diff --git a/test/runtests.jl b/test/runtests.jl index 8235535d..309d498d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,7 +1,10 @@ using JuliaFormatter, Libtask, Test @testset "Libtask" begin - + @testset "quality" begin + Aqua.test_all(Libtask) + @test JuliaFormatter.format(Mooncake; verbose=false, overwrite=false) + end include("copyable_task.jl") include("issues.jl") end From ae6d00017ebbdfc950f9c8fd2d3a8186970bf914 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 27 Feb 2025 16:30:35 +0000 Subject: [PATCH 13/69] Drop 1.7 from CI --- .github/workflows/Testing.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/Testing.yaml b/.github/workflows/Testing.yaml index 728f92e1..2b8b8685 100644 --- a/.github/workflows/Testing.yaml +++ b/.github/workflows/Testing.yaml @@ -11,7 +11,6 @@ jobs: strategy: matrix: version: - - '1.7' - '1.10' - '1' - 'nightly' From dc4a0becc39fc198dc2ee3b3172ca2cace65ed66 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 27 Feb 2025 17:40:54 +0000 Subject: [PATCH 14/69] Initial transfer of code --- Project.toml | 4 + src/Libtask.jl | 13 + src/copyable_task.jl | 351 +++++++++++++++++++++ src/test_resources.jl | 41 --- src/test_utils.jl | 110 +++++++ test/copyable_task.jl | 711 +++++++++++++++++++++--------------------- test/front_matter.jl | 1 + test/runtests.jl | 7 +- 8 files changed, 839 insertions(+), 399 deletions(-) delete mode 100644 src/test_resources.jl create mode 100644 src/test_utils.jl create mode 100644 test/front_matter.jl diff --git a/Project.toml b/Project.toml index 4a62edfe..1df3b546 100644 --- a/Project.toml +++ b/Project.toml @@ -6,12 +6,16 @@ repo = "https://github.com/TuringLang/Libtask.jl.git" version = "0.9.0" [deps] +MistyClosures = "dbe65cb8-6be2-42dd-bbc5-4196aaced4f4" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] Aqua = "0.8.11" JuliaFormatter = "1.0.62" +MistyClosures = "2.0.0" Mooncake = "0.4.99" +Test = "1.11.0" julia = "1.10.8" [extras] diff --git a/src/Libtask.jl b/src/Libtask.jl index 51b60f95..36abd4bd 100644 --- a/src/Libtask.jl +++ b/src/Libtask.jl @@ -1,6 +1,19 @@ module Libtask +# Need this for BBCode. +using Mooncake +using Mooncake: BBCode, BBlock, ID, new_inst, stmt, seed_id! +using Mooncake: IDGotoIfNot, IDGotoNode, IDPhiNode, Switch + +# We'll emit `MistyClosure`s rather than `OpaqueClosure`s. +using MistyClosures + +# Import some names from the compiler. +const CC = Core.Compiler +using Core.Compiler: Argument, IRCode, ReturnNode + include("copyable_task.jl") +include("test_utils.jl") export CopyableTask, consume, produce diff --git a/src/copyable_task.jl b/src/copyable_task.jl index 8b137891..49431942 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -1 +1,352 @@ +__v::Int = 5 +@noinline function produce(x) + global __v = 4 + return nothing +end +mutable struct CopyableTask{Tmc<:MistyClosure,Targs} + const mc::Tmc + args::Targs + const position::Base.RefValue{Int32} +end + +@inline consume(t::CopyableTask) = t.mc(t.args...) + +function initialise!(t::CopyableTask, args::Vararg{Any,N})::Nothing where {N} + t.position[] = -1 + t.args = args + return nothing +end + +function CopyableTask(fargs...) + sig = typeof(fargs) + mc, count_ref = build_callable(Base.code_ircode_by_type(sig)[1][1]) + return CopyableTask(mc, fargs[2:end], count_ref) +end + +function build_callable(ir::IRCode) + seed_id!() + bb, refs = derive_copyable_task_ir(BBCode(ir)) + ir = IRCode(bb) + optimised_ir = Mooncake.optimise_ir!(ir) + return MistyClosure(optimised_ir, refs...; do_compile=true), refs[end] +end + +""" + might_produce(sig::Type{<:Tuple})::Bool + +`true` if a call to method with signature `sig` is permitted to contain +`CopyableTasks.produce` statements. + +This is an opt-in mechanism. the fallback method of this function returns `false` indicating +that, by default, we assume that calls do not contain `CopyableTasks.produce` statements. +""" +might_produce(::Type{<:Tuple}) = false + +# Helper struct used in `derive_copyable_task_ir`. +struct TupleRef + n::Int +end + +# Unclear whether this is needed. +get_value(x::GlobalRef) = getglobal(x.mod, x.name) +get_value(x::QuoteNode) = x.value +get_value(x) = x + +""" + is_produce_stmt(x)::Bool + +`true` if `x` is an expression of the form `Expr(:call, produce, %x)` or a similar `:invoke` +expression, otherwise `false`. +""" +function is_produce_stmt(x)::Bool + if Meta.isexpr(x, :invoke) && length(x.args) == 3 + return get_value(x.args[2]) === produce + elseif Meta.isexpr(x, :call) && length(x.args) == 2 + return get_value(x.args[1]) === produce + else + return false + end +end + +""" + produce_value(x::Expr) + +Returns the value that a `produce` statement returns. For example, for the statment +`produce(%x)`, this function will return `%x`. +""" +function produce_value(x::Expr) + is_produce_stmt(x) || throw(error("Not a produce statement. Please report this error.")) + Meta.isexpr(x, :invoke) && return x.args[3] + return x.args[2] # must be a `:call` Expr. +end + +""" + derive_copyable_task_ir(ir::IRCode)::IRCode + + +""" +function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} + + # Replace all existing `ReturnNode`s with `ReturnNode(nothing)` in order to provide the + # same semantics as `Libtask`. + for bb in ir.blocks + for (n, inst) in enumerate(bb.insts) + stmt = inst.stmt + if stmt isa ReturnNode + bb.insts[n] = new_inst(ReturnNode(nothing)) + end + end + end + + # The location at which `refs` will be stored. + refs_id = Argument(1) + + # Mapping in which each key-value pairs says: "if we exited from block `key`, we must + # resume by jumping to basic block `value`". + resume_block_ids = Dict{ID,ID}() + + # For each basic block `bb`: + # - count the number of produce statements, `n_produce`. + # - construct `n_produce + 1` new basic blocks. The 1st new basic block runs from the + # first stmt in `bb` to the first `produce(%x)` statement (inclusive), the second + # from the next statement after the first `produce(%x)` statement until the next + # `produce(%x)` statement, etc. The final new basic block runs from the statement + # following the final `produce(%x)` statment, until the end of `bb`. + # Furthermore, each `produce(%x)` statement is replaced with a `ReturnNode(%x)`. + # We log the `ID`s of each of these new basic blocks, for use later. + new_bblocks = map(ir.blocks) do bb + + # Find all of the `produce` statements. + produce_indices = findall(x -> is_produce_stmt(x.stmt), bb.insts) + terminator_indices = vcat(produce_indices, length(bb)) + + # TODO: WHAT HAPPENS IF THERE ARE NO PRODUCE STATEMENTS? + # TODO: WHAT HAPPENS IF THE PRODUCE STATEMENT IS A FALLTHROUGH TERMINATOR????? + + # The `ID`s of the new basic blocks. + new_block_ids = vcat(bb.id, [ID() for _ in produce_indices]) + + # Construct `n_produce + 1` new basic blocks. The first basic block retains the + # `ID` of `bb`, the remaining `n_produce + 1` blocks get new `ID`s (which we log). + # All `produce(%x)` statements are replaced with `Return(%x)` statements. + return map(enumerate(terminator_indices)) do (n, term_ind) + + # The first new block has the same `ID` as `bb`. The others gets new ones. + block_id = new_block_ids[n] + + # Pull out the instructions and their `ID`s for the new block. + start_ind = n == 1 ? 1 : terminator_indices[n - 1] + 1 + inst_ids = bb.inst_ids[start_ind:term_ind] + insts = bb.insts[start_ind:term_ind] + + # If n < length(terminator_indices) then it must end with a `produce` statement. + # In this case, we replace the `produce(%x)` statement with a call to set the + # `resume_block` to the next block, which ensures that execution jumps to the + # statement immediately following this `produce(%x)` statement next time the + # function is called. We also insert a `ReturnNode(%x)` i.e. to implement the + # `produce` statement. + # Also log the mapping between the current new block ID, and the ID of the block + # we should resume to. + if n < length(terminator_indices) + resume_id = new_block_ids[n + 1] + resume_block_ids[block_id] = resume_id + set_resume = Expr(:call, set_resume_block!, refs_id, resume_id.id) + return_node = ReturnNode(produce_value(insts[end].stmt)) + inst_ids = vcat(inst_ids[1:(end - 1)], [ID(), ID()]) # actual ID values are irrelevant (no uses). + insts = vcat( + insts[1:(end - 1)], [new_inst(set_resume), new_inst(return_node)] + ) + end + + # Construct + return new basic block. + return BBlock(block_id, inst_ids, insts) + end + end + new_bblocks = reduce(vcat, new_bblocks) + + # Construct map between SSA IDs and their index in the state data structure and back. + # Optimisation TODO: don't create an entry for literally every line in the IR, just the + # ones which produce values that might be needed later. + ssa_id_to_ref_index_map = Dict{ID,Int}() + ref_index_to_ssa_id_map = Dict{Int,ID}() + ref_index_to_type_map = Dict{Int,Type}() + n = 0 + for bb in new_bblocks + for (id, stmt) in zip(bb.inst_ids, bb.insts) + stmt.stmt isa IDGotoNode && continue + stmt.stmt isa IDGotoIfNot && continue + stmt.stmt === nothing && continue + stmt.stmt isa ReturnNode && continue + n += 1 + ssa_id_to_ref_index_map[id] = n + ref_index_to_ssa_id_map[n] = id + ref_index_to_type_map[n] = stmt.type + end + end + + # Specify data structure containing `Ref`s for all of the SSAs. + # Optimisation TODO: permit users to construct custom data structures to make their + # lives involve less indirection. + # Optimisation TODO: make there be only one `Ref` per basic block, and only write to it + # at the end of basic block execution (or something like that). Probably need to base + # this on what the basic blocks _will_ _be_ after we've transformed everything, so need + # to figure out when this can happen. + _refs = map(p -> Ref{ref_index_to_type_map[p]}(), 1:length(ref_index_to_ssa_id_map)) + refs = (_refs..., Ref{Int32}(-1)) + + # For each instruction in each basic block, replace it with a call to the refs. + new_bblocks = map(new_bblocks) do bb + inst_pairs = Mooncake.IDInstPair[] + + # + # Handle all other nodes in the block. + # + + foreach(zip(bb.inst_ids, bb.insts)) do (id, inst) + stmt = inst.stmt + if Meta.isexpr(stmt, :invoke) || Meta.isexpr(stmt, :call) + + # Skip over set_resume_block! statements inserted in the previous pass. + if stmt.args[1] == set_resume_block! + push!(inst_pairs, (id, inst)) + return nothing + end + + # Find any `ID`s and replace them with calls to read whatever is stored in + # the `Ref`s that they are associated to. + for (n, arg) in enumerate(stmt.args) + arg isa ID || continue + + new_id = ID() + ref_ind = ssa_id_to_ref_index_map[arg] + expr = Expr(:call, get_ref_at, refs_id, ref_ind) + push!(inst_pairs, (new_id, new_inst(expr))) + stmt.args[n] = new_id + end + + # Push the target instruction to the list. + push!(inst_pairs, (id, inst)) + + # Push the result to its `Ref`. + out_ind = ssa_id_to_ref_index_map[id] + set_ref = Expr(:call, set_ref_at!, refs_id, out_ind, id) + push!(inst_pairs, (ID(), new_inst(set_ref))) + elseif Meta.isexpr(stmt, :new) + push!(inst_pairs, (id, inst)) + elseif stmt isa ReturnNode + push!(inst_pairs, (id, inst)) + elseif stmt isa IDGotoNode + push!(inst_pairs, (id, inst)) + elseif stmt isa IDGotoIfNot + push!(inst_pairs, (id, inst)) + elseif stmt isa IDPhiNode + # we'll fix up the PhiNodes after this, so identity transform for now. + push!(inst_pairs, (id, inst)) + elseif stmt isa Nothing + push!(inst_pairs, (id, inst)) + else + throw(error("Unhandled stmt $stmt")) + end + end + + # + # Handle `(ID)PhiNode`s. + # + + phi_inds = findall(x -> x.stmt isa IDPhiNode, bb.insts) + phi_inst_pairs = Mooncake.IDInstPair[] + + # Replace SSA IDs with `TupleRef`s, and record these instructions. + phi_ids = map(phi_inds) do n + phi = bb.insts[n].stmt + for i in eachindex(phi.values) + isassigned(phi.values, i) || continue + v = phi.values[i] + v isa ID || continue + phi.values[i] = TupleRef(ssa_id_to_ref_index_map[v]) + end + phi_id = ID() + push!(phi_inst_pairs, (phi_id, new_inst(phi, Any))) + return phi_id + end + + # De-reference values associated to `IDPhiNode`s. + deref_ids = map(phi_inds) do n + id = bb.inst_ids[n] + phi_id = phi_ids[n] + + # # Re-reference the PhiNode. + # n_id = ID() + # push!(phi_inst_pairs, (n_id, new_inst(Expr(:call, getfield, phi_id, :n)))) + # ref_id = ID() + # push!(phi_inst_pairs, (ref_id, new_inst(Expr(:call, getfield, refs_id, n_id)))) + # push!(phi_inst_pairs, (id, new_inst(Expr(:call, getfield, ref_id, :x)))) + + push!(phi_inst_pairs, (id, new_inst(Expr(:call, deref_phi, refs_id, phi_id)))) + return id + end + + # Update values stored in `Ref`s associated to `PhiNode`s. + for n in phi_inds + ref_ind = ssa_id_to_ref_index_map[bb.inst_ids[n]] + expr = Expr(:call, set_ref_at!, refs_id, ref_ind, deref_ids[n]) + push!(phi_inst_pairs, (ID(), new_inst(expr))) + end + + # Concatenate new phi stmts, removing old ones. + inst_pairs = vcat(phi_inst_pairs, inst_pairs[(length(phi_inds) + 1):end]) + + return BBlock(bb.id, inst_pairs) + end + + # Insert statements at the top. + cases = map(collect(resume_block_ids)) do (pred, succ) + return ID(), succ, Expr(:call, resume_block_is, refs_id, succ.id) + end + cond_ids = ID[x[1] for x in cases] + cond_dests = ID[x[2] for x in cases] + cond_stmts = Any[x[3] for x in cases] + switch_stmt = Switch(Any[x for x in cond_ids], cond_dests, first(new_bblocks).id) + entry_stmts = vcat(cond_stmts, switch_stmt) + entry_block = BBlock(ID(), vcat(cond_ids, ID()), map(new_inst, entry_stmts)) + new_bblocks = vcat(entry_block, new_bblocks) + + # New argtypes are the same as the old ones, except we have `Ref`s in the first argument + # rather than nothing at all. + new_argtypes = copy(ir.argtypes) + new_argtypes[1] = typeof(refs) + + # Return BBCode and the `Ref`s. + return BBCode(new_bblocks, new_argtypes, ir.sptypes, ir.linetable, ir.meta), refs +end + +# Helper used in `derive_copyable_task_ir`. +@inline get_ref_at(refs::R, n::Int) where {R<:Tuple} = refs[n][] + +# Helper used in `derive_copyable_task_ir`. +@inline function set_ref_at!(refs::R, n::Int, val) where {R<:Tuple} + refs[n][] = val + return nothing +end + +# Helper used in `derive_copyable_task_ir`. +@inline function set_resume_block!(refs::R, id::Int32) where {R<:Tuple} + refs[end][] = id + return nothing +end + +# Helper used in `derive_copyable_task_ir`. +@inline resume_block_is(refs::R, id::Int32) where {R<:Tuple} = !(refs[end][] === id) + +# Helper used in `derive_copyable_task_ir`. +@inline deref_phi(refs::R, n::TupleRef) where {R<:Tuple} = refs[n.n][] +@inline deref_phi(::R, x) where {R<:Tuple} = x + +# Implement iterator interface. +function Base.iterate(t::CopyableTask, state::Nothing=nothing) + v = consume(t) + return v === nothing ? nothing : (v, nothing) +end +Base.IteratorSize(::Type{<:CopyableTask}) = Base.SizeUnknown() +Base.IteratorEltype(::Type{<:CopyableTask}) = Base.EltypeUnknown() diff --git a/src/test_resources.jl b/src/test_resources.jl deleted file mode 100644 index 932733ae..00000000 --- a/src/test_resources.jl +++ /dev/null @@ -1,41 +0,0 @@ -module TestResources - -# Old test case without any produce statements used to test TapedFunction. Since this -# doesn't exist as a distinct entity anymore, not clear that this test case is useful. -mutable struct S - i::Int - S(x, y) = new(x + y) -end - -# Old test case without any produce statements. Might make sense to ensure that something -# vaguely like this is included in the test suite, but isn't directly relevant. -function g(x, y) - if x > y - r = string(sin(x)) - else - r = sin(x) * cos(y) - end - return r -end - -# Old test case -- github.com/TuringLang/Libtask.jl/issues/148, unused argument -function f(x) - produce(1) - return nothing -end - -# Old test case. Probably redundant, but makes sense to check. Might want to replace the -# final statement with a produce statement to make the test case meaningful. -function g(x, y) - c = x + y - return (; c, x, y) -end - -# Make sure I provide a test case in which a function contains consts. -function f() - # this line generates: %1 = 1::Core.Const(1) - r = (a = 1) - return nothing -end - -end diff --git a/src/test_utils.jl b/src/test_utils.jl new file mode 100644 index 00000000..b8736dbc --- /dev/null +++ b/src/test_utils.jl @@ -0,0 +1,110 @@ +module TestUtils + +using ..Libtask +using Test +using ..Libtask: CopyableTask + +struct Testcase + name::String + fargs::Tuple + expected_iteration_results::Vector +end + +function (case::Testcase)() + testset = @testset "$(case.name)" begin + + # Construct the task. + t = CopyableTask(case.fargs...) + + # Iterate through t. Record the results, and take a copy after each iteration. + iteration_results = [] + t_copies = [deepcopy(t)] + for val in t + push!(iteration_results, val) + push!(t_copies, deepcopy(t)) + end + + # Check that iterating the original task gives the expected results. + @test iteration_results == case.expected_iteration_results + + # Check that iterating the copies yields the correct results. + for (n, t_copy) in enumerate(t_copies) + @test iteration_results[n:end] == collect(t_copy) + end + end + return testset +end + +function test_cases() + return Testcase[Testcase( + "single block", + (single_block, 5.0), + [sin(5.0), sin(sin(5.0)), sin(sin(sin(5.0))), sin(sin(sin(sin(5.0))))], + ), + # Testcase("no produce", (no_produce_test, 5.0, 4.0), []), +] +end + +function single_block(x::Float64) + x1 = sin(x) + produce(x1) + x2 = sin(x1) + produce(x2) + x3 = sin(x2) + produce(x3) + x4 = sin(x3) + produce(x4) + return cos(x4) +end + +function no_produce_test(x, y) + c = x + y + return (; c, x, y) +end + +# Old test case without any produce statements used to test TapedFunction. Since this +# doesn't exist as a distinct entity anymore, not clear that this test case is useful. +mutable struct C + i::Int + C(x, y) = new(x + y) +end + +function new_object_test(x, y) + produce(C(x, y)) + return nothing +end + +function branching_test(x, y) + if x > y + r = string(sin(x)) + else + r = sin(x) * cos(y) + end + return r +end + +function unused_argument_test(x) + produce(1) + return nothing +end + +function test_with_const() + # this line generates: %1 = 1::Core.Const(1) + r = (a = 1) + return nothing +end + +@noinline function nested_inner() + produce(true) + return nothing +end + +might_produce(::Type{Tuple{typeof(nested_inner)}}) = true + +function nested_outer() + nested_inner() + produce(false) + return nothing +end + +end diff --git a/test/copyable_task.jl b/test/copyable_task.jl index bd401a71..312f987d 100644 --- a/test/copyable_task.jl +++ b/test/copyable_task.jl @@ -1,357 +1,360 @@ @testset "copyable_task" begin - @testset "construction" begin - function f() - t = 1 - while true - produce(t) - t = 1 + t - end - end - - ttask = TapedTask(f) - @test consume(ttask) == 1 - - ttask = TapedTask((f, Union{})) - @test consume(ttask) == 1 - end - - @testset "iteration" begin - function f() - t = 1 - while true - produce(t) - t = 1 + t - end - end - - ttask = TapedTask(f) - - next = iterate(ttask) - @test next === (1, nothing) - - val, state = next - next = iterate(ttask, state) - @test next === (2, nothing) - - val, state = next - next = iterate(ttask, state) - @test next === (3, nothing) - - a = collect(Iterators.take(ttask, 7)) - @test eltype(a) === Int - @test a == 4:10 - end - - # Test of `Exception`. - @testset "Exception" begin - @testset "method error" begin - function f() - t = 0 - while true - t[3] = 1 - produce(t) - t = t + 1 - end - end - - ttask = TapedTask(f) - try - consume(ttask) - catch ex - @test ex isa MethodError - end - if VERSION >= v"1.5" - @test ttask.task.exception isa MethodError - end - end - - @testset "error test" begin - function f() - x = 1 - while true - error("error test") - produce(x) - x += 1 - end - end - - ttask = TapedTask(f) - try - consume(ttask) - catch ex - @test ex isa ErrorException - end - if VERSION >= v"1.5" - @test ttask.task.exception isa ErrorException - end - end - - @testset "OutOfBounds Test Before" begin - function f() - x = zeros(2) - while true - x[1] = 1 - x[2] = 2 - x[3] = 3 - produce(x[1]) - end - end - - ttask = TapedTask(f) - try - consume(ttask) - catch ex - @test ex isa BoundsError - end - if VERSION >= v"1.5" - @test ttask.task.exception isa BoundsError - end - end - - @testset "OutOfBounds Test After `produce`" begin - function f() - x = zeros(2) - while true - x[1] = 1 - x[2] = 2 - produce(x[2]) - x[3] = 3 - end - end - - ttask = TapedTask(f) - @test consume(ttask) == 2 - try - consume(ttask) - catch ex - @test ex isa BoundsError - end - if VERSION >= v"1.5" - @test ttask.task.exception isa BoundsError - end - end - - @testset "OutOfBounds Test After `copy`" begin - function f() - x = zeros(2) - while true - x[1] = 1 - x[2] = 2 - produce(x[2]) - x[3] = 3 - end - end - - ttask = TapedTask(f) - @test consume(ttask) == 2 - ttask2 = copy(ttask) - try - consume(ttask2) - catch ex - @test ex isa BoundsError - end - @test ttask.task.exception === nothing - if VERSION >= v"1.5" - @test ttask2.task.exception isa BoundsError - end - end - end - - @testset "copying" begin - # Test case 1: stack allocated objects are deep copied. - @testset "stack allocated objects shallow copy" begin - function f() - t = 0 - while true - produce(t) - t = 1 + t - end - end - - ttask = TapedTask(f) - @test consume(ttask) == 0 - @test consume(ttask) == 1 - a = copy(ttask) - @test consume(a) == 2 - @test consume(a) == 3 - @test consume(ttask) == 2 - @test consume(ttask) == 3 - - @inferred Libtask.TapedFunction(f) - end - - # Test case 2: Array objects are deeply copied. - @testset "Array objects deep copy" begin - function f() - t = [0 1 2] - while true - produce(t[1]) - t[1] = 1 + t[1] - end - end - - ttask = TapedTask(f) - @test consume(ttask) == 0 - @test consume(ttask) == 1 - a = copy(ttask) - @test consume(a) == 2 - @test consume(a) == 3 - @test consume(ttask) == 2 - @test consume(ttask) == 3 - @test consume(ttask) == 4 - @test consume(ttask) == 5 - end - - # Test case 3: Dict objects are shallowly copied. - @testset "Dict objects shallow copy" begin - function f() - t = Dict(1 => 10, 2 => 20) - while true - produce(t[1]) - t[1] = 1 + t[1] - end - end - - ttask = TapedTask(f) - - @test consume(ttask) == 10 - @test consume(ttask) == 11 - - a = copy(ttask) - @test consume(a) == 12 - @test consume(a) == 13 - - @test consume(ttask) == 14 - @test consume(ttask) == 15 - end - - @testset "Array deep copy 2" begin - function f() - t = Array{Int}(undef, 1) - t[1] = 0 - while true - produce(t[1]) - t[1] - t[1] = 1 + t[1] - end - end - - ttask = TapedTask(f) - - consume(ttask) - consume(ttask) - a = copy(ttask) - consume(a) - consume(a) - - @test consume(ttask) == 2 - @test consume(a) == 4 - - DATA = Dict{Task,Array}() - function g() - ta = zeros(UInt64, 4) - for i in 1:4 - ta[i] = hash(current_task()) - DATA[current_task()] = ta - produce(ta[i]) - end - end - - ttask = TapedTask(g) - @test consume(ttask) == hash(ttask.task) # index = 1 - @test consume(ttask) == hash(ttask.task) # index = 2 - - a = copy(ttask) - @test consume(a) == hash(a.task) # index = 3 - @test consume(a) == hash(a.task) # index = 4 - - @test consume(ttask) == hash(ttask.task) # index = 3 - - @test DATA[ttask.task] == - [hash(ttask.task), hash(ttask.task), hash(ttask.task), 0] - @test DATA[a.task] == - [hash(ttask.task), hash(ttask.task), hash(a.task), hash(a.task)] - end - - # Test atomic values. - @testset "ref atomic" begin - function f() - t = Ref(1) - t[] = 0 - for _ in 1:6 - produce(t[]) - t[] - t[] += 1 - end - end - - ctask = TapedTask(f) - - consume(ctask) - consume(ctask) - - a = copy(ctask) - consume(a) - consume(a) - - @test consume(ctask) == 2 - @test consume(a) == 4 - end - - @testset "ref of dictionary deep copy" begin - function f() - t = Ref(Dict("A" => 1, 5 => "B")) - t[]["A"] = 0 - for _ in 1:6 - produce(t[]["A"]) - t[]["A"] += 1 - end - end - - ctask = TapedTask(f) - - consume(ctask) - consume(ctask) - - a = copy(ctask) - consume(a) - consume(a) - - @test consume(ctask) == 2 - @test consume(a) == 4 - end - - @testset "ref of array deep copy" begin - # Create a TRef storing a matrix. - x = TRef([1 2 3; 4 5 6]) - x[][1, 3] = 900 - @test x[][1, 3] == 900 - - # TRef holding an array. - y = TRef([1, 2, 3]) - y[][2] = 19 - @test y[][2] == 19 - end - - @testset "override deepcopy_types #57" begin - struct DummyType end - - function f(start::Int) - t = [start] - while true - produce(t[1]) - t[1] = 1 + t[1] - end - end - - ttask = TapedTask(f, 0; deepcopy_types=DummyType) - consume(ttask) - - ttask2 = copy(ttask) - consume(ttask2) - - @test consume(ttask) == 1 - @test consume(ttask2) == 2 - end + for case in Libtask.TestUtils.test_cases() + case() end + # @testset "construction" begin + # function f() + # t = 1 + # while true + # produce(t) + # t = 1 + t + # end + # end + + # ttask = TapedTask(f) + # @test consume(ttask) == 1 + + # ttask = TapedTask((f, Union{})) + # @test consume(ttask) == 1 + # end + + # @testset "iteration" begin + # function f() + # t = 1 + # while true + # produce(t) + # t = 1 + t + # end + # end + + # ttask = TapedTask(f) + + # next = iterate(ttask) + # @test next === (1, nothing) + + # val, state = next + # next = iterate(ttask, state) + # @test next === (2, nothing) + + # val, state = next + # next = iterate(ttask, state) + # @test next === (3, nothing) + + # a = collect(Iterators.take(ttask, 7)) + # @test eltype(a) === Int + # @test a == 4:10 + # end + + # # Test of `Exception`. + # @testset "Exception" begin + # @testset "method error" begin + # function f() + # t = 0 + # while true + # t[3] = 1 + # produce(t) + # t = t + 1 + # end + # end + + # ttask = TapedTask(f) + # try + # consume(ttask) + # catch ex + # @test ex isa MethodError + # end + # if VERSION >= v"1.5" + # @test ttask.task.exception isa MethodError + # end + # end + + # @testset "error test" begin + # function f() + # x = 1 + # while true + # error("error test") + # produce(x) + # x += 1 + # end + # end + + # ttask = TapedTask(f) + # try + # consume(ttask) + # catch ex + # @test ex isa ErrorException + # end + # if VERSION >= v"1.5" + # @test ttask.task.exception isa ErrorException + # end + # end + + # @testset "OutOfBounds Test Before" begin + # function f() + # x = zeros(2) + # while true + # x[1] = 1 + # x[2] = 2 + # x[3] = 3 + # produce(x[1]) + # end + # end + + # ttask = TapedTask(f) + # try + # consume(ttask) + # catch ex + # @test ex isa BoundsError + # end + # if VERSION >= v"1.5" + # @test ttask.task.exception isa BoundsError + # end + # end + + # @testset "OutOfBounds Test After `produce`" begin + # function f() + # x = zeros(2) + # while true + # x[1] = 1 + # x[2] = 2 + # produce(x[2]) + # x[3] = 3 + # end + # end + + # ttask = TapedTask(f) + # @test consume(ttask) == 2 + # try + # consume(ttask) + # catch ex + # @test ex isa BoundsError + # end + # if VERSION >= v"1.5" + # @test ttask.task.exception isa BoundsError + # end + # end + + # @testset "OutOfBounds Test After `copy`" begin + # function f() + # x = zeros(2) + # while true + # x[1] = 1 + # x[2] = 2 + # produce(x[2]) + # x[3] = 3 + # end + # end + + # ttask = TapedTask(f) + # @test consume(ttask) == 2 + # ttask2 = copy(ttask) + # try + # consume(ttask2) + # catch ex + # @test ex isa BoundsError + # end + # @test ttask.task.exception === nothing + # if VERSION >= v"1.5" + # @test ttask2.task.exception isa BoundsError + # end + # end + # end + + # @testset "copying" begin + # # Test case 1: stack allocated objects are deep copied. + # @testset "stack allocated objects shallow copy" begin + # function f() + # t = 0 + # while true + # produce(t) + # t = 1 + t + # end + # end + + # ttask = TapedTask(f) + # @test consume(ttask) == 0 + # @test consume(ttask) == 1 + # a = copy(ttask) + # @test consume(a) == 2 + # @test consume(a) == 3 + # @test consume(ttask) == 2 + # @test consume(ttask) == 3 + + # @inferred Libtask.TapedFunction(f) + # end + + # # Test case 2: Array objects are deeply copied. + # @testset "Array objects deep copy" begin + # function f() + # t = [0 1 2] + # while true + # produce(t[1]) + # t[1] = 1 + t[1] + # end + # end + + # ttask = TapedTask(f) + # @test consume(ttask) == 0 + # @test consume(ttask) == 1 + # a = copy(ttask) + # @test consume(a) == 2 + # @test consume(a) == 3 + # @test consume(ttask) == 2 + # @test consume(ttask) == 3 + # @test consume(ttask) == 4 + # @test consume(ttask) == 5 + # end + + # # Test case 3: Dict objects are shallowly copied. + # @testset "Dict objects shallow copy" begin + # function f() + # t = Dict(1 => 10, 2 => 20) + # while true + # produce(t[1]) + # t[1] = 1 + t[1] + # end + # end + + # ttask = TapedTask(f) + + # @test consume(ttask) == 10 + # @test consume(ttask) == 11 + + # a = copy(ttask) + # @test consume(a) == 12 + # @test consume(a) == 13 + + # @test consume(ttask) == 14 + # @test consume(ttask) == 15 + # end + + # @testset "Array deep copy 2" begin + # function f() + # t = Array{Int}(undef, 1) + # t[1] = 0 + # while true + # produce(t[1]) + # t[1] + # t[1] = 1 + t[1] + # end + # end + + # ttask = TapedTask(f) + + # consume(ttask) + # consume(ttask) + # a = copy(ttask) + # consume(a) + # consume(a) + + # @test consume(ttask) == 2 + # @test consume(a) == 4 + + # DATA = Dict{Task,Array}() + # function g() + # ta = zeros(UInt64, 4) + # for i in 1:4 + # ta[i] = hash(current_task()) + # DATA[current_task()] = ta + # produce(ta[i]) + # end + # end + + # ttask = TapedTask(g) + # @test consume(ttask) == hash(ttask.task) # index = 1 + # @test consume(ttask) == hash(ttask.task) # index = 2 + + # a = copy(ttask) + # @test consume(a) == hash(a.task) # index = 3 + # @test consume(a) == hash(a.task) # index = 4 + + # @test consume(ttask) == hash(ttask.task) # index = 3 + + # @test DATA[ttask.task] == + # [hash(ttask.task), hash(ttask.task), hash(ttask.task), 0] + # @test DATA[a.task] == + # [hash(ttask.task), hash(ttask.task), hash(a.task), hash(a.task)] + # end + + # # Test atomic values. + # @testset "ref atomic" begin + # function f() + # t = Ref(1) + # t[] = 0 + # for _ in 1:6 + # produce(t[]) + # t[] + # t[] += 1 + # end + # end + + # ctask = TapedTask(f) + + # consume(ctask) + # consume(ctask) + + # a = copy(ctask) + # consume(a) + # consume(a) + + # @test consume(ctask) == 2 + # @test consume(a) == 4 + # end + + # @testset "ref of dictionary deep copy" begin + # function f() + # t = Ref(Dict("A" => 1, 5 => "B")) + # t[]["A"] = 0 + # for _ in 1:6 + # produce(t[]["A"]) + # t[]["A"] += 1 + # end + # end + + # ctask = TapedTask(f) + + # consume(ctask) + # consume(ctask) + + # a = copy(ctask) + # consume(a) + # consume(a) + + # @test consume(ctask) == 2 + # @test consume(a) == 4 + # end + + # @testset "ref of array deep copy" begin + # # Create a TRef storing a matrix. + # x = TRef([1 2 3; 4 5 6]) + # x[][1, 3] = 900 + # @test x[][1, 3] == 900 + + # # TRef holding an array. + # y = TRef([1, 2, 3]) + # y[][2] = 19 + # @test y[][2] == 19 + # end + + # @testset "override deepcopy_types #57" begin + # struct DummyType end + + # function f(start::Int) + # t = [start] + # while true + # produce(t[1]) + # t[1] = 1 + t[1] + # end + # end + + # ttask = TapedTask(f, 0; deepcopy_types=DummyType) + # consume(ttask) + + # ttask2 = copy(ttask) + # consume(ttask2) + + # @test consume(ttask) == 1 + # @test consume(ttask2) == 2 + # end + # end end diff --git a/test/front_matter.jl b/test/front_matter.jl new file mode 100644 index 00000000..92a26f35 --- /dev/null +++ b/test/front_matter.jl @@ -0,0 +1 @@ +using Aqua, JuliaFormatter, Libtask, Test diff --git a/test/runtests.jl b/test/runtests.jl index 309d498d..54876973 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,10 +1,9 @@ -using JuliaFormatter, Libtask, Test - +include("front_matter.jl") @testset "Libtask" begin @testset "quality" begin Aqua.test_all(Libtask) - @test JuliaFormatter.format(Mooncake; verbose=false, overwrite=false) + @test JuliaFormatter.format(Libtask; verbose=false, overwrite=false) end include("copyable_task.jl") - include("issues.jl") + # include("issues.jl") end From 7e9ba1e1fb1ef0d8d549ebe8a00874a5f73dc8e5 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 27 Feb 2025 17:50:56 +0000 Subject: [PATCH 15/69] Fix bug for function with no produce statements --- src/copyable_task.jl | 4 ++-- src/test_utils.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/copyable_task.jl b/src/copyable_task.jl index 49431942..9d7c3fe5 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -308,8 +308,8 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} cond_dests = ID[x[2] for x in cases] cond_stmts = Any[x[3] for x in cases] switch_stmt = Switch(Any[x for x in cond_ids], cond_dests, first(new_bblocks).id) - entry_stmts = vcat(cond_stmts, switch_stmt) - entry_block = BBlock(ID(), vcat(cond_ids, ID()), map(new_inst, entry_stmts)) + entry_stmts = vcat(cond_stmts, nothing, switch_stmt) + entry_block = BBlock(ID(), vcat(cond_ids, ID(), ID()), map(new_inst, entry_stmts)) new_bblocks = vcat(entry_block, new_bblocks) # New argtypes are the same as the old ones, except we have `Ref`s in the first argument diff --git a/src/test_utils.jl b/src/test_utils.jl index b8736dbc..1321de04 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -41,7 +41,7 @@ function test_cases() (single_block, 5.0), [sin(5.0), sin(sin(5.0)), sin(sin(sin(5.0))), sin(sin(sin(sin(5.0))))], ), - # Testcase("no produce", (no_produce_test, 5.0, 4.0), []), + Testcase("no produce", (no_produce_test, 5.0, 4.0), []), ] end From 8e6f1b6c9837c0d188a2409e720038f18a241405 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 27 Feb 2025 17:52:29 +0000 Subject: [PATCH 16/69] Test for construction of new mutable struct --- src/test_utils.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/test_utils.jl b/src/test_utils.jl index 1321de04..ce18f7f3 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -42,6 +42,7 @@ function test_cases() [sin(5.0), sin(sin(5.0)), sin(sin(sin(5.0))), sin(sin(sin(sin(5.0))))], ), Testcase("no produce", (no_produce_test, 5.0, 4.0), []), + Testcase("new object", (new_object_test, 5, 4), Any[C(5, 4)]), ] end @@ -69,6 +70,8 @@ mutable struct C C(x, y) = new(x + y) end +Base.:(==)(c::C, d::C) = c.i == d.i + function new_object_test(x, y) produce(C(x, y)) return nothing From 71e6444a8113426d8c2d9c4199a05f7ba027b075 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 27 Feb 2025 18:04:38 +0000 Subject: [PATCH 17/69] Fix more test cases --- src/copyable_task.jl | 8 +++++++- src/test_utils.jl | 9 +++++---- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/copyable_task.jl b/src/copyable_task.jl index 9d7c3fe5..b927c4ab 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -117,11 +117,17 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} # We log the `ID`s of each of these new basic blocks, for use later. new_bblocks = map(ir.blocks) do bb + # If the final statement in the block is a `produce` statement, insert an additional + # statement afterwards. + if is_produce_stmt(bb.insts[end].stmt) + push!(bb.inst_ids, ID()) + push!(bb.insts, new_inst(nothing, Nothing)) + end + # Find all of the `produce` statements. produce_indices = findall(x -> is_produce_stmt(x.stmt), bb.insts) terminator_indices = vcat(produce_indices, length(bb)) - # TODO: WHAT HAPPENS IF THERE ARE NO PRODUCE STATEMENTS? # TODO: WHAT HAPPENS IF THE PRODUCE STATEMENT IS A FALLTHROUGH TERMINATOR????? # The `ID`s of the new basic blocks. diff --git a/src/test_utils.jl b/src/test_utils.jl index ce18f7f3..43542964 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -42,7 +42,8 @@ function test_cases() [sin(5.0), sin(sin(5.0)), sin(sin(sin(5.0))), sin(sin(sin(sin(5.0))))], ), Testcase("no produce", (no_produce_test, 5.0, 4.0), []), - Testcase("new object", (new_object_test, 5, 4), Any[C(5, 4)]), + Testcase("new object", (new_object_test, 5, 4), [C(5, 4)]), + Testcase("branching test l", (branching_test, 5.0, 4.0), [string(sin(5.0))]), ] end @@ -79,11 +80,11 @@ end function branching_test(x, y) if x > y - r = string(sin(x)) + produce(string(sin(x))) else - r = sin(x) * cos(y) + produce(sin(x) * cos(y)) end - return r + return nothing end function unused_argument_test(x) From 4d18d477681a3b069a6aab39c657a45e988d9bb0 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 27 Feb 2025 18:06:02 +0000 Subject: [PATCH 18/69] Add another test case --- src/test_utils.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/test_utils.jl b/src/test_utils.jl index 43542964..a39d0ccb 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -44,6 +44,7 @@ function test_cases() Testcase("no produce", (no_produce_test, 5.0, 4.0), []), Testcase("new object", (new_object_test, 5, 4), [C(5, 4)]), Testcase("branching test l", (branching_test, 5.0, 4.0), [string(sin(5.0))]), + Testcase("branching test r", (branching_test, 4.0, 5.0), [sin(4.0) * cos(5.0)]), ] end From f1b247d76b5016925f460bbbc37b0bf974d24cd6 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 27 Feb 2025 18:08:27 +0000 Subject: [PATCH 19/69] More test cases --- src/test_utils.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/test_utils.jl b/src/test_utils.jl index a39d0ccb..c8272be4 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -45,6 +45,9 @@ function test_cases() Testcase("new object", (new_object_test, 5, 4), [C(5, 4)]), Testcase("branching test l", (branching_test, 5.0, 4.0), [string(sin(5.0))]), Testcase("branching test r", (branching_test, 4.0, 5.0), [sin(4.0) * cos(5.0)]), + Testcase("unused argument test", (unused_argument_test, 3), [1]), + Testcase("test with const", (test_with_const, ), [1]), + Testcase("nested", (nested_outer, ), [true, false]), ] end @@ -96,6 +99,7 @@ end function test_with_const() # this line generates: %1 = 1::Core.Const(1) r = (a = 1) + produce(r) return nothing end From 85ec6a2d93b4963034b95f8df4765aa6c005241a Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 27 Feb 2025 19:36:37 +0000 Subject: [PATCH 20/69] More work --- src/copyable_task.jl | 75 +++-- src/test_utils.jl | 70 ++++- test/copyable_task.jl | 644 ++++++++++++++++++++---------------------- 3 files changed, 418 insertions(+), 371 deletions(-) diff --git a/src/copyable_task.jl b/src/copyable_task.jl index b927c4ab..eda0c036 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -115,6 +115,7 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} # following the final `produce(%x)` statment, until the end of `bb`. # Furthermore, each `produce(%x)` statement is replaced with a `ReturnNode(%x)`. # We log the `ID`s of each of these new basic blocks, for use later. + replacements = Dict{ID,ID}() new_bblocks = map(ir.blocks) do bb # If the final statement in the block is a `produce` statement, insert an additional @@ -128,17 +129,18 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} produce_indices = findall(x -> is_produce_stmt(x.stmt), bb.insts) terminator_indices = vcat(produce_indices, length(bb)) - # TODO: WHAT HAPPENS IF THE PRODUCE STATEMENT IS A FALLTHROUGH TERMINATOR????? - # The `ID`s of the new basic blocks. - new_block_ids = vcat(bb.id, [ID() for _ in produce_indices]) + old_id = bb.id + new_block_ids = vcat([ID() for _ in produce_indices], bb.id) + new_id = first(new_block_ids) + replacements[old_id] = new_id - # Construct `n_produce + 1` new basic blocks. The first basic block retains the - # `ID` of `bb`, the remaining `n_produce + 1` blocks get new `ID`s (which we log). + # Construct `n_produce + 1` new basic blocks. The last basic block retains the + # `ID` of `bb`, the remaining `n_produce` blocks get new `ID`s (which we log). # All `produce(%x)` statements are replaced with `Return(%x)` statements. return map(enumerate(terminator_indices)) do (n, term_ind) - # The first new block has the same `ID` as `bb`. The others gets new ones. + # The last new block has the same `ID` as `bb`. The others gets new ones. block_id = new_block_ids[n] # Pull out the instructions and their `ID`s for the new block. @@ -165,12 +167,29 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} ) end - # Construct + return new basic block. + # Construct + return new basic block. return BBlock(block_id, inst_ids, insts) end end new_bblocks = reduce(vcat, new_bblocks) + # Hunt for `IDGotoNode`s and `IDGotoIfNot`s, and replace them with the new ID of the + # start of these blocks. + for (old_id, new_id) in replacements, bb in new_bblocks + inst = last(bb.insts) + stmt = inst.stmt + new_stmt = if stmt isa IDGotoNode && stmt.label == old_id + IDGotoNode(new_id) + elseif stmt isa IDGotoIfNot && stmt.dest == old_id + IDGotoIfNot(stmt.cond, new_id) + else + continue + end + bb.insts[end] = CC.NewInstruction( + new_stmt, inst.type, inst.info, inst.line, inst.flag + ) + end + # Construct map between SSA IDs and their index in the state data structure and back. # Optimisation TODO: don't create an entry for literally every line in the IR, just the # ones which produce values that might be needed later. @@ -211,7 +230,10 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} foreach(zip(bb.inst_ids, bb.insts)) do (id, inst) stmt = inst.stmt - if Meta.isexpr(stmt, :invoke) || Meta.isexpr(stmt, :call) + if Meta.isexpr(stmt, :invoke) || + Meta.isexpr(stmt, :call) || + Meta.isexpr(stmt, :new) || + Meta.isexpr(stmt, :foreigncall) # Skip over set_resume_block! statements inserted in the previous pass. if stmt.args[1] == set_resume_block! @@ -238,13 +260,34 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} out_ind = ssa_id_to_ref_index_map[id] set_ref = Expr(:call, set_ref_at!, refs_id, out_ind, id) push!(inst_pairs, (ID(), new_inst(set_ref))) - elseif Meta.isexpr(stmt, :new) + elseif Meta.isexpr(stmt, :boundscheck) push!(inst_pairs, (id, inst)) elseif stmt isa ReturnNode - push!(inst_pairs, (id, inst)) - elseif stmt isa IDGotoNode - push!(inst_pairs, (id, inst)) + # If returning an SSA, it might be one whose value was restored from before. + # Therefore, grab it out of storage, rather than assuming that it is def-ed. + if isdefined(stmt, :val) && stmt.val isa ID + ref_ind = ssa_id_to_ref_index_map[stmt.val] + val_id = ID() + expr = Expr(:call, get_ref_at, refs_id, ref_ind) + push!(inst_pairs, (val_id, new_inst(expr))) + push!(inst_pairs, (ID(), new_inst(ReturnNode(val_id)))) + else + push!(inst_pairs, (id, inst)) + end elseif stmt isa IDGotoIfNot + # If the condition is an SSA, it might be one whose value was restored from + # before. Therefore, grab it out of storage, rather than assuming that it is + # defined. + if stmt.cond isa ID + ref_ind = ssa_id_to_ref_index_map[stmt.cond] + cond_id = ID() + expr = Expr(:call, get_ref_at, refs_id, ref_ind) + push!(inst_pairs, (cond_id, new_inst(expr))) + push!(inst_pairs, (ID(), new_inst(IDGotoIfNot(cond_id, stmt.dest)))) + else + push!(inst_pairs, (id, inst)) + end + elseif stmt isa IDGotoNode push!(inst_pairs, (id, inst)) elseif stmt isa IDPhiNode # we'll fix up the PhiNodes after this, so identity transform for now. @@ -281,14 +324,6 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} deref_ids = map(phi_inds) do n id = bb.inst_ids[n] phi_id = phi_ids[n] - - # # Re-reference the PhiNode. - # n_id = ID() - # push!(phi_inst_pairs, (n_id, new_inst(Expr(:call, getfield, phi_id, :n)))) - # ref_id = ID() - # push!(phi_inst_pairs, (ref_id, new_inst(Expr(:call, getfield, refs_id, n_id)))) - # push!(phi_inst_pairs, (id, new_inst(Expr(:call, getfield, ref_id, :x)))) - push!(phi_inst_pairs, (id, new_inst(Expr(:call, deref_phi, refs_id, phi_id)))) return id end diff --git a/src/test_utils.jl b/src/test_utils.jl index c8272be4..49ee7de7 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -36,19 +36,29 @@ function (case::Testcase)() end function test_cases() - return Testcase[Testcase( - "single block", - (single_block, 5.0), - [sin(5.0), sin(sin(5.0)), sin(sin(sin(5.0))), sin(sin(sin(sin(5.0))))], - ), - Testcase("no produce", (no_produce_test, 5.0, 4.0), []), - Testcase("new object", (new_object_test, 5, 4), [C(5, 4)]), - Testcase("branching test l", (branching_test, 5.0, 4.0), [string(sin(5.0))]), - Testcase("branching test r", (branching_test, 4.0, 5.0), [sin(4.0) * cos(5.0)]), - Testcase("unused argument test", (unused_argument_test, 3), [1]), - Testcase("test with const", (test_with_const, ), [1]), - Testcase("nested", (nested_outer, ), [true, false]), -] + return Testcase[ + Testcase( + "single block", + (single_block, 5.0), + [sin(5.0), sin(sin(5.0)), sin(sin(sin(5.0))), sin(sin(sin(sin(5.0))))], + ), + Testcase("produce old", (produce_old_value, 5.0), [sin(5.0), sin(5.0)]), + Testcase("branch on old value l", (branch_on_old_value, 2.0), [true, 2.0]), + Testcase("branch on old value r", (branch_on_old_value, -1.0), [false, -2.0]), + Testcase("no produce", (no_produce_test, 5.0, 4.0), []), + Testcase("new object", (new_object_test, 5, 4), [C(5, 4), C(5, 4)]), + Testcase("branching test l", (branching_test, 5.0, 4.0), [string(sin(5.0))]), + Testcase("branching test r", (branching_test, 4.0, 5.0), [sin(4.0) * cos(5.0)]), + Testcase("unused argument test", (unused_argument_test, 3), [1]), + Testcase("test with const", (test_with_const,), [1]), + Testcase("while loop", (while_loop,), collect(1:9)), + Testcase( + "foreigncall tester", (foreigncall_tester, "hi"), [Ptr{UInt8}, Ptr{UInt8}] + ), + + # Failing tests + # Testcase("nested", (nested_outer, ), [true, false]), + ] end function single_block(x::Float64) @@ -63,6 +73,20 @@ function single_block(x::Float64) return cos(x4) end +function produce_old_value(x::Float64) + v = sin(x) + produce(v) + produce(v) + return nothing +end + +function branch_on_old_value(x::Float64) + b = x > 0 + produce(b) + produce(b ? x : 2x) + return nothing +end + function no_produce_test(x, y) c = x + y return (; c, x, y) @@ -78,7 +102,9 @@ end Base.:(==)(c::C, d::C) = c.i == d.i function new_object_test(x, y) - produce(C(x, y)) + c = C(x, y) + produce(c) + produce(c) return nothing end @@ -103,6 +129,22 @@ function test_with_const() return nothing end +function while_loop() + t = 1 + while t < 10 + produce(t) + t = 1 + t + end + return nothing +end + +function foreigncall_tester(s::String) + ptr = ccall(:jl_string_ptr, Ptr{UInt8}, (Any,), s) + produce(typeof(ptr)) + produce(typeof(ptr)) + return nothing +end + @noinline function nested_inner() produce(true) return nothing diff --git a/test/copyable_task.jl b/test/copyable_task.jl index 312f987d..30e925a2 100644 --- a/test/copyable_task.jl +++ b/test/copyable_task.jl @@ -18,343 +18,313 @@ # @test consume(ttask) == 1 # end - # @testset "iteration" begin - # function f() - # t = 1 - # while true - # produce(t) - # t = 1 + t - # end - # end - - # ttask = TapedTask(f) - - # next = iterate(ttask) - # @test next === (1, nothing) - - # val, state = next - # next = iterate(ttask, state) - # @test next === (2, nothing) - - # val, state = next - # next = iterate(ttask, state) - # @test next === (3, nothing) - - # a = collect(Iterators.take(ttask, 7)) - # @test eltype(a) === Int - # @test a == 4:10 - # end - - # # Test of `Exception`. - # @testset "Exception" begin - # @testset "method error" begin - # function f() - # t = 0 - # while true - # t[3] = 1 - # produce(t) - # t = t + 1 - # end - # end - - # ttask = TapedTask(f) - # try - # consume(ttask) - # catch ex - # @test ex isa MethodError - # end - # if VERSION >= v"1.5" - # @test ttask.task.exception isa MethodError - # end - # end - - # @testset "error test" begin - # function f() - # x = 1 - # while true - # error("error test") - # produce(x) - # x += 1 - # end - # end - - # ttask = TapedTask(f) - # try - # consume(ttask) - # catch ex - # @test ex isa ErrorException - # end - # if VERSION >= v"1.5" - # @test ttask.task.exception isa ErrorException - # end - # end - - # @testset "OutOfBounds Test Before" begin - # function f() - # x = zeros(2) - # while true - # x[1] = 1 - # x[2] = 2 - # x[3] = 3 - # produce(x[1]) - # end - # end - - # ttask = TapedTask(f) - # try - # consume(ttask) - # catch ex - # @test ex isa BoundsError - # end - # if VERSION >= v"1.5" - # @test ttask.task.exception isa BoundsError - # end - # end - - # @testset "OutOfBounds Test After `produce`" begin - # function f() - # x = zeros(2) - # while true - # x[1] = 1 - # x[2] = 2 - # produce(x[2]) - # x[3] = 3 - # end - # end - - # ttask = TapedTask(f) - # @test consume(ttask) == 2 - # try - # consume(ttask) - # catch ex - # @test ex isa BoundsError - # end - # if VERSION >= v"1.5" - # @test ttask.task.exception isa BoundsError - # end - # end - - # @testset "OutOfBounds Test After `copy`" begin - # function f() - # x = zeros(2) - # while true - # x[1] = 1 - # x[2] = 2 - # produce(x[2]) - # x[3] = 3 - # end - # end - - # ttask = TapedTask(f) - # @test consume(ttask) == 2 - # ttask2 = copy(ttask) - # try - # consume(ttask2) - # catch ex - # @test ex isa BoundsError - # end - # @test ttask.task.exception === nothing - # if VERSION >= v"1.5" - # @test ttask2.task.exception isa BoundsError - # end - # end - # end - - # @testset "copying" begin - # # Test case 1: stack allocated objects are deep copied. - # @testset "stack allocated objects shallow copy" begin - # function f() - # t = 0 - # while true - # produce(t) - # t = 1 + t - # end - # end - - # ttask = TapedTask(f) - # @test consume(ttask) == 0 - # @test consume(ttask) == 1 - # a = copy(ttask) - # @test consume(a) == 2 - # @test consume(a) == 3 - # @test consume(ttask) == 2 - # @test consume(ttask) == 3 - - # @inferred Libtask.TapedFunction(f) - # end - - # # Test case 2: Array objects are deeply copied. - # @testset "Array objects deep copy" begin - # function f() - # t = [0 1 2] - # while true - # produce(t[1]) - # t[1] = 1 + t[1] - # end - # end - - # ttask = TapedTask(f) - # @test consume(ttask) == 0 - # @test consume(ttask) == 1 - # a = copy(ttask) - # @test consume(a) == 2 - # @test consume(a) == 3 - # @test consume(ttask) == 2 - # @test consume(ttask) == 3 - # @test consume(ttask) == 4 - # @test consume(ttask) == 5 - # end - - # # Test case 3: Dict objects are shallowly copied. - # @testset "Dict objects shallow copy" begin - # function f() - # t = Dict(1 => 10, 2 => 20) - # while true - # produce(t[1]) - # t[1] = 1 + t[1] - # end - # end - - # ttask = TapedTask(f) - - # @test consume(ttask) == 10 - # @test consume(ttask) == 11 - - # a = copy(ttask) - # @test consume(a) == 12 - # @test consume(a) == 13 - - # @test consume(ttask) == 14 - # @test consume(ttask) == 15 - # end - - # @testset "Array deep copy 2" begin - # function f() - # t = Array{Int}(undef, 1) - # t[1] = 0 - # while true - # produce(t[1]) - # t[1] - # t[1] = 1 + t[1] - # end - # end - - # ttask = TapedTask(f) - - # consume(ttask) - # consume(ttask) - # a = copy(ttask) - # consume(a) - # consume(a) - - # @test consume(ttask) == 2 - # @test consume(a) == 4 - - # DATA = Dict{Task,Array}() - # function g() - # ta = zeros(UInt64, 4) - # for i in 1:4 - # ta[i] = hash(current_task()) - # DATA[current_task()] = ta - # produce(ta[i]) - # end - # end - - # ttask = TapedTask(g) - # @test consume(ttask) == hash(ttask.task) # index = 1 - # @test consume(ttask) == hash(ttask.task) # index = 2 - - # a = copy(ttask) - # @test consume(a) == hash(a.task) # index = 3 - # @test consume(a) == hash(a.task) # index = 4 - - # @test consume(ttask) == hash(ttask.task) # index = 3 - - # @test DATA[ttask.task] == - # [hash(ttask.task), hash(ttask.task), hash(ttask.task), 0] - # @test DATA[a.task] == - # [hash(ttask.task), hash(ttask.task), hash(a.task), hash(a.task)] - # end - - # # Test atomic values. - # @testset "ref atomic" begin - # function f() - # t = Ref(1) - # t[] = 0 - # for _ in 1:6 - # produce(t[]) - # t[] - # t[] += 1 - # end - # end - - # ctask = TapedTask(f) - - # consume(ctask) - # consume(ctask) - - # a = copy(ctask) - # consume(a) - # consume(a) - - # @test consume(ctask) == 2 - # @test consume(a) == 4 - # end - - # @testset "ref of dictionary deep copy" begin - # function f() - # t = Ref(Dict("A" => 1, 5 => "B")) - # t[]["A"] = 0 - # for _ in 1:6 - # produce(t[]["A"]) - # t[]["A"] += 1 - # end - # end - - # ctask = TapedTask(f) - - # consume(ctask) - # consume(ctask) - - # a = copy(ctask) - # consume(a) - # consume(a) - - # @test consume(ctask) == 2 - # @test consume(a) == 4 - # end - - # @testset "ref of array deep copy" begin - # # Create a TRef storing a matrix. - # x = TRef([1 2 3; 4 5 6]) - # x[][1, 3] = 900 - # @test x[][1, 3] == 900 - - # # TRef holding an array. - # y = TRef([1, 2, 3]) - # y[][2] = 19 - # @test y[][2] == 19 - # end - - # @testset "override deepcopy_types #57" begin - # struct DummyType end - - # function f(start::Int) - # t = [start] - # while true - # produce(t[1]) - # t[1] = 1 + t[1] - # end - # end - - # ttask = TapedTask(f, 0; deepcopy_types=DummyType) - # consume(ttask) + @testset "iteration" begin + function f() + t = 1 + while true + produce(t) + t = 1 + t + end + end + + ttask = CopyableTask(f) + + next = iterate(ttask) + @test next === (1, nothing) + + val, state = next + next = iterate(ttask, state) + @test next === (2, nothing) + + val, state = next + next = iterate(ttask, state) + @test next === (3, nothing) + + a = collect(Iterators.take(ttask, 7)) + @test eltype(a) === Int + @test a == 4:10 + end - # ttask2 = copy(ttask) - # consume(ttask2) + # Test of `Exception`. + @testset "Exception" begin + @testset "method error" begin + function f() + t = 0 + while true + t[3] = 1 + produce(t) + t = t + 1 + end + end + + ttask = CopyableTask(f) + try + consume(ttask) + catch ex + @test ex isa MethodError + end + end + + @testset "error test" begin + function f() + x = 1 + while true + error("error test") + produce(x) + x += 1 + end + end + + ttask = CopyableTask(f) + try + consume(ttask) + catch ex + @test ex isa ErrorException + end + end + + @testset "OutOfBounds Test Before" begin + function f() + x = zeros(2) + while true + x[1] = 1 + x[2] = 2 + x[3] = 3 + produce(x[1]) + end + end + + ttask = CopyableTask(f) + try + consume(ttask) + catch ex + @test ex isa BoundsError + end + end + + @testset "OutOfBounds Test After `produce`" begin + function f() + x = zeros(2) + while true + x[1] = 1 + x[2] = 2 + produce(x[2]) + x[3] = 3 + end + end + + ttask = CopyableTask(f) + @test consume(ttask) == 2 + try + consume(ttask) + catch ex + @test ex isa BoundsError + end + end + + @testset "OutOfBounds Test After `copy`" begin + function f() + x = zeros(2) + while true + x[1] = 1 + x[2] = 2 + produce(x[2]) + x[3] = 3 + end + end + + ttask = CopyableTask(f) + @test consume(ttask) == 2 + ttask2 = deepcopy(ttask) + try + consume(ttask2) + catch ex + @test ex isa BoundsError + end + end + end - # @test consume(ttask) == 1 - # @test consume(ttask2) == 2 - # end - # end + @testset "copying" begin + # Test case 1: stack allocated objects are deep copied. + @testset "stack allocated objects shallow copy" begin + function f() + t = 0 + while true + produce(t) + t = 1 + t + end + end + + ttask = CopyableTask(f) + @test consume(ttask) == 0 + @test consume(ttask) == 1 + a = deepcopy(ttask) + @test consume(a) == 2 + @test consume(a) == 3 + @test consume(ttask) == 2 + @test consume(ttask) == 3 + end + + # Test case 2: Array objects are deeply copied. + @testset "Array objects deep copy" begin + function f() + t = [0 1 2] + while true + produce(t[1]) + t[1] = 1 + t[1] + end + end + + ttask = CopyableTask(f) + @test consume(ttask) == 0 + @test consume(ttask) == 1 + a = deepcopy(ttask) + @test consume(a) == 2 + @test consume(a) == 3 + @test consume(ttask) == 2 + @test consume(ttask) == 3 + @test consume(ttask) == 4 + @test consume(ttask) == 5 + end + + # # Test case 3: Dict objects are shallowly copied. + # @testset "Dict objects shallow copy" begin + # function f() + # t = Dict(1 => 10, 2 => 20) + # while true + # produce(t[1]) + # t[1] = 1 + t[1] + # end + # end + + # ttask = TapedTask(f) + + # @test consume(ttask) == 10 + # @test consume(ttask) == 11 + + # a = copy(ttask) + # @test consume(a) == 12 + # @test consume(a) == 13 + + # @test consume(ttask) == 14 + # @test consume(ttask) == 15 + # end + + @testset "Array deep copy 2" begin + function f() + t = Array{Int}(undef, 1) + t[1] = 0 + while true + produce(t[1]) + t[1] + t[1] = 1 + t[1] + end + end + + ttask = CopyableTask(f) + + consume(ttask) + consume(ttask) + a = deepcopy(ttask) + consume(a) + consume(a) + + @test consume(ttask) == 2 + @test consume(a) == 4 + + # DATA = Dict{Task,Array}() + # function g() + # ta = zeros(UInt64, 4) + # for i in 1:4 + # ta[i] = hash(current_task()) + # DATA[current_task()] = ta + # produce(ta[i]) + # end + # end + + # ttask = TapedTask(g) + # @test consume(ttask) == hash(ttask.task) # index = 1 + # @test consume(ttask) == hash(ttask.task) # index = 2 + + # a = copy(ttask) + # @test consume(a) == hash(a.task) # index = 3 + # @test consume(a) == hash(a.task) # index = 4 + + # @test consume(ttask) == hash(ttask.task) # index = 3 + + # @test DATA[ttask.task] == + # [hash(ttask.task), hash(ttask.task), hash(ttask.task), 0] + # @test DATA[a.task] == + # [hash(ttask.task), hash(ttask.task), hash(a.task), hash(a.task)] + end + + # # Test atomic values. + # @testset "ref atomic" begin + # function f() + # t = Ref(1) + # t[] = 0 + # for _ in 1:6 + # produce(t[]) + # t[] + # t[] += 1 + # end + # end + + # ctask = TapedTask(f) + + # consume(ctask) + # consume(ctask) + + # a = copy(ctask) + # consume(a) + # consume(a) + + # @test consume(ctask) == 2 + # @test consume(a) == 4 + # end + + # @testset "ref of dictionary deep copy" begin + # function f() + # t = Ref(Dict("A" => 1, 5 => "B")) + # t[]["A"] = 0 + # for _ in 1:6 + # produce(t[]["A"]) + # t[]["A"] += 1 + # end + # end + + # ctask = TapedTask(f) + + # consume(ctask) + # consume(ctask) + + # a = copy(ctask) + # consume(a) + # consume(a) + + # @test consume(ctask) == 2 + # @test consume(a) == 4 + # end + + # @testset "override deepcopy_types #57" begin + # struct DummyType end + + # function f(start::Int) + # t = [start] + # while true + # produce(t[1]) + # t[1] = 1 + t[1] + # end + # end + + # ttask = TapedTask(f, 0; deepcopy_types=DummyType) + # consume(ttask) + + # ttask2 = copy(ttask) + # consume(ttask2) + + # @test consume(ttask) == 1 + # @test consume(ttask2) == 2 + # end + end end From 04544ca4dc97849d432324fcb9ba68275664e097 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 27 Feb 2025 19:37:12 +0000 Subject: [PATCH 21/69] Formatting --- .JuliaFormatter.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index c7439503..323237ba 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -1 +1 @@ -style = "blue" \ No newline at end of file +style = "blue" From ecbf41c2754c50bae5c2d322a27ca4749e5dfa20 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 27 Feb 2025 19:42:06 +0000 Subject: [PATCH 22/69] Remove unhelpful docstring --- src/copyable_task.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/copyable_task.jl b/src/copyable_task.jl index eda0c036..0e33e5c6 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -81,11 +81,6 @@ function produce_value(x::Expr) return x.args[2] # must be a `:call` Expr. end -""" - derive_copyable_task_ir(ir::IRCode)::IRCode - - -""" function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} # Replace all existing `ReturnNode`s with `ReturnNode(nothing)` in order to provide the From 22f6f6719747dc5d6f0107bbdebafec500257bee Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 27 Feb 2025 19:43:08 +0000 Subject: [PATCH 23/69] Relax Test compat a bit --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 1df3b546..6477b21f 100644 --- a/Project.toml +++ b/Project.toml @@ -15,7 +15,7 @@ Aqua = "0.8.11" JuliaFormatter = "1.0.62" MistyClosures = "2.0.0" Mooncake = "0.4.99" -Test = "1.11.0" +Test = "1" julia = "1.10.8" [extras] From 04e4dcda3d5a34f49a4cc547221bcc0d4a6ec58f Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 27 Feb 2025 19:45:13 +0000 Subject: [PATCH 24/69] Lower minor version to make integration tests run --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 6477b21f..00fbaebd 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" license = "MIT" desc = "Tape based task copying in Turing" repo = "https://github.com/TuringLang/Libtask.jl.git" -version = "0.9.0" +version = "0.8.8" [deps] MistyClosures = "dbe65cb8-6be2-42dd-bbc5-4196aaced4f4" From b25d2f418156955e773a7fc41337880fd2d63e8d Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 27 Feb 2025 19:50:00 +0000 Subject: [PATCH 25/69] Handle code_coverage_effect --- src/copyable_task.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/copyable_task.jl b/src/copyable_task.jl index 0e33e5c6..7493c134 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -257,6 +257,8 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} push!(inst_pairs, (ID(), new_inst(set_ref))) elseif Meta.isexpr(stmt, :boundscheck) push!(inst_pairs, (id, inst)) + elseif Meta.isexpr(stmt, :code_coverage_effect) + push!(inst_pairs, (id, inst)) elseif stmt isa ReturnNode # If returning an SSA, it might be one whose value was restored from before. # Therefore, grab it out of storage, rather than assuming that it is def-ed. From 94aed2f6b5673573fb99da9135f7f82256bc5017 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 28 Feb 2025 13:50:07 +0000 Subject: [PATCH 26/69] Some tweaks + docs --- src/Libtask.jl | 2 +- src/copyable_task.jl | 99 ++++++++++++++++++++++++++++++++++++++----- src/test_utils.jl | 8 ++-- test/copyable_task.jl | 26 ++++++------ 4 files changed, 107 insertions(+), 28 deletions(-) diff --git a/src/Libtask.jl b/src/Libtask.jl index 36abd4bd..430c98a6 100644 --- a/src/Libtask.jl +++ b/src/Libtask.jl @@ -15,6 +15,6 @@ using Core.Compiler: Argument, IRCode, ReturnNode include("copyable_task.jl") include("test_utils.jl") -export CopyableTask, consume, produce +export TapedTask, consume, produce end diff --git a/src/copyable_task.jl b/src/copyable_task.jl index 7493c134..bcaa6083 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -4,24 +4,103 @@ __v::Int = 5 return nothing end -mutable struct CopyableTask{Tmc<:MistyClosure,Targs} +mutable struct TapedTask{Tmc<:MistyClosure,Targs} const mc::Tmc args::Targs const position::Base.RefValue{Int32} + const deepcopy_types::Type end -@inline consume(t::CopyableTask) = t.mc(t.args...) +""" + Base.copy(t::TapedTask) + +Makes a copy of `t` which can be run. For the most part, calls to [`consume`](@ref) on the +copied task will give the same results as the original. There are, however, substantial +limitations to this, detailed in the extended help. + +# Extended Help + +We call a copy of a `TapedTask` _consistent_ with the original if the call to `==` in the +loop below always returns `true`: +```julia +t = +tc = copy(t) +for (v, vc) in zip(t, tc) + v == vc +end +``` +(provided that `==` is implemented for all `v` that are produced). Convesely, we refer to a +copy as _inconsistent_ if this property doesn't hold. In order to ensure +consistency, we need to ensure that independent copies are made of anything which might be +mutated by the task or its copy during subsequent `consume` calls. Failure to do this can +cause problems if, for example, a task reads-to and writes-from some memory. +If we call `consume` on the original task, and then on a copy of it, any changes made by the +original will be visible to the copy, potentially causing its behaviour to differ. This can +manifest itself as a race condition if the task and its copies are run concurrently. + +To understand a bit more about when a task is / is not consistent, we need to dig into the +rather specific semantics of `copy`. Calling `copy` on a `TapedTask` does the following: +1. `copy` the `position` field, +2. `map`s `_tape_copy` over the `args` field, and +3. `map`s `_tape_copy` over the all of the data closed over in the `OpaqueClosure` which + implements the task (specifically the values _inside_ the `Ref`s) -- call these the + `captures`. Except the last elements of this data, because this is `===` to the + `position` field -- for this element we use the copy we made in step 1. + +`_tape_copy` doesn't actually make a copy of the object at all if it is not either an +`Array`, a `Ref`, or an instance of one of the types listed in the task's `deepcopy_type` +field. If it is an instance of one of these types then `_tape_copy` just calls `deepcopy`. + +This behaviour is plainly entirely acceptable if the argument to `_tape_copy` is a bits +type. For any `mutable struct`s which aren't flagged for `deepcopy`ing, we have an immediate +risk of inconsistency. Similarly, for any `struct` types which aren't bits types (e.g. +those which contain an `Array`, `Ref`, or some other `mutable struct` either directly as one +of their fields, or as a field of a field, etc), we have an inconsistency risk. + +Furthermore, for anything which _is_ `deepcopy`ed we introduce inconsistency risks. If, for +example, two elements of the data closed over by the task alias one another, calling +`deepcopy` on them separately will cause the copies to _not_ alias one another. +The same thing can happen if one element is `deepcopy`ed and the other not. For example, if +we have both an `Array` `x` and `view(x, inds)` stored in separate elements of `captures`, +`x` will be `deepcopy`ed, while `view(x, inds)` will not. In the copy of `captures`, the +`view` will still be a view into the original `x`, not the `deepcopy`ed version. Again, this +introduces inconsistency. + +Why do we have these semantics? We have them because Libtask has always had them, and at the +time of writing we're unsure whether AdvancedPS.jl, and by extension Turing.jl rely on this +behaviour. + +What other options do we have? Simply calling `deepcopy` on a `TapedTask` works fine, and +should reliably result in consistent behaviour between a `TapedTask` and any copies of it. +This would, therefore, be a preferable implementation. We should try to determine whether +this is a viable option. +""" +function Base.copy(t::T) where {T<:TapedTask} + captures = t.mc.oc.captures + new_captures = map(Base.Fix2(_tape_copy, t.deepcopy_types), captures) + new_position = new_captures[end] # baked in later on. + new_args = map(Base.Fix2(_tape_copy, t.deepcopy_types), t.args) + new_mc = Mooncake.replace_captures(t.mc, new_captures) + return T(new_mc, new_args, new_position, t.deepcopy_types) +end + +_tape_copy(v, deepcopy_types::Type) = v isa deepcopy_types ? deepcopy(v) : v + +# Not sure that we need this in the new implementation. +_tape_copy(box::Core.Box, deepcopy_types::Type) = error("Found a box") + +@inline consume(t::TapedTask) = t.mc(t.args...) -function initialise!(t::CopyableTask, args::Vararg{Any,N})::Nothing where {N} +function initialise!(t::TapedTask, args::Vararg{Any,N})::Nothing where {N} t.position[] = -1 t.args = args return nothing end -function CopyableTask(fargs...) +function TapedTask(fargs...; deepcopy_types::Type=Union{}) sig = typeof(fargs) mc, count_ref = build_callable(Base.code_ircode_by_type(sig)[1][1]) - return CopyableTask(mc, fargs[2:end], count_ref) + return TapedTask(mc, fargs[2:end], count_ref, Union{deepcopy_types, Array, Ref}) end function build_callable(ir::IRCode) @@ -36,10 +115,10 @@ end might_produce(sig::Type{<:Tuple})::Bool `true` if a call to method with signature `sig` is permitted to contain -`CopyableTasks.produce` statements. +`Libtask.produce` statements. This is an opt-in mechanism. the fallback method of this function returns `false` indicating -that, by default, we assume that calls do not contain `CopyableTasks.produce` statements. +that, by default, we assume that calls do not contain `Libtask.produce` statements. """ might_produce(::Type{<:Tuple}) = false @@ -382,9 +461,9 @@ end @inline deref_phi(::R, x) where {R<:Tuple} = x # Implement iterator interface. -function Base.iterate(t::CopyableTask, state::Nothing=nothing) +function Base.iterate(t::TapedTask, state::Nothing=nothing) v = consume(t) return v === nothing ? nothing : (v, nothing) end -Base.IteratorSize(::Type{<:CopyableTask}) = Base.SizeUnknown() -Base.IteratorEltype(::Type{<:CopyableTask}) = Base.EltypeUnknown() +Base.IteratorSize(::Type{<:TapedTask}) = Base.SizeUnknown() +Base.IteratorEltype(::Type{<:TapedTask}) = Base.EltypeUnknown() diff --git a/src/test_utils.jl b/src/test_utils.jl index 49ee7de7..fd343b5b 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -2,7 +2,7 @@ module TestUtils using ..Libtask using Test -using ..Libtask: CopyableTask +using ..Libtask: TapedTask struct Testcase name::String @@ -14,14 +14,14 @@ function (case::Testcase)() testset = @testset "$(case.name)" begin # Construct the task. - t = CopyableTask(case.fargs...) + t = TapedTask(case.fargs...) # Iterate through t. Record the results, and take a copy after each iteration. iteration_results = [] - t_copies = [deepcopy(t)] + t_copies = [copy(t)] for val in t push!(iteration_results, val) - push!(t_copies, deepcopy(t)) + push!(t_copies, copy(t)) end # Check that iterating the original task gives the expected results. diff --git a/test/copyable_task.jl b/test/copyable_task.jl index 30e925a2..19074e2a 100644 --- a/test/copyable_task.jl +++ b/test/copyable_task.jl @@ -27,7 +27,7 @@ end end - ttask = CopyableTask(f) + ttask = TapedTask(f) next = iterate(ttask) @test next === (1, nothing) @@ -57,7 +57,7 @@ end end - ttask = CopyableTask(f) + ttask = TapedTask(f) try consume(ttask) catch ex @@ -75,7 +75,7 @@ end end - ttask = CopyableTask(f) + ttask = TapedTask(f) try consume(ttask) catch ex @@ -94,7 +94,7 @@ end end - ttask = CopyableTask(f) + ttask = TapedTask(f) try consume(ttask) catch ex @@ -113,7 +113,7 @@ end end - ttask = CopyableTask(f) + ttask = TapedTask(f) @test consume(ttask) == 2 try consume(ttask) @@ -133,9 +133,9 @@ end end - ttask = CopyableTask(f) + ttask = TapedTask(f) @test consume(ttask) == 2 - ttask2 = deepcopy(ttask) + ttask2 = copy(ttask) try consume(ttask2) catch ex @@ -155,10 +155,10 @@ end end - ttask = CopyableTask(f) + ttask = TapedTask(f) @test consume(ttask) == 0 @test consume(ttask) == 1 - a = deepcopy(ttask) + a = copy(ttask) @test consume(a) == 2 @test consume(a) == 3 @test consume(ttask) == 2 @@ -175,10 +175,10 @@ end end - ttask = CopyableTask(f) + ttask = TapedTask(f) @test consume(ttask) == 0 @test consume(ttask) == 1 - a = deepcopy(ttask) + a = copy(ttask) @test consume(a) == 2 @test consume(a) == 3 @test consume(ttask) == 2 @@ -221,11 +221,11 @@ end end - ttask = CopyableTask(f) + ttask = TapedTask(f) consume(ttask) consume(ttask) - a = deepcopy(ttask) + a = copy(ttask) consume(a) consume(a) From 1579bff25ad8ff4d62e53e5cd8890eb92a99c643 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 28 Feb 2025 13:59:10 +0000 Subject: [PATCH 27/69] Fix copying --- src/copyable_task.jl | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/copyable_task.jl b/src/copyable_task.jl index bcaa6083..7fca5c96 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -48,7 +48,7 @@ rather specific semantics of `copy`. Calling `copy` on a `TapedTask` does the fo `position` field -- for this element we use the copy we made in step 1. `_tape_copy` doesn't actually make a copy of the object at all if it is not either an -`Array`, a `Ref`, or an instance of one of the types listed in the task's `deepcopy_type` +`Array`, a `Ref`, or an instance of one of the types listed in the task's `deepcopy_types` field. If it is an instance of one of these types then `_tape_copy` just calls `deepcopy`. This behaviour is plainly entirely acceptable if the argument to `_tape_copy` is a bits @@ -77,13 +77,21 @@ this is a viable option. """ function Base.copy(t::T) where {T<:TapedTask} captures = t.mc.oc.captures - new_captures = map(Base.Fix2(_tape_copy, t.deepcopy_types), captures) + new_captures = map(Base.Fix2(_copy_capture, t.deepcopy_types), captures) new_position = new_captures[end] # baked in later on. new_args = map(Base.Fix2(_tape_copy, t.deepcopy_types), t.args) new_mc = Mooncake.replace_captures(t.mc, new_captures) return T(new_mc, new_args, new_position, t.deepcopy_types) end +function _copy_capture(r::Ref{T}, deepcopy_types::Type) where {T} + new_capture = Ref{T}() + if isassigned(r) + new_capture[] = _tape_copy(r[], deepcopy_types) + end + return new_capture +end + _tape_copy(v, deepcopy_types::Type) = v isa deepcopy_types ? deepcopy(v) : v # Not sure that we need this in the new implementation. From 5114d649fdf8f06f6762d14aa19522a67f9a3d9a Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 28 Feb 2025 14:46:01 +0000 Subject: [PATCH 28/69] Formatting --- src/copyable_task.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/copyable_task.jl b/src/copyable_task.jl index 7fca5c96..56adf250 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -108,7 +108,7 @@ end function TapedTask(fargs...; deepcopy_types::Type=Union{}) sig = typeof(fargs) mc, count_ref = build_callable(Base.code_ircode_by_type(sig)[1][1]) - return TapedTask(mc, fargs[2:end], count_ref, Union{deepcopy_types, Array, Ref}) + return TapedTask(mc, fargs[2:end], count_ref, Union{deepcopy_types,Array,Ref}) end function build_callable(ir::IRCode) From b803f58afd36c15af3a561e74100bedd04770faa Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 28 Feb 2025 15:16:22 +0000 Subject: [PATCH 29/69] Enable more tests --- src/copyable_task.jl | 4 + test/copyable_task.jl | 190 +++++++++++++++++++++--------------------- 2 files changed, 99 insertions(+), 95 deletions(-) diff --git a/src/copyable_task.jl b/src/copyable_task.jl index 56adf250..b2b7753e 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -346,6 +346,10 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} push!(inst_pairs, (id, inst)) elseif Meta.isexpr(stmt, :code_coverage_effect) push!(inst_pairs, (id, inst)) + elseif Meta.isexpr(stmt, :gc_preserve_begin) + push!(inst_pairs, (id, inst)) + elseif Meta.isexpr(stmt, :gc_preserve_end) + push!(inst_pairs, (id, inst)) elseif stmt isa ReturnNode # If returning an SSA, it might be one whose value was restored from before. # Therefore, grab it out of storage, rather than assuming that it is def-ed. diff --git a/test/copyable_task.jl b/test/copyable_task.jl index 19074e2a..f7313ead 100644 --- a/test/copyable_task.jl +++ b/test/copyable_task.jl @@ -187,28 +187,28 @@ @test consume(ttask) == 5 end - # # Test case 3: Dict objects are shallowly copied. - # @testset "Dict objects shallow copy" begin - # function f() - # t = Dict(1 => 10, 2 => 20) - # while true - # produce(t[1]) - # t[1] = 1 + t[1] - # end - # end + # Test case 3: Dict objects are shallowly copied. + @testset "Dict objects shallow copy" begin + function f() + t = Dict(1 => 10, 2 => 20) + while true + produce(t[1]) + t[1] = 1 + t[1] + end + end - # ttask = TapedTask(f) + ttask = TapedTask(f) - # @test consume(ttask) == 10 - # @test consume(ttask) == 11 + @test consume(ttask) == 10 + @test consume(ttask) == 11 - # a = copy(ttask) - # @test consume(a) == 12 - # @test consume(a) == 13 + a = copy(ttask) + @test consume(a) == 12 + @test consume(a) == 13 - # @test consume(ttask) == 14 - # @test consume(ttask) == 15 - # end + @test consume(ttask) == 14 + @test consume(ttask) == 15 + end @testset "Array deep copy 2" begin function f() @@ -232,15 +232,15 @@ @test consume(ttask) == 2 @test consume(a) == 4 - # DATA = Dict{Task,Array}() - # function g() - # ta = zeros(UInt64, 4) - # for i in 1:4 - # ta[i] = hash(current_task()) - # DATA[current_task()] = ta - # produce(ta[i]) - # end - # end + DATA = Dict{Task,Array}() + function g() + ta = zeros(UInt64, 4) + for i in 1:4 + ta[i] = hash(current_task()) + DATA[current_task()] = ta + produce(ta[i]) + end + end # ttask = TapedTask(g) # @test consume(ttask) == hash(ttask.task) # index = 1 @@ -258,73 +258,73 @@ # [hash(ttask.task), hash(ttask.task), hash(a.task), hash(a.task)] end - # # Test atomic values. - # @testset "ref atomic" begin - # function f() - # t = Ref(1) - # t[] = 0 - # for _ in 1:6 - # produce(t[]) - # t[] - # t[] += 1 - # end - # end - - # ctask = TapedTask(f) - - # consume(ctask) - # consume(ctask) - - # a = copy(ctask) - # consume(a) - # consume(a) - - # @test consume(ctask) == 2 - # @test consume(a) == 4 - # end - - # @testset "ref of dictionary deep copy" begin - # function f() - # t = Ref(Dict("A" => 1, 5 => "B")) - # t[]["A"] = 0 - # for _ in 1:6 - # produce(t[]["A"]) - # t[]["A"] += 1 - # end - # end - - # ctask = TapedTask(f) - - # consume(ctask) - # consume(ctask) - - # a = copy(ctask) - # consume(a) - # consume(a) - - # @test consume(ctask) == 2 - # @test consume(a) == 4 - # end - - # @testset "override deepcopy_types #57" begin - # struct DummyType end - - # function f(start::Int) - # t = [start] - # while true - # produce(t[1]) - # t[1] = 1 + t[1] - # end - # end - - # ttask = TapedTask(f, 0; deepcopy_types=DummyType) - # consume(ttask) - - # ttask2 = copy(ttask) - # consume(ttask2) - - # @test consume(ttask) == 1 - # @test consume(ttask2) == 2 - # end + # Test atomic values. + @testset "ref atomic" begin + function f() + t = Ref(1) + t[] = 0 + for _ in 1:6 + produce(t[]) + t[] + t[] += 1 + end + end + + ctask = TapedTask(f) + + consume(ctask) + consume(ctask) + + a = copy(ctask) + consume(a) + consume(a) + + @test consume(ctask) == 2 + @test consume(a) == 4 + end + + @testset "ref of dictionary deep copy" begin + function f() + t = Ref(Dict("A" => 1, 5 => "B")) + t[]["A"] = 0 + for _ in 1:6 + produce(t[]["A"]) + t[]["A"] += 1 + end + end + + ctask = TapedTask(f) + + consume(ctask) + consume(ctask) + + a = copy(ctask) + consume(a) + consume(a) + + @test consume(ctask) == 2 + @test consume(a) == 4 + end + + @testset "override deepcopy_types #57" begin + struct DummyType end + + function f(start::Int) + t = [start] + while true + produce(t[1]) + t[1] = 1 + t[1] + end + end + + ttask = TapedTask(f, 0; deepcopy_types=DummyType) + consume(ttask) + + ttask2 = copy(ttask) + consume(ttask2) + + @test consume(ttask) == 1 + @test consume(ttask2) == 2 + end end end From 6aeaac3c08127a7beff1e3582b43845eb18a320e Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 7 Mar 2025 12:23:51 +0000 Subject: [PATCH 30/69] Implement dynamic scope --- src/copyable_task.jl | 180 ++++++++++++++++++------------------------ src/test_utils.jl | 26 +++--- test/copyable_task.jl | 130 +++--------------------------- 3 files changed, 102 insertions(+), 234 deletions(-) diff --git a/src/copyable_task.jl b/src/copyable_task.jl index b2b7753e..bfe92a74 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -1,122 +1,79 @@ +const dynamic_scope = Base.ScopedValues.ScopedValue{Any}(0) + +""" + get_dynamic_scope() + +Returns the dynamic scope associated to `Libtask`. If called from inside a `TapedTask`, this +will return whatever is contained in its `dynamic_scope` field. + +See also [`set_dynamic_scope!`](@ref). +""" +get_dynamic_scope() = dynamic_scope[] + __v::Int = 5 @noinline function produce(x) global __v = 4 return nothing end -mutable struct TapedTask{Tmc<:MistyClosure,Targs} - const mc::Tmc +function build_callable(ir::IRCode) + seed_id!() + bb, refs = derive_copyable_task_ir(BBCode(ir)) + ir = IRCode(bb) + optimised_ir = Mooncake.optimise_ir!(ir) + return MistyClosure(optimised_ir, refs...; do_compile=true), refs[end] +end + +mutable struct TapedTask{Tdynamic_scope,Targs,Tmc<:MistyClosure} + dynamic_scope::Tdynamic_scope args::Targs - const position::Base.RefValue{Int32} - const deepcopy_types::Type + const mc::Tmc + const position::Base.RefValue{Int32} # As does this end """ - Base.copy(t::TapedTask) - -Makes a copy of `t` which can be run. For the most part, calls to [`consume`](@ref) on the -copied task will give the same results as the original. There are, however, substantial -limitations to this, detailed in the extended help. + TapedTask(dynamic_scope::Any, f, args...) -# Extended Help - -We call a copy of a `TapedTask` _consistent_ with the original if the call to `==` in the -loop below always returns `true`: -```julia -t = -tc = copy(t) -for (v, vc) in zip(t, tc) - v == vc -end -``` -(provided that `==` is implemented for all `v` that are produced). Convesely, we refer to a -copy as _inconsistent_ if this property doesn't hold. In order to ensure -consistency, we need to ensure that independent copies are made of anything which might be -mutated by the task or its copy during subsequent `consume` calls. Failure to do this can -cause problems if, for example, a task reads-to and writes-from some memory. -If we call `consume` on the original task, and then on a copy of it, any changes made by the -original will be visible to the copy, potentially causing its behaviour to differ. This can -manifest itself as a race condition if the task and its copies are run concurrently. - -To understand a bit more about when a task is / is not consistent, we need to dig into the -rather specific semantics of `copy`. Calling `copy` on a `TapedTask` does the following: -1. `copy` the `position` field, -2. `map`s `_tape_copy` over the `args` field, and -3. `map`s `_tape_copy` over the all of the data closed over in the `OpaqueClosure` which - implements the task (specifically the values _inside_ the `Ref`s) -- call these the - `captures`. Except the last elements of this data, because this is `===` to the - `position` field -- for this element we use the copy we made in step 1. - -`_tape_copy` doesn't actually make a copy of the object at all if it is not either an -`Array`, a `Ref`, or an instance of one of the types listed in the task's `deepcopy_types` -field. If it is an instance of one of these types then `_tape_copy` just calls `deepcopy`. - -This behaviour is plainly entirely acceptable if the argument to `_tape_copy` is a bits -type. For any `mutable struct`s which aren't flagged for `deepcopy`ing, we have an immediate -risk of inconsistency. Similarly, for any `struct` types which aren't bits types (e.g. -those which contain an `Array`, `Ref`, or some other `mutable struct` either directly as one -of their fields, or as a field of a field, etc), we have an inconsistency risk. - -Furthermore, for anything which _is_ `deepcopy`ed we introduce inconsistency risks. If, for -example, two elements of the data closed over by the task alias one another, calling -`deepcopy` on them separately will cause the copies to _not_ alias one another. -The same thing can happen if one element is `deepcopy`ed and the other not. For example, if -we have both an `Array` `x` and `view(x, inds)` stored in separate elements of `captures`, -`x` will be `deepcopy`ed, while `view(x, inds)` will not. In the copy of `captures`, the -`view` will still be a view into the original `x`, not the `deepcopy`ed version. Again, this -introduces inconsistency. - -Why do we have these semantics? We have them because Libtask has always had them, and at the -time of writing we're unsure whether AdvancedPS.jl, and by extension Turing.jl rely on this -behaviour. - -What other options do we have? Simply calling `deepcopy` on a `TapedTask` works fine, and -should reliably result in consistent behaviour between a `TapedTask` and any copies of it. -This would, therefore, be a preferable implementation. We should try to determine whether -this is a viable option. +Construct a `TapedTask` with the specified `dynamic_scope`, for function `f` and positional +arguments `args`. """ -function Base.copy(t::T) where {T<:TapedTask} - captures = t.mc.oc.captures - new_captures = map(Base.Fix2(_copy_capture, t.deepcopy_types), captures) - new_position = new_captures[end] # baked in later on. - new_args = map(Base.Fix2(_tape_copy, t.deepcopy_types), t.args) - new_mc = Mooncake.replace_captures(t.mc, new_captures) - return T(new_mc, new_args, new_position, t.deepcopy_types) -end - -function _copy_capture(r::Ref{T}, deepcopy_types::Type) where {T} - new_capture = Ref{T}() - if isassigned(r) - new_capture[] = _tape_copy(r[], deepcopy_types) - end - return new_capture +function TapedTask(dynamic_scope::Any, fargs...) + mc, count_ref = build_callable(Base.code_ircode_by_type(typeof(fargs))[1][1]) + return TapedTask(dynamic_scope, fargs[2:end], mc, count_ref) end -_tape_copy(v, deepcopy_types::Type) = v isa deepcopy_types ? deepcopy(v) : v - -# Not sure that we need this in the new implementation. -_tape_copy(box::Core.Box, deepcopy_types::Type) = error("Found a box") +""" + set_dynamic_scope!(t::TapedTask, new_dynamic_scope)::Nothing -@inline consume(t::TapedTask) = t.mc(t.args...) +Set the `dynamic_scope` of `t` to `new_dynamic_scope`. Any references to +`LibTask.dynamic_scope` in future calls to `consume(t)` (either directly, or implicitly via +iteration) will see this new value. -function initialise!(t::TapedTask, args::Vararg{Any,N})::Nothing where {N} - t.position[] = -1 - t.args = args +See also: [`get_dynamic_scope`](@ref). +""" +function set_dynamic_scope!(t::TapedTask{T}, new_dynamic_scope::T)::Nothing where {T} + t.dynamic_scope = new_dynamic_scope return nothing end -function TapedTask(fargs...; deepcopy_types::Type=Union{}) - sig = typeof(fargs) - mc, count_ref = build_callable(Base.code_ircode_by_type(sig)[1][1]) - return TapedTask(mc, fargs[2:end], count_ref, Union{deepcopy_types,Array,Ref}) -end +""" + Base.copy(t::TapedTask) -function build_callable(ir::IRCode) - seed_id!() - bb, refs = derive_copyable_task_ir(BBCode(ir)) - ir = IRCode(bb) - optimised_ir = Mooncake.optimise_ir!(ir) - return MistyClosure(optimised_ir, refs...; do_compile=true), refs[end] +Makes a completely independent copy of `t`. `consume` can be applied to either the copy of +`t` or the original without advancing the state of the other. +""" +Base.copy(t::T) where {T<:TapedTask} = deepcopy(t) + +""" + consume(t::TapedTask) + +Run `t` until it makes a call to `produce`. If this is the first time that `t` has been +called, it start execution from the entry point. If `consume` has previously been called on +`t`, it will resume from the last `produce` call. If there are no more `produce` calls, +`nothing` will be returned. +""" +@inline function consume(t::TapedTask) + return Base.ScopedValues.with(() -> t.mc(t.args...), dynamic_scope => t.dynamic_scope) end """ @@ -288,7 +245,7 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} n += 1 ssa_id_to_ref_index_map[id] = n ref_index_to_ssa_id_map[n] = id - ref_index_to_type_map[n] = stmt.type + ref_index_to_type_map[n] = CC.widenconst(stmt.type) end end @@ -382,8 +339,25 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} push!(inst_pairs, (id, inst)) elseif stmt isa Nothing push!(inst_pairs, (id, inst)) + elseif stmt isa GlobalRef + ref_ind = ssa_id_to_ref_index_map[id] + expr = Expr(:call, set_ref_at!, refs_id, ref_ind, stmt) + push!(inst_pairs, (id, new_inst(expr))) + elseif stmt isa Core.PiNode + if stmt.val isa ID + ref_ind = ssa_id_to_ref_index_map[stmt.val] + val_id = ID() + expr = Expr(:call, get_ref_at, refs_id, ref_ind) + push!(inst_pairs, (val_id, new_inst(expr))) + push!(inst_pairs, (id, new_inst(Core.PiNode(val_id, stmt.typ)))) + else + push!(inst_pairs, (id, inst)) + end + set_ind = ssa_id_to_ref_index_map[id] + set_expr = Expr(:call, set_ref_at!, refs_id, set_ind, id) + push!(inst_pairs, (ID(), new_inst(set_expr))) else - throw(error("Unhandled stmt $stmt")) + throw(error("Unhandled stmt $stmt of type $(typeof(stmt))")) end end @@ -451,7 +425,9 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} end # Helper used in `derive_copyable_task_ir`. -@inline get_ref_at(refs::R, n::Int) where {R<:Tuple} = refs[n][] +@inline function get_ref_at(refs::R, n::Int) where {R<:Tuple} + return refs[n][] +end # Helper used in `derive_copyable_task_ir`. @inline function set_ref_at!(refs::R, n::Int, val) where {R<:Tuple} diff --git a/src/test_utils.jl b/src/test_utils.jl index fd343b5b..94deea52 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -6,6 +6,7 @@ using ..Libtask: TapedTask struct Testcase name::String + dynamic_scope::Any fargs::Tuple expected_iteration_results::Vector end @@ -14,7 +15,7 @@ function (case::Testcase)() testset = @testset "$(case.name)" begin # Construct the task. - t = TapedTask(case.fargs...) + t = TapedTask(case.dynamic_scope, case.fargs...) # Iterate through t. Record the results, and take a copy after each iteration. iteration_results = [] @@ -39,21 +40,22 @@ function test_cases() return Testcase[ Testcase( "single block", + nothing, (single_block, 5.0), [sin(5.0), sin(sin(5.0)), sin(sin(sin(5.0))), sin(sin(sin(sin(5.0))))], ), - Testcase("produce old", (produce_old_value, 5.0), [sin(5.0), sin(5.0)]), - Testcase("branch on old value l", (branch_on_old_value, 2.0), [true, 2.0]), - Testcase("branch on old value r", (branch_on_old_value, -1.0), [false, -2.0]), - Testcase("no produce", (no_produce_test, 5.0, 4.0), []), - Testcase("new object", (new_object_test, 5, 4), [C(5, 4), C(5, 4)]), - Testcase("branching test l", (branching_test, 5.0, 4.0), [string(sin(5.0))]), - Testcase("branching test r", (branching_test, 4.0, 5.0), [sin(4.0) * cos(5.0)]), - Testcase("unused argument test", (unused_argument_test, 3), [1]), - Testcase("test with const", (test_with_const,), [1]), - Testcase("while loop", (while_loop,), collect(1:9)), + Testcase("produce old", nothing, (produce_old_value, 5.0), [sin(5.0), sin(5.0)]), + Testcase("branch on old value l", nothing, (branch_on_old_value, 2.0), [true, 2.0]), + Testcase("branch on old value r", nothing, (branch_on_old_value, -1.0), [false, -2.0]), + Testcase("no produce", nothing, (no_produce_test, 5.0, 4.0), []), + Testcase("new object", nothing, (new_object_test, 5, 4), [C(5, 4), C(5, 4)]), + Testcase("branching test l", nothing, (branching_test, 5.0, 4.0), [string(sin(5.0))]), + Testcase("branching test r", nothing, (branching_test, 4.0, 5.0), [sin(4.0) * cos(5.0)]), + Testcase("unused argument test", nothing, (unused_argument_test, 3), [1]), + Testcase("test with const", nothing, (test_with_const,), [1]), + Testcase("while loop", nothing, (while_loop,), collect(1:9)), Testcase( - "foreigncall tester", (foreigncall_tester, "hi"), [Ptr{UInt8}, Ptr{UInt8}] + "foreigncall tester", nothing, (foreigncall_tester, "hi"), [Ptr{UInt8}, Ptr{UInt8}] ), # Failing tests diff --git a/test/copyable_task.jl b/test/copyable_task.jl index f7313ead..0a76cff3 100644 --- a/test/copyable_task.jl +++ b/test/copyable_task.jl @@ -2,22 +2,6 @@ for case in Libtask.TestUtils.test_cases() case() end - # @testset "construction" begin - # function f() - # t = 1 - # while true - # produce(t) - # t = 1 + t - # end - # end - - # ttask = TapedTask(f) - # @test consume(ttask) == 1 - - # ttask = TapedTask((f, Union{})) - # @test consume(ttask) == 1 - # end - @testset "iteration" begin function f() t = 1 @@ -27,7 +11,7 @@ end end - ttask = TapedTask(f) + ttask = TapedTask(nothing, f) next = iterate(ttask) @test next === (1, nothing) @@ -57,7 +41,7 @@ end end - ttask = TapedTask(f) + ttask = TapedTask(nothing, f) try consume(ttask) catch ex @@ -75,7 +59,7 @@ end end - ttask = TapedTask(f) + ttask = TapedTask(nothing, f) try consume(ttask) catch ex @@ -94,7 +78,7 @@ end end - ttask = TapedTask(f) + ttask = TapedTask(nothing, f) try consume(ttask) catch ex @@ -113,7 +97,7 @@ end end - ttask = TapedTask(f) + ttask = TapedTask(nothing, f) @test consume(ttask) == 2 try consume(ttask) @@ -133,7 +117,7 @@ end end - ttask = TapedTask(f) + ttask = TapedTask(nothing, f) @test consume(ttask) == 2 ttask2 = copy(ttask) try @@ -155,7 +139,7 @@ end end - ttask = TapedTask(f) + ttask = TapedTask(nothing, f) @test consume(ttask) == 0 @test consume(ttask) == 1 a = copy(ttask) @@ -175,7 +159,7 @@ end end - ttask = TapedTask(f) + ttask = TapedTask(nothing, f) @test consume(ttask) == 0 @test consume(ttask) == 1 a = copy(ttask) @@ -187,29 +171,6 @@ @test consume(ttask) == 5 end - # Test case 3: Dict objects are shallowly copied. - @testset "Dict objects shallow copy" begin - function f() - t = Dict(1 => 10, 2 => 20) - while true - produce(t[1]) - t[1] = 1 + t[1] - end - end - - ttask = TapedTask(f) - - @test consume(ttask) == 10 - @test consume(ttask) == 11 - - a = copy(ttask) - @test consume(a) == 12 - @test consume(a) == 13 - - @test consume(ttask) == 14 - @test consume(ttask) == 15 - end - @testset "Array deep copy 2" begin function f() t = Array{Int}(undef, 1) @@ -221,7 +182,7 @@ end end - ttask = TapedTask(f) + ttask = TapedTask(nothing, f) consume(ttask) consume(ttask) @@ -231,56 +192,6 @@ @test consume(ttask) == 2 @test consume(a) == 4 - - DATA = Dict{Task,Array}() - function g() - ta = zeros(UInt64, 4) - for i in 1:4 - ta[i] = hash(current_task()) - DATA[current_task()] = ta - produce(ta[i]) - end - end - - # ttask = TapedTask(g) - # @test consume(ttask) == hash(ttask.task) # index = 1 - # @test consume(ttask) == hash(ttask.task) # index = 2 - - # a = copy(ttask) - # @test consume(a) == hash(a.task) # index = 3 - # @test consume(a) == hash(a.task) # index = 4 - - # @test consume(ttask) == hash(ttask.task) # index = 3 - - # @test DATA[ttask.task] == - # [hash(ttask.task), hash(ttask.task), hash(ttask.task), 0] - # @test DATA[a.task] == - # [hash(ttask.task), hash(ttask.task), hash(a.task), hash(a.task)] - end - - # Test atomic values. - @testset "ref atomic" begin - function f() - t = Ref(1) - t[] = 0 - for _ in 1:6 - produce(t[]) - t[] - t[] += 1 - end - end - - ctask = TapedTask(f) - - consume(ctask) - consume(ctask) - - a = copy(ctask) - consume(a) - consume(a) - - @test consume(ctask) == 2 - @test consume(a) == 4 end @testset "ref of dictionary deep copy" begin @@ -293,7 +204,7 @@ end end - ctask = TapedTask(f) + ctask = TapedTask(nothing, f) consume(ctask) consume(ctask) @@ -305,26 +216,5 @@ @test consume(ctask) == 2 @test consume(a) == 4 end - - @testset "override deepcopy_types #57" begin - struct DummyType end - - function f(start::Int) - t = [start] - while true - produce(t[1]) - t[1] = 1 + t[1] - end - end - - ttask = TapedTask(f, 0; deepcopy_types=DummyType) - consume(ttask) - - ttask2 = copy(ttask) - consume(ttask2) - - @test consume(ttask) == 1 - @test consume(ttask2) == 2 - end end end From b546f212016baf36741c250103c1f9f40e58cf7e Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 7 Mar 2025 12:30:45 +0000 Subject: [PATCH 31/69] Test dynamic scope correctness --- src/test_utils.jl | 7 +++++++ test/copyable_task.jl | 12 ++++++++++++ 2 files changed, 19 insertions(+) diff --git a/src/test_utils.jl b/src/test_utils.jl index 94deea52..83d9bb03 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -57,6 +57,8 @@ function test_cases() Testcase( "foreigncall tester", nothing, (foreigncall_tester, "hi"), [Ptr{UInt8}, Ptr{UInt8}] ), + Testcase("dynamic scope 1", 5, (dynamic_scope_tester_1,), [5]), + Testcase("dynamic scope 2", 6, (dynamic_scope_tester_1,), [6]), # Failing tests # Testcase("nested", (nested_outer, ), [true, false]), @@ -160,4 +162,9 @@ function nested_outer() return nothing end +function dynamic_scope_tester_1() + produce(Libtask.get_dynamic_scope()) + return nothing +end + end diff --git a/test/copyable_task.jl b/test/copyable_task.jl index 0a76cff3..9fd18c4c 100644 --- a/test/copyable_task.jl +++ b/test/copyable_task.jl @@ -2,6 +2,18 @@ for case in Libtask.TestUtils.test_cases() case() end + @testset "set_dynamic_scope" begin + function f() + produce(typeassert(Libtask.get_dynamic_scope(), Int)) + produce(typeassert(Libtask.get_dynamic_scope(), Int)) + return nothing + end + t = TapedTask(5, f) + @test consume(t) == 5 + Libtask.set_dynamic_scope!(t, 6) + @test consume(t) == 6 + @test consume(t) === nothing + end @testset "iteration" begin function f() t = 1 From 8f4e4b5661212de9679f2b5b1471e7fd1e38cc6e Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 7 Mar 2025 14:28:27 +0000 Subject: [PATCH 32/69] README and NEWS overhaul --- NEWS.md | 8 +++++ README.md | 98 ++----------------------------------------------------- 2 files changed, 11 insertions(+), 95 deletions(-) create mode 100644 NEWS.md diff --git a/NEWS.md b/NEWS.md new file mode 100644 index 00000000..db7b2b8b --- /dev/null +++ b/NEWS.md @@ -0,0 +1,8 @@ +- From v0.6.0, Libtask is implemented by recording all the computing + to a tape and copying that tape. Before that version, it is based on + a tricky hack on the Julia internals. You can check the commit + history of this repo to see the details. + +- From version 0.9.0, the old `TArray` and `TRef` types are completely removed, where + previously they were only deprecated. Additionally, the internals have been completely + overhauled, and the public interface more precisely defined. See the docs for more info. diff --git a/README.md b/README.md index 219bd5b4..73b650a8 100644 --- a/README.md +++ b/README.md @@ -2,100 +2,8 @@ [![Libtask Testing](https://github.com/TuringLang/Libtask.jl/workflows/Libtask%20Testing/badge.svg)](https://github.com/TuringLang/Libtask.jl/actions?branch=master) -Tape based task copying in Turing -## Getting Started +Resumable and copyable functions in Julia, with optional dynamic scope. +See the docs for example usage. -Stack allocated objects are always deep copied: - -```julia -using Libtask - -function f() - t = 0 - for _ in 1:10 - produce(t) - t = 1 + t - end -end - -ttask = TapedTask(f) - -@show consume(ttask) # 0 -@show consume(ttask) # 1 - -a = copy(ttask) -@show consume(a) # 2 -@show consume(a) # 3 - -@show consume(ttask) # 2 -@show consume(ttask) # 3 -``` - -Heap-allocated Array and Ref objects are deep copied by default: - -```julia -using Libtask - -function f() - t = [0 1 2] - for _ in 1:10 - produce(t[1]) - t[1] = 1 + t[1] - end -end - -ttask = TapedTask(f) - -@show consume(ttask) # 0 -@show consume(ttask) # 1 - -a = copy(ttask) -@show consume(a) # 2 -@show consume(a) # 3 - -@show consume(ttask) # 2 -@show consume(ttask) # 3 -``` - -Other heap-allocated objects (e.g., `Dict`) are shallow copied, by default: - -```julia -using Libtask - -function f() - t = Dict(1=>10, 2=>20) - while true - produce(t[1]) - t[1] = 1 + t[1] - end -end - -ttask = TapedTask(f) - -@show consume(ttask) # 10 -@show consume(ttask) # 11 - -a = copy(ttask) -@show consume(a) # 12 -@show consume(a) # 13 - -@show consume(ttask) # 14 -@show consume(ttask) # 15 -``` - -Notes: - -- The [Turing](https://github.com/TuringLang/Turing.jl) probabilistic - programming language uses this task copying feature in an efficient - implementation of the [particle - filtering](https://en.wikipedia.org/wiki/Particle_filter) sampling - algorithm for arbitrary order [Markov - processes](https://en.wikipedia.org/wiki/Markov_model#Hidden_Markov_model). - -- From v0.6.0, Libtask is implemented by recording all the computing - to a tape and copying that tape. Before that version, it is based on - a tricky hack on the Julia internals. You can check the commit - history of this repo to see the details. - -- From version 0.9.0, the old `TArray` and `TRef` types are completely removed, where previously they were deprecated. \ No newline at end of file +Used in the [Turing](https://github.com/TuringLang/Turing.jl) probabilistic programming language to implement various particle-based inference methods, for example those in [AdvancedPS.jl](https://github.com/TuringLang/AdvancedPS.jl/). From 4d0b4234499fa3cf1843b043af4361689856a2bb Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 7 Mar 2025 14:28:44 +0000 Subject: [PATCH 33/69] Export get_dynamic_scope and set_dynamic_scope --- src/Libtask.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Libtask.jl b/src/Libtask.jl index 430c98a6..7730f10e 100644 --- a/src/Libtask.jl +++ b/src/Libtask.jl @@ -15,6 +15,6 @@ using Core.Compiler: Argument, IRCode, ReturnNode include("copyable_task.jl") include("test_utils.jl") -export TapedTask, consume, produce +export TapedTask, consume, produce, get_dynamic_scope, set_dynamic_scope! end From bde142f88870a775119e0f0a1437b8c13b659eca Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 7 Mar 2025 14:28:56 +0000 Subject: [PATCH 34/69] Placeholder docstring for produce --- src/copyable_task.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/copyable_task.jl b/src/copyable_task.jl index bfe92a74..a6118319 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -11,6 +11,15 @@ See also [`set_dynamic_scope!`](@ref). get_dynamic_scope() = dynamic_scope[] __v::Int = 5 + +""" + produce(x) + +When run inside a [`TapedTask`](@ref), will immediately yield to the caller, producing value +`x`. + +See also: [`Libtask.consume`](@ref) +""" @noinline function produce(x) global __v = 4 return nothing From 4e58e11853f22f4e37397ad134b10a6de1d7742c Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 7 Mar 2025 14:29:22 +0000 Subject: [PATCH 35/69] Initial docs --- docs/Project.toml | 3 +++ docs/make.jl | 3 +++ docs/src/index.md | 25 +++++++++++++++++++++++++ 3 files changed, 31 insertions(+) create mode 100644 docs/Project.toml create mode 100644 docs/make.jl create mode 100644 docs/src/index.md diff --git a/docs/Project.toml b/docs/Project.toml new file mode 100644 index 00000000..acf52482 --- /dev/null +++ b/docs/Project.toml @@ -0,0 +1,3 @@ +[deps] +Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" diff --git a/docs/make.jl b/docs/make.jl new file mode 100644 index 00000000..67303158 --- /dev/null +++ b/docs/make.jl @@ -0,0 +1,3 @@ +using Documenter, Libtask + +makedocs(sitename="Libtask") diff --git a/docs/src/index.md b/docs/src/index.md new file mode 100644 index 00000000..3015115a --- /dev/null +++ b/docs/src/index.md @@ -0,0 +1,25 @@ +# Libtask + +Libtask is best explained by the docstring for [`TapedTask`](@ref): +```@docs; canonical=true +Libtask.TapedTask +``` + +The functions discussed the above docstring (in addition to [`TapedTask`](@ref) itself) form the +public interface of Libtask.jl. +They divide neatly into two kinds of functions: those which are used to construct and +manipulate [`TapedTask`](@ref)s, and those which are intended to be used _inside_ a +[`TapedTask`](@ref). + +First, manipulation of [`TapedTask`](@ref)s: +```@docs; canonical=true +Libtask.consume +Base.copy(::Libtask.TapedTask) +Libtask.set_dynamic_scope! +``` + +The functions which enable special functionality inside a [`TapedTask`](@ref)s are: +```@docs; canonical=true +Libtask.produce +Libtask.get_dynamic_scope +``` From 21816ee2f7ebe187811ec6c9c62801d975868545 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 7 Mar 2025 14:29:30 +0000 Subject: [PATCH 36/69] Ignore build folder of docs --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 1c7787bd..40e18fcd 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ # Projects files Manifest.toml *.cov +docs/build \ No newline at end of file From 32052736dd2bfc5d79cbc043d59a6e974058a291 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 7 Mar 2025 14:31:16 +0000 Subject: [PATCH 37/69] Update cache action --- .github/workflows/Testing.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/Testing.yaml b/.github/workflows/Testing.yaml index 2b8b8685..0d0facc7 100644 --- a/.github/workflows/Testing.yaml +++ b/.github/workflows/Testing.yaml @@ -30,7 +30,7 @@ jobs: with: version: ${{ matrix.version }} arch: ${{ matrix.arch }} - - uses: actions/cache@v1 + - uses: actions/cache@v4 env: cache-name: cache-artifacts with: From 06c168f530f7b30feb73e9b7faa8d8c840c48008 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 7 Mar 2025 14:40:52 +0000 Subject: [PATCH 38/69] Formatting --- docs/make.jl | 2 +- src/test_utils.jl | 17 +++++++++++++---- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index 67303158..caf80499 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,3 +1,3 @@ using Documenter, Libtask -makedocs(sitename="Libtask") +makedocs(; sitename="Libtask") diff --git a/src/test_utils.jl b/src/test_utils.jl index 83d9bb03..e126179f 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -46,16 +46,25 @@ function test_cases() ), Testcase("produce old", nothing, (produce_old_value, 5.0), [sin(5.0), sin(5.0)]), Testcase("branch on old value l", nothing, (branch_on_old_value, 2.0), [true, 2.0]), - Testcase("branch on old value r", nothing, (branch_on_old_value, -1.0), [false, -2.0]), + Testcase( + "branch on old value r", nothing, (branch_on_old_value, -1.0), [false, -2.0] + ), Testcase("no produce", nothing, (no_produce_test, 5.0, 4.0), []), Testcase("new object", nothing, (new_object_test, 5, 4), [C(5, 4), C(5, 4)]), - Testcase("branching test l", nothing, (branching_test, 5.0, 4.0), [string(sin(5.0))]), - Testcase("branching test r", nothing, (branching_test, 4.0, 5.0), [sin(4.0) * cos(5.0)]), + Testcase( + "branching test l", nothing, (branching_test, 5.0, 4.0), [string(sin(5.0))] + ), + Testcase( + "branching test r", nothing, (branching_test, 4.0, 5.0), [sin(4.0) * cos(5.0)] + ), Testcase("unused argument test", nothing, (unused_argument_test, 3), [1]), Testcase("test with const", nothing, (test_with_const,), [1]), Testcase("while loop", nothing, (while_loop,), collect(1:9)), Testcase( - "foreigncall tester", nothing, (foreigncall_tester, "hi"), [Ptr{UInt8}, Ptr{UInt8}] + "foreigncall tester", + nothing, + (foreigncall_tester, "hi"), + [Ptr{UInt8}, Ptr{UInt8}], ), Testcase("dynamic scope 1", 5, (dynamic_scope_tester_1,), [5]), Testcase("dynamic scope 2", 6, (dynamic_scope_tester_1,), [6]), From 258a4cdd80f5c559834bba42f0a50dd63c8848f8 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 7 Mar 2025 14:55:14 +0000 Subject: [PATCH 39/69] Ignore all top-level manifest files --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 40e18fcd..bb8f9cf6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ # Projects files -Manifest.toml +Manifest* *.cov docs/build \ No newline at end of file From a3c2162102786c8159d4e9e0b347bd4842c9569c Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 7 Mar 2025 14:55:24 +0000 Subject: [PATCH 40/69] Add dependency on ScopedValues --- Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index 00fbaebd..ce0f4b20 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ version = "0.8.8" [deps] MistyClosures = "dbe65cb8-6be2-42dd-bbc5-4196aaced4f4" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] @@ -15,6 +16,7 @@ Aqua = "0.8.11" JuliaFormatter = "1.0.62" MistyClosures = "2.0.0" Mooncake = "0.4.99" +ScopedValues = "1.3.0" Test = "1" julia = "1.10.8" From 4bb0dfa3d2485a833fcf4a0eb7f924fbdd4f72db Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 7 Mar 2025 14:55:37 +0000 Subject: [PATCH 41/69] Fix on LTS --- src/Libtask.jl | 8 ++++++++ src/copyable_task.jl | 4 ++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/Libtask.jl b/src/Libtask.jl index 7730f10e..757e8294 100644 --- a/src/Libtask.jl +++ b/src/Libtask.jl @@ -8,6 +8,14 @@ using Mooncake: IDGotoIfNot, IDGotoNode, IDPhiNode, Switch # We'll emit `MistyClosure`s rather than `OpaqueClosure`s. using MistyClosures +# ScopedValues only became available as part of `Base` in v1.11. Therefore, on v1.10 we +# need to use the `ScopedValues` package. +@static if VERSION < v"1.11" + using ScopedValues: ScopedValue, with +else + using Base.ScopedValues: ScopedValue, with +end + # Import some names from the compiler. const CC = Core.Compiler using Core.Compiler: Argument, IRCode, ReturnNode diff --git a/src/copyable_task.jl b/src/copyable_task.jl index a6118319..8f29342a 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -1,4 +1,4 @@ -const dynamic_scope = Base.ScopedValues.ScopedValue{Any}(0) +const dynamic_scope = ScopedValue{Any}(0) """ get_dynamic_scope() @@ -82,7 +82,7 @@ called, it start execution from the entry point. If `consume` has previously bee `nothing` will be returned. """ @inline function consume(t::TapedTask) - return Base.ScopedValues.with(() -> t.mc(t.args...), dynamic_scope => t.dynamic_scope) + return with(() -> t.mc(t.args...), dynamic_scope => t.dynamic_scope) end """ From a63660e621f921ecdec9a80808a265f436d3d418 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 7 Mar 2025 16:53:44 +0000 Subject: [PATCH 42/69] Do not check for stale deps on 1.11 --- test/runtests.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 54876973..0a049368 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,7 +1,8 @@ include("front_matter.jl") @testset "Libtask" begin @testset "quality" begin - Aqua.test_all(Libtask) + # ScopedValues is stale on 1.11. + Aqua.test_all(Libtask; stale_deps=VERSION < v"1.11" ? true : false) @test JuliaFormatter.format(Libtask; verbose=false, overwrite=false) end include("copyable_task.jl") From 4e1e5dfe027068a893642eda1655a76f3b5411d4 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 7 Mar 2025 17:36:17 +0000 Subject: [PATCH 43/69] Some docs --- docs/make.jl | 15 ++++++++- docs/src/internals.md | 7 +++++ src/copyable_task.jl | 72 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 93 insertions(+), 1 deletion(-) create mode 100644 docs/src/internals.md diff --git a/docs/make.jl b/docs/make.jl index caf80499..e3fa7cac 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,3 +1,16 @@ using Documenter, Libtask -makedocs(; sitename="Libtask") +DocMeta.setdocmeta!( + Libtask, + :DocTestSetup, + quote + using Libtask + end; + recursive=true, +) + +makedocs(; + sitename="Libtask", doctest=true, pages=["index.md", "internals.md"], modules=[Libtask] +) + +deploydocs(; repo="github.com/TuringLang/Libtask.jl.git", push_preview=true) diff --git a/docs/src/internals.md b/docs/src/internals.md new file mode 100644 index 00000000..d98bb6c4 --- /dev/null +++ b/docs/src/internals.md @@ -0,0 +1,7 @@ +# Internals + +```@docs; canonical=true +Libtask.produce_value +Libtask.is_produce_stmt +Libtask.might_produce +``` diff --git a/src/copyable_task.jl b/src/copyable_task.jl index 8f29342a..6ce015e3 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -45,6 +45,78 @@ end Construct a `TapedTask` with the specified `dynamic_scope`, for function `f` and positional arguments `args`. + +# Extended Help + +There are three central features of a `TapedTask`, which we demonstrate via three examples. + +## Resumption + +```jldoctest tt +julia> function f() + for t in 1:2 + produce(t) + t += 1 + end + return nothing + end +f (generic function with 1 method) +``` + +```jldoctest tt +julia> t = TapedTask(nothing, f); + +julia> consume(t) +1 +``` + +```jldoctest tt +julia> consume(t) +2 + +julia> consume(t) + +``` + +## Copying + +```jldoctest tt +julia> t2 = TapedTask(nothing, f); + +julia> consume(t2) +1 +``` + +```jldoctest tt +julia> t3 = copy(t2); + +julia> consume(t3) +2 + +julia> consume(t2) +2 +``` + +## Scoped Values + +```jldoctest +julia> function f() + produce(get_dynamic_scope()) + produce(get_dynamic_scope()) + return nothing + end +f (generic function with 1 method) + +julia> t = TapedTask(1, f); + +julia> consume(t) +1 + +julia> set_dynamic_scope!(t, 2) + +julia> consume(t) +2 +``` """ function TapedTask(dynamic_scope::Any, fargs...) mc, count_ref = build_callable(Base.code_ircode_by_type(typeof(fargs))[1][1]) From 54c745c72a5655d43050e2f66f390d3b5b22ba60 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 7 Mar 2025 17:37:21 +0000 Subject: [PATCH 44/69] Docs action --- .github/workflows/Documentation.yml | 32 +++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 .github/workflows/Documentation.yml diff --git a/.github/workflows/Documentation.yml b/.github/workflows/Documentation.yml new file mode 100644 index 00000000..0ec2baa2 --- /dev/null +++ b/.github/workflows/Documentation.yml @@ -0,0 +1,32 @@ +name: Documentation + +on: + push: + branches: + - main + tags: '*' + pull_request: + +jobs: + build: + permissions: + contents: write + pull-requests: read + statuses: write + actions: write + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: '1' + arch: x64 + include-all-prereleases: false + - name: Install dependencies + run: julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.update(); Pkg.instantiate()' + - name: Build and deploy + env: + GKSwstype: nul # turn off GR's interactive plotting for notebooks + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # If authenticating with GitHub Actions token + DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} # For authentication with SSH deploy key + run: julia --project=docs/ docs/make.jl From 010265317fd5f2b596fe65034f83ee7636cb1024 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 7 Mar 2025 17:53:01 +0000 Subject: [PATCH 45/69] Tidy up docs --- docs/make.jl | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index e3fa7cac..318d04ea 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,13 +1,6 @@ using Documenter, Libtask -DocMeta.setdocmeta!( - Libtask, - :DocTestSetup, - quote - using Libtask - end; - recursive=true, -) +DocMeta.setdocmeta!(Libtask, :DocTestSetup, :(using Libtask); recursive=true) makedocs(; sitename="Libtask", doctest=true, pages=["index.md", "internals.md"], modules=[Libtask] From e751db3323fef3f7edda14af28ad1007d7f9c6f3 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 7 Mar 2025 18:29:06 +0000 Subject: [PATCH 46/69] Tidy up docs slightly --- docs/src/index.md | 7 +++---- src/copyable_task.jl | 40 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 3015115a..3bc62e1f 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -7,10 +7,9 @@ Libtask.TapedTask The functions discussed the above docstring (in addition to [`TapedTask`](@ref) itself) form the public interface of Libtask.jl. -They divide neatly into two kinds of functions: those which are used to construct and -manipulate [`TapedTask`](@ref)s, and those which are intended to be used _inside_ a +They divide neatly into two kinds of functions: those which are used to manipulate +[`TapedTask`](@ref)s, and those which are intended to be used _inside_ a [`TapedTask`](@ref). - First, manipulation of [`TapedTask`](@ref)s: ```@docs; canonical=true Libtask.consume @@ -18,7 +17,7 @@ Base.copy(::Libtask.TapedTask) Libtask.set_dynamic_scope! ``` -The functions which enable special functionality inside a [`TapedTask`](@ref)s are: +Functions for use inside a [`TapedTask`](@ref)s are: ```@docs; canonical=true Libtask.produce Libtask.get_dynamic_scope diff --git a/src/copyable_task.jl b/src/copyable_task.jl index 6ce015e3..107fc631 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -52,6 +52,8 @@ There are three central features of a `TapedTask`, which we demonstrate via thre ## Resumption +The function [`Libtask.produce`](@ref) has a special meaning in Libtask. You can insert it +into regular Julia functions anywhere that you like. For example ```jldoctest tt julia> function f() for t in 1:2 @@ -63,23 +65,36 @@ julia> function f() f (generic function with 1 method) ``` +If you construct a `TapedTask` from `f`, and call [`Libtask.consume`](@ref) on it, you'll +see ```jldoctest tt julia> t = TapedTask(nothing, f); julia> consume(t) 1 ``` +The semantics of this are that [`Libtask.consume`](@ref) runs the function `f` until it +reaches the call to [`Libtask.produce`](@ref), at which point it will return the argument +to [`Libtask.produce`](@ref). +Subsequent calls to [`Libtask.produce`](@ref) will _resume_ execution of `f` immediately +after the last [`Libtask.produce`](@ref) statement that was hit. ```jldoctest tt julia> consume(t) 2 +``` +When there are no more [`Libtask.produce`](@ref) statements to hit, calling +[`Libtask.consume`](@ref) will return `nothing`: +```jldoctest tt julia> consume(t) ``` ## Copying +[`TapedTask`](@ref)s can be copied. Doing so creates a completely independent object. +For example: ```jldoctest tt julia> t2 = TapedTask(nothing, f); @@ -87,36 +102,59 @@ julia> consume(t2) 1 ``` +If we make a copy and advance its state, it produces the same value that the original would +have produced: ```jldoctest tt julia> t3 = copy(t2); julia> consume(t3) 2 +``` +Moreover, advancing the state of the copy has not advanced the state of the original, +because they are completely independent copies: +```jldoctest tt julia> consume(t2) 2 ``` ## Scoped Values -```jldoctest +It is often desirable to permit a copy of a task and the original to differ in very specific +ways. For example, in the context of Sequential Monte Carlo, you might want the only +difference between two copies to be their random number generator. + +A generic mechanism is available to achieve this. [`Libtask.get_dynamic_scope`](@ref) and +[`Libtask.set_dynamic_scope!`](@ref) let you set and retrieve a variable which is specific +to a given [`Libtask.TapedTask`](@ref). The former can be called inside a function: +```jldoctest sv julia> function f() produce(get_dynamic_scope()) produce(get_dynamic_scope()) return nothing end f (generic function with 1 method) +``` +The first argument to [`Libtask.TapedTask`](@ref) is the value that +[`Libtask.get_dynamic_scope`](@ref) will return: +```jldoctest sv julia> t = TapedTask(1, f); julia> consume(t) 1 +``` +The value that it returns can be changed between [`Libtask.consume`](@ref) calls: +```jldoctest sv julia> set_dynamic_scope!(t, 2) julia> consume(t) 2 ``` + +`Int`s have been used here, but it is permissible to set the value returned by +[`Libtask.get_dynamic_scope`](@ref) to anything you like. """ function TapedTask(dynamic_scope::Any, fargs...) mc, count_ref = build_callable(Base.code_ircode_by_type(typeof(fargs))[1][1]) From 6a221c2a157516e1af609a9b38b9cbcefa8fca87 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Tue, 11 Mar 2025 18:54:04 +0000 Subject: [PATCH 47/69] Refactor + basic nested produce handling --- src/Libtask.jl | 2 +- src/copyable_task.jl | 712 ++++++++++++++++++++++++++++--------------- src/test_utils.jl | 14 +- 3 files changed, 473 insertions(+), 255 deletions(-) diff --git a/src/Libtask.jl b/src/Libtask.jl index 757e8294..df61d87f 100644 --- a/src/Libtask.jl +++ b/src/Libtask.jl @@ -2,7 +2,7 @@ module Libtask # Need this for BBCode. using Mooncake -using Mooncake: BBCode, BBlock, ID, new_inst, stmt, seed_id! +using Mooncake: BBCode, BBlock, ID, new_inst, stmt, seed_id!, terminator using Mooncake: IDGotoIfNot, IDGotoNode, IDPhiNode, Switch # We'll emit `MistyClosure`s rather than `OpaqueClosure`s. diff --git a/src/copyable_task.jl b/src/copyable_task.jl index 107fc631..79f4c0c0 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -21,23 +21,29 @@ When run inside a [`TapedTask`](@ref), will immediately yield to the caller, pro See also: [`Libtask.consume`](@ref) """ @noinline function produce(x) - global __v = 4 - return nothing + global __v = 4 # silly side-effect to prevent this call getting constant-folded away. Should really use the effects system. + return ProducedValue(x) end -function build_callable(ir::IRCode) - seed_id!() +function callable_ret_type(sig) + return Union{Base.code_ircode_by_type(sig)[1][2], ProducedValue} +end + +function build_callable(sig::Type{<:Tuple}) + ir = Base.code_ircode_by_type(sig)[1][1] bb, refs = derive_copyable_task_ir(BBCode(ir)) - ir = IRCode(bb) - optimised_ir = Mooncake.optimise_ir!(ir) - return MistyClosure(optimised_ir, refs...; do_compile=true), refs[end] + unoptimised_ir = IRCode(bb) + optimised_ir = Mooncake.optimise_ir!(unoptimised_ir) + mc_ret_type = callable_ret_type(sig) + mc = Mooncake.misty_closure(mc_ret_type, optimised_ir, refs...; do_compile=true) + return mc, refs[end] end mutable struct TapedTask{Tdynamic_scope,Targs,Tmc<:MistyClosure} dynamic_scope::Tdynamic_scope args::Targs const mc::Tmc - const position::Base.RefValue{Int32} # As does this + const position::Base.RefValue{Int32} end """ @@ -157,7 +163,8 @@ julia> consume(t) [`Libtask.get_dynamic_scope`](@ref) to anything you like. """ function TapedTask(dynamic_scope::Any, fargs...) - mc, count_ref = build_callable(Base.code_ircode_by_type(typeof(fargs))[1][1]) + seed_id!() + mc, count_ref = build_callable(typeof(fargs)) return TapedTask(dynamic_scope, fargs[2:end], mc, count_ref) end @@ -192,7 +199,8 @@ called, it start execution from the entry point. If `consume` has previously bee `nothing` will be returned. """ @inline function consume(t::TapedTask) - return with(() -> t.mc(t.args...), dynamic_scope => t.dynamic_scope) + v = with(() -> t.mc(t.args...), dynamic_scope => t.dynamic_scope) + return v isa ProducedValue ? v[] : nothing end """ @@ -232,6 +240,25 @@ function is_produce_stmt(x)::Bool end end +""" + stmt_might_produce(x)::Bool + +`true` if `x` might contain a call to `produce`, and `false` otherwise. +""" +function stmt_might_produce(x)::Bool + is_produce_stmt(x) && return true + Meta.isexpr(x, :invoke) && return might_produce(x.args[1].specTypes) + return false + + # # TODO: make this correct + # Meta.isexpr(x, :call) && + # return !isa(x.args[1], Union{Core.IntrinsicFunction,Core.Builtin}) + # Meta.isexpr(x, :invoke) && return false # todo: make this more accurate + # return false +end + +get_function(x::Expr) = x. + """ produce_value(x::Expr) @@ -244,118 +271,27 @@ function produce_value(x::Expr) return x.args[2] # must be a `:call` Expr. end -function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} - - # Replace all existing `ReturnNode`s with `ReturnNode(nothing)` in order to provide the - # same semantics as `Libtask`. - for bb in ir.blocks - for (n, inst) in enumerate(bb.insts) - stmt = inst.stmt - if stmt isa ReturnNode - bb.insts[n] = new_inst(ReturnNode(nothing)) - end - end - end - - # The location at which `refs` will be stored. - refs_id = Argument(1) - - # Mapping in which each key-value pairs says: "if we exited from block `key`, we must - # resume by jumping to basic block `value`". - resume_block_ids = Dict{ID,ID}() - - # For each basic block `bb`: - # - count the number of produce statements, `n_produce`. - # - construct `n_produce + 1` new basic blocks. The 1st new basic block runs from the - # first stmt in `bb` to the first `produce(%x)` statement (inclusive), the second - # from the next statement after the first `produce(%x)` statement until the next - # `produce(%x)` statement, etc. The final new basic block runs from the statement - # following the final `produce(%x)` statment, until the end of `bb`. - # Furthermore, each `produce(%x)` statement is replaced with a `ReturnNode(%x)`. - # We log the `ID`s of each of these new basic blocks, for use later. - replacements = Dict{ID,ID}() - new_bblocks = map(ir.blocks) do bb - - # If the final statement in the block is a `produce` statement, insert an additional - # statement afterwards. - if is_produce_stmt(bb.insts[end].stmt) - push!(bb.inst_ids, ID()) - push!(bb.insts, new_inst(nothing, Nothing)) - end +struct ProducedValue{T} + x::T +end - # Find all of the `produce` statements. - produce_indices = findall(x -> is_produce_stmt(x.stmt), bb.insts) - terminator_indices = vcat(produce_indices, length(bb)) - - # The `ID`s of the new basic blocks. - old_id = bb.id - new_block_ids = vcat([ID() for _ in produce_indices], bb.id) - new_id = first(new_block_ids) - replacements[old_id] = new_id - - # Construct `n_produce + 1` new basic blocks. The last basic block retains the - # `ID` of `bb`, the remaining `n_produce` blocks get new `ID`s (which we log). - # All `produce(%x)` statements are replaced with `Return(%x)` statements. - return map(enumerate(terminator_indices)) do (n, term_ind) - - # The last new block has the same `ID` as `bb`. The others gets new ones. - block_id = new_block_ids[n] - - # Pull out the instructions and their `ID`s for the new block. - start_ind = n == 1 ? 1 : terminator_indices[n - 1] + 1 - inst_ids = bb.inst_ids[start_ind:term_ind] - insts = bb.insts[start_ind:term_ind] - - # If n < length(terminator_indices) then it must end with a `produce` statement. - # In this case, we replace the `produce(%x)` statement with a call to set the - # `resume_block` to the next block, which ensures that execution jumps to the - # statement immediately following this `produce(%x)` statement next time the - # function is called. We also insert a `ReturnNode(%x)` i.e. to implement the - # `produce` statement. - # Also log the mapping between the current new block ID, and the ID of the block - # we should resume to. - if n < length(terminator_indices) - resume_id = new_block_ids[n + 1] - resume_block_ids[block_id] = resume_id - set_resume = Expr(:call, set_resume_block!, refs_id, resume_id.id) - return_node = ReturnNode(produce_value(insts[end].stmt)) - inst_ids = vcat(inst_ids[1:(end - 1)], [ID(), ID()]) # actual ID values are irrelevant (no uses). - insts = vcat( - insts[1:(end - 1)], [new_inst(set_resume), new_inst(return_node)] - ) - end +@inline Base.getindex(x::ProducedValue) = x.x - # Construct + return new basic block. - return BBlock(block_id, inst_ids, insts) - end - end - new_bblocks = reduce(vcat, new_bblocks) +function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} - # Hunt for `IDGotoNode`s and `IDGotoIfNot`s, and replace them with the new ID of the - # start of these blocks. - for (old_id, new_id) in replacements, bb in new_bblocks - inst = last(bb.insts) - stmt = inst.stmt - new_stmt = if stmt isa IDGotoNode && stmt.label == old_id - IDGotoNode(new_id) - elseif stmt isa IDGotoIfNot && stmt.dest == old_id - IDGotoIfNot(stmt.cond, new_id) - else - continue - end - bb.insts[end] = CC.NewInstruction( - new_stmt, inst.type, inst.info, inst.line, inst.flag - ) - end + # The location from which all state can be retrieved. Since we're using `OpaqueClosure`s + # to implement `TapedTask`s, this appears via the first argument. + refs_id = Argument(1) # Construct map between SSA IDs and their index in the state data structure and back. - # Optimisation TODO: don't create an entry for literally every line in the IR, just the - # ones which produce values that might be needed later. + # Also construct a map from each ref index to its type. We only construct `Ref`s + # for statements which return a value e.g. `IDGotoIfNot`s do not have a meaningful + # return value, so there's no need to allocate a `Ref` for them. ssa_id_to_ref_index_map = Dict{ID,Int}() ref_index_to_ssa_id_map = Dict{Int,ID}() ref_index_to_type_map = Dict{Int,Type}() n = 0 - for bb in new_bblocks + for bb in ir.blocks for (id, stmt) in zip(bb.inst_ids, bb.insts) stmt.stmt isa IDGotoNode && continue stmt.stmt isa IDGotoIfNot && continue @@ -369,162 +305,419 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} end # Specify data structure containing `Ref`s for all of the SSAs. - # Optimisation TODO: permit users to construct custom data structures to make their - # lives involve less indirection. - # Optimisation TODO: make there be only one `Ref` per basic block, and only write to it - # at the end of basic block execution (or something like that). Probably need to base - # this on what the basic blocks _will_ _be_ after we've transformed everything, so need - # to figure out when this can happen. - _refs = map(p -> Ref{ref_index_to_type_map[p]}(), 1:length(ref_index_to_ssa_id_map)) - refs = (_refs..., Ref{Int32}(-1)) + _refs = Any[Ref{ref_index_to_type_map[p]}() for p in 1:length(ref_index_to_ssa_id_map)] + + # Ensure that each basic block ends with a non-producing statement. This is achieved by + # replacing any fall-through terminators with `IDGotoNode`s. This is not strictly + # necessary, but simplifies later stages of the pipeline, as discussed variously below. + for (n, block) in enumerate(ir.blocks) + if terminator(block) === nothing + # Fall-through terminator, so next block in `ir.blocks` is the unique successor + # block of `block`. Final block cannot have a fall-through terminator, so asking + # for element `n + 1` is always going to be valid. + successor_id = ir.blocks[n + 1].id + push!(block.insts, new_inst(IDGotoNode(successor_id))) + push!(block.inst_ids, ID()) + end + end - # For each instruction in each basic block, replace it with a call to the refs. - new_bblocks = map(new_bblocks) do bb - inst_pairs = Mooncake.IDInstPair[] + # For each existing basic block, produce a sequence of `NamedTuple`s which + # define the manner in which it must be split. + # A block will in general be split as follows: + # 1 - %1 = φ(...) + # 1 - %2 = φ(...) + # 1 - %3 = call_which_must_not_produce(...) + # 1 - %4 = produce(%3) + # 2 - %5 = call_which_must_not_produce(...) + # 2 - %6 = call_which_might_produce(...) + # 3 - %7 = call_which_must_not_produce(...) + # 3 - terminator (GotoIfNot, GotoNode, etc) + # + # The numbers on the left indicate which split each statement falls. The first + # split comprises all statements up until the first produce / call-which-might-produce. + # Consequently, the first split will always contain any `PhiNode`s present in the block. + # The next set of statements up until the next produce / call-which-might-produce form + # the second split, and so on. + # We enforced above the condition that the final statement in a basic block must not + # produce. This ensures that the final split does not produce. While not strictly + # necessary, this simplifies the implementation (see below). + # + # As a result of the above, a basic block will be associated to exactly one split if it + # does not contain any statements which may produce. + # + # Each `NamedTuple` contains a `start` index and `last` index, indicating the position + # in the block at which the corresponding split starts and finishes. + all_splits = map(ir.blocks) do block + split_ends = vcat( + findall(inst -> stmt_might_produce(inst.stmt), block.insts), length(block) + ) + return map(enumerate(split_ends)) do (n, split_end) + return (start=(n == 1 ? 0 : split_ends[n - 1]) + 1, last=split_end) + end + end - # - # Handle all other nodes in the block. - # + # Owing to splitting blocks up, we will need to re-label some `GotoNode`s and + # `GotoIfNot`s. To understand this, consider the following block, whose original `ID` + # we assume to be `ID(old_id)`. + # ID(new_id) - %1 = φ(ID(3) => ...) + # ID(new_id) - %3 = call_which_must_not_produce(...) + # ID(new_id) - %4 = produce(%3) + # ID(old_id) - GotoNode(ID(5)) + # + # In the above, the entire block was original associated to a single ID, `ID(old_id)`, + # but is now split into two sections. We keep the original ID for the final split, and + # assign a new one to the first split. As a result, any `PhiNode`s in other blocks + # which have edges incoming from `ID(old_id)` will remain valid. + # However, if we adopt this strategy for all blocks, `ID(5)` in the `GotoNode` at the + # end of the block will refer to the wrong block if the block original associated to + # `ID(5)` was itself split, since the "top" of that block will have a new `ID`. + # + # To resolve this, we: + # 1. Associate an ID to each split in each block, ensuring that the ID for the final + # split of each block is the same ID as that of the original block. + all_split_ids = map(zip(ir.blocks, all_splits)) do (block, splits) + return vcat([ID() for _ in splits[1:end-1]], block.id) + end - foreach(zip(bb.inst_ids, bb.insts)) do (id, inst) - stmt = inst.stmt - if Meta.isexpr(stmt, :invoke) || - Meta.isexpr(stmt, :call) || - Meta.isexpr(stmt, :new) || - Meta.isexpr(stmt, :foreigncall) + # 2. Construct a map between the ID of each block and the ID associated to its split. + top_split_id_map = Dict{ID,ID}(b.id => x[1] for (b, x) in zip(ir.blocks, all_split_ids)) - # Skip over set_resume_block! statements inserted in the previous pass. - if stmt.args[1] == set_resume_block! - push!(inst_pairs, (id, inst)) - return nothing + # 3. Update all `GotoNode`s and `GotoIfNot`s to refer to these new names. + for block in ir.blocks + t = terminator(block) + if t isa IDGotoNode + block.insts[end] = new_inst(IDGotoNode(top_split_id_map[t.label])) + elseif t isa IDGotoIfNot + block.insts[end] = new_inst(IDGotoIfNot(t.cond, top_split_id_map[t.dest])) + end + end + + # A set of blocks from which we might wish to resume computation. + resume_block_ids = Vector{ID}() + + # This where most of the action happens. + # + # For each split of each block, we must + # 1. translate all statements which accept any SSAs as arguments, or return a value, + # into statements which read in data from the `Ref`s containing the value associated + # to each SSA, and write the result to `Ref`s associated to the SSA of the line in + # question. + # 2. add additional code at the end of the split to handle the possibility that the + # last statement produces (per the definition of the splits above). This applies to + # all splits except the last, which cannot produce by construction. Exactly what + # happens here depends on whether the last statement is a `produce` call, or a + # call-which-might-produce -- see below for specifics. + # + # This code transforms each block (and its splits) into a new collection of blocks. + # Note that the total number of new blocks may be greater than the total number of + # splits, because each split ending in a call-which-might-produce requires more than a + # single block to implement the required resumption functionality. + new_bblocks = map(zip(ir.blocks, all_splits, all_split_ids)) do (bb, splits, splits_ids) + new_blocks = map(enumerate(splits)) do (n, split) + # We'll push ID-NewInstruction pairs to this as we proceed through the split. + inst_pairs = Mooncake.IDInstPair[] + + # PhiNodes: + # + # A single `PhiNode` + # + # ID(%1) = φ(ID(#1) => 1, ID(#2) => ID(%n)) + # + # sets `ID(%1)` to either `1` or whatever value is currently associated to + # `ID(%n)`, depending upon whether the predecessor block was `ID(#1)` or + # `ID(#2)`. Consequently, a single `PhiNode` can be transformed into something + # along the lines of: + # + # ID(%1) = φ(ID(#1) => 1, ID(#2) => TupleRef(ref_ind_for_ID(%n))) + # ID(%2) = deref_phi(refs, ID(%1)) + # set_ref_at!(refs, ref_ind_for_ID(%1), ID(%2)) + # + # where `deref_phi` retrives the value in position `ref_ind_for_ID(%n)` if + # ID(%1) is a `TupleRef`, and `1` otherwise, and `set_ref_at!` sets the `Ref` + # at position `ref_ind_for_ID(%1)` to the value of `ID(%2)`. See the actual + # implementations below. + # + # If we have multiple `PhiNode`s at the start of a block, we must run all of + # them, then dereference all of their values, and finally write all of the + # de-referenced values to the appropriate locations. This is because + # a. we require all `PhiNode`s appear together at the top of a given basic + # block, and + # b. the semantics of `PhiNode`s is that they are all "run" simultaneously. This + # only matters if one `PhiNode` in the block can refer to the value stored in + # the SSA associated to another. For example, something along the lines of: + # + # ID(%1) = φ(ID(#1) => 1, ID(#2) => ID(%2)) + # ID(%2) = φ(ID(#1) => 1, ID(#2) => 2) + # + # (we leave it as an exercise for the reader to figure out why this particular + # semantic feature of `PhiNode`s is relevant in this specific case). + # + # So, in general, the code produced by this block will look roughly like + # + # ID(%1) = φ(...) + # ID(%2) = φ(...) + # ID(%3) = φ(...) + # ID(%4) = deref_phi(refs, ID(%1)) + # ID(%5) = deref_phi(refs, ID(%2)) + # ID(%6) = deref_phi(refs, ID(%3)) + # set_ref_at!(refs, ref_ind_for_ID(%1), ID(%4)) + # set_ref_at!(refs, ref_ind_for_ID(%2), ID(%5)) + # set_ref_at!(refs, ref_ind_for_ID(%3), ID(%6)) + if n == 1 + # Find all PhiNodes in the block -- will definitely be in this split. + phi_inds = findall(x -> x.stmt isa IDPhiNode, bb.insts) + + # Replace SSA IDs with `TupleRef`s, and record these instructions. + phi_ids = map(phi_inds) do n + phi = bb.insts[n].stmt + for i in eachindex(phi.values) + isassigned(phi.values, i) || continue + v = phi.values[i] + v isa ID || continue + phi.values[i] = TupleRef(ssa_id_to_ref_index_map[v]) + end + phi_id = ID() + push!(inst_pairs, (phi_id, new_inst(phi, Any))) + return phi_id end - # Find any `ID`s and replace them with calls to read whatever is stored in - # the `Ref`s that they are associated to. - for (n, arg) in enumerate(stmt.args) - arg isa ID || continue + # De-reference values associated to `IDPhiNode`s. + deref_ids = map(phi_inds) do n + id = bb.inst_ids[n] + phi_id = phi_ids[n] + push!(inst_pairs, (id, new_inst(Expr(:call, deref_phi, refs_id, phi_id)))) + return id + end - new_id = ID() - ref_ind = ssa_id_to_ref_index_map[arg] - expr = Expr(:call, get_ref_at, refs_id, ref_ind) - push!(inst_pairs, (new_id, new_inst(expr))) - stmt.args[n] = new_id + # Update values stored in `Ref`s associated to `PhiNode`s. + for n in phi_inds + ref_ind = ssa_id_to_ref_index_map[bb.inst_ids[n]] + expr = Expr(:call, set_ref_at!, refs_id, ref_ind, deref_ids[n]) + push!(inst_pairs, (ID(), new_inst(expr))) end + end - # Push the target instruction to the list. - push!(inst_pairs, (id, inst)) - - # Push the result to its `Ref`. - out_ind = ssa_id_to_ref_index_map[id] - set_ref = Expr(:call, set_ref_at!, refs_id, out_ind, id) - push!(inst_pairs, (ID(), new_inst(set_ref))) - elseif Meta.isexpr(stmt, :boundscheck) - push!(inst_pairs, (id, inst)) - elseif Meta.isexpr(stmt, :code_coverage_effect) - push!(inst_pairs, (id, inst)) - elseif Meta.isexpr(stmt, :gc_preserve_begin) - push!(inst_pairs, (id, inst)) - elseif Meta.isexpr(stmt, :gc_preserve_end) - push!(inst_pairs, (id, inst)) - elseif stmt isa ReturnNode - # If returning an SSA, it might be one whose value was restored from before. - # Therefore, grab it out of storage, rather than assuming that it is def-ed. - if isdefined(stmt, :val) && stmt.val isa ID - ref_ind = ssa_id_to_ref_index_map[stmt.val] - val_id = ID() - expr = Expr(:call, get_ref_at, refs_id, ref_ind) - push!(inst_pairs, (val_id, new_inst(expr))) - push!(inst_pairs, (ID(), new_inst(ReturnNode(val_id)))) - else + # Statements which do not produce: + # + # Iterate every statement in the split other than the final one, replacing uses + # of SSAs with de-referenced `Ref`s, and writing the results of statements to + # the corresponding `Ref`s. + _ids = view(bb.inst_ids, split.start:(split.last - 1)) + _insts = view(bb.insts, split.start:(split.last - 1)) + for (id, inst) in zip(_ids, _insts) + stmt = inst.stmt + if Meta.isexpr(stmt, :invoke) || + Meta.isexpr(stmt, :call) || + Meta.isexpr(stmt, :new) || + Meta.isexpr(stmt, :foreigncall) + + # Find any `ID`s and replace them with calls to read whatever is stored + # in the `Ref`s that they are associated to. + for (n, arg) in enumerate(stmt.args) + arg isa ID || continue + + new_id = ID() + ref_ind = ssa_id_to_ref_index_map[arg] + expr = Expr(:call, get_ref_at, refs_id, ref_ind) + push!(inst_pairs, (new_id, new_inst(expr))) + stmt.args[n] = new_id + end + + # Push the target instruction to the list. push!(inst_pairs, (id, inst)) - end - elseif stmt isa IDGotoIfNot - # If the condition is an SSA, it might be one whose value was restored from - # before. Therefore, grab it out of storage, rather than assuming that it is - # defined. - if stmt.cond isa ID - ref_ind = ssa_id_to_ref_index_map[stmt.cond] - cond_id = ID() - expr = Expr(:call, get_ref_at, refs_id, ref_ind) - push!(inst_pairs, (cond_id, new_inst(expr))) - push!(inst_pairs, (ID(), new_inst(IDGotoIfNot(cond_id, stmt.dest)))) - else + + # If we know it is not possible for this statement to contain any calls + # to produce, then simply write out the result to its `Ref`. + out_ind = ssa_id_to_ref_index_map[id] + set_ref = Expr(:call, set_ref_at!, refs_id, out_ind, id) + push!(inst_pairs, (ID(), new_inst(set_ref))) + elseif Meta.isexpr(stmt, :boundscheck) push!(inst_pairs, (id, inst)) - end - elseif stmt isa IDGotoNode - push!(inst_pairs, (id, inst)) - elseif stmt isa IDPhiNode - # we'll fix up the PhiNodes after this, so identity transform for now. - push!(inst_pairs, (id, inst)) - elseif stmt isa Nothing - push!(inst_pairs, (id, inst)) - elseif stmt isa GlobalRef - ref_ind = ssa_id_to_ref_index_map[id] - expr = Expr(:call, set_ref_at!, refs_id, ref_ind, stmt) - push!(inst_pairs, (id, new_inst(expr))) - elseif stmt isa Core.PiNode - if stmt.val isa ID - ref_ind = ssa_id_to_ref_index_map[stmt.val] - val_id = ID() - expr = Expr(:call, get_ref_at, refs_id, ref_ind) - push!(inst_pairs, (val_id, new_inst(expr))) - push!(inst_pairs, (id, new_inst(Core.PiNode(val_id, stmt.typ)))) - else + elseif Meta.isexpr(stmt, :code_coverage_effect) + push!(inst_pairs, (id, inst)) + elseif Meta.isexpr(stmt, :gc_preserve_begin) + push!(inst_pairs, (id, inst)) + elseif Meta.isexpr(stmt, :gc_preserve_end) push!(inst_pairs, (id, inst)) + elseif stmt isa Nothing + push!(inst_pairs, (id, inst)) + elseif stmt isa GlobalRef + ref_ind = ssa_id_to_ref_index_map[id] + expr = Expr(:call, set_ref_at!, refs_id, ref_ind, stmt) + push!(inst_pairs, (id, new_inst(expr))) + elseif stmt isa Core.PiNode + if stmt.val isa ID + ref_ind = ssa_id_to_ref_index_map[stmt.val] + val_id = ID() + expr = Expr(:call, get_ref_at, refs_id, ref_ind) + push!(inst_pairs, (val_id, new_inst(expr))) + push!(inst_pairs, (id, new_inst(Core.PiNode(val_id, stmt.typ)))) + else + push!(inst_pairs, (id, inst)) + end + set_ind = ssa_id_to_ref_index_map[id] + set_expr = Expr(:call, set_ref_at!, refs_id, set_ind, id) + push!(inst_pairs, (ID(), new_inst(set_expr))) + elseif stmt isa IDPhiNode + # do nothing -- we've already handled any `PhiNode`s. + else + throw(error("Unhandled stmt $stmt of type $(typeof(stmt))")) end - set_ind = ssa_id_to_ref_index_map[id] - set_expr = Expr(:call, set_ref_at!, refs_id, set_ind, id) - push!(inst_pairs, (ID(), new_inst(set_expr))) - else - throw(error("Unhandled stmt $stmt of type $(typeof(stmt))")) end - end - # - # Handle `(ID)PhiNode`s. - # - - phi_inds = findall(x -> x.stmt isa IDPhiNode, bb.insts) - phi_inst_pairs = Mooncake.IDInstPair[] - - # Replace SSA IDs with `TupleRef`s, and record these instructions. - phi_ids = map(phi_inds) do n - phi = bb.insts[n].stmt - for i in eachindex(phi.values) - isassigned(phi.values, i) || continue - v = phi.values[i] - v isa ID || continue - phi.values[i] = TupleRef(ssa_id_to_ref_index_map[v]) - end - phi_id = ID() - push!(phi_inst_pairs, (phi_id, new_inst(phi, Any))) - return phi_id - end + # TODO: explain this better. + new_blocks = BBlock[] - # De-reference values associated to `IDPhiNode`s. - deref_ids = map(phi_inds) do n - id = bb.inst_ids[n] - phi_id = phi_ids[n] - push!(phi_inst_pairs, (id, new_inst(Expr(:call, deref_phi, refs_id, phi_id)))) - return id - end + # Produce and Terminators: + # + # Handle the last statement in the split. + id = bb.inst_ids[split.last] + inst = bb.insts[split.last] + stmt = inst.stmt + if n == length(splits) + # This is the last split in the block, so it must end with a non-producing + # terminator. We handle this in a similar way to the statements above. + + if stmt isa ReturnNode + # If returning an SSA, it might be one whose value was restored from + # before. Therefore, grab it out of storage, rather than assuming that + # it is def-ed. + if isdefined(stmt, :val) && stmt.val isa ID + ref_ind = ssa_id_to_ref_index_map[stmt.val] + val_id = ID() + expr = Expr(:call, get_ref_at, refs_id, ref_ind) + push!(inst_pairs, (val_id, new_inst(expr))) + push!(inst_pairs, (ID(), new_inst(ReturnNode(val_id)))) + else + push!(inst_pairs, (id, inst)) + end + elseif stmt isa IDGotoIfNot + # If the condition is an SSA, it might be one whose value was restored + # from before. Therefore, grab it out of storage, rather than assuming + # that it is defined. + if stmt.cond isa ID + ref_ind = ssa_id_to_ref_index_map[stmt.cond] + cond_id = ID() + expr = Expr(:call, get_ref_at, refs_id, ref_ind) + push!(inst_pairs, (cond_id, new_inst(expr))) + push!(inst_pairs, (ID(), new_inst(IDGotoIfNot(cond_id, stmt.dest)))) + else + push!(inst_pairs, (id, inst)) + end + elseif stmt isa IDGotoNode + push!(inst_pairs, (id, inst)) + else + error("Unexpected terminator $stmt") + end + push!(new_blocks, BBlock(splits_ids[n], inst_pairs)) + elseif is_produce_stmt(stmt) + + # When this TapedTask is next called, we should resume from the first + # statement of the next split. + resume_id = splits_ids[n + 1] + push!(resume_block_ids, resume_id) + + # Insert statement to enforce correct resumption behaviour. + resume_stmt = Expr(:call, set_resume_block!, refs_id, resume_id.id) + push!(inst_pairs, (ID(), new_inst(resume_stmt))) + + # Insert statement to construct a `ProducedValue` from the value. + # Could be that the produce references an SSA, in which case we need to + # de-reference, rather than just return the thing. + prod_val = produce_value(stmt) + if prod_val isa ID + deref_id = ID() + ref_ind = ssa_id_to_ref_index_map[prod_val] + expr = Expr(:call, get_ref_at, refs_id, ref_ind) + push!(inst_pairs, (deref_id, new_inst(expr))) + prod_val = deref_id + end - # Update values stored in `Ref`s associated to `PhiNode`s. - for n in phi_inds - ref_ind = ssa_id_to_ref_index_map[bb.inst_ids[n]] - expr = Expr(:call, set_ref_at!, refs_id, ref_ind, deref_ids[n]) - push!(phi_inst_pairs, (ID(), new_inst(expr))) - end + # Construct a `ProducedValue`. + val_id = ID() + push!(inst_pairs, (val_id, new_inst(Expr(:call, ProducedValue, prod_val)))) - # Concatenate new phi stmts, removing old ones. - inst_pairs = vcat(phi_inst_pairs, inst_pairs[(length(phi_inds) + 1):end]) + # Insert statement to return the `ProducedValue`. + push!(inst_pairs, (ID(), new_inst(ReturnNode(val_id)))) - return BBlock(bb.id, inst_pairs) + # Construct a single new basic block from all of the inst-pairs. + push!(new_blocks, BBlock(splits_ids[n], inst_pairs)) + else + # The final statement is one which might produce, but is not itself a + # `produce` statement. + + # Create a new basic block from the existing statements, since all new + # statement need to live in their own basic blocks. + callable_block_id = ID() + push!(inst_pairs, (ID(), new_inst(IDGotoNode(callable_block_id)))) + push!(new_blocks, BBlock(splits_ids[n], inst_pairs)) + + # Derive TapedTask for this statement. + callable = if Meta.isexpr(stmt, :invoke) + sig = stmt.args[1].specTypes + LazyCallable{sig,callable_ret_type(sig)}() + else + error("unhandled statement which might produce $stmt") + end + + # Allocate a slot in the _refs vector for this callable. + push!(_refs, Ref(callable)) + callable_ind = length(_refs) + + # Retrieve the callable from the refs. + callable_id = ID() + callable = Expr(:call, get_ref_at, refs_id, callable_ind) + + # Call the callable. + result = Expr(:call, callable_id, stmt.args[3:end]...) + result_id = ID() + + # Determine whether this TapedTask has produced a not-a-`ProducedValue`. + not_produced = Expr(:call, not_a_produced, result_id) + not_produced_id = ID() + + # Go to a block which just returns the `ProducedValue`, if a + # `ProducedValue` is returned, otherwise continue to the next split. + is_produced_block_id = ID() + next_block_id = splits_ids[n + 1] # safe since the last split ends with a terminator + # switch = Switch(Any[not_produced_id], [is_produced_block_id], next_block_id) + switch = IDGotoIfNot(not_produced_id, is_produced_block_id) + + # Insert a new block to hold the three previous statements. + callable_inst_pairs = Mooncake.IDInstPair[ + (callable_id, new_inst(callable)), + (result_id, new_inst(result)), + (not_produced_id, new_inst(not_produced)), + (ID(), new_inst(switch)), + ] + push!(new_blocks, BBlock(callable_block_id, callable_inst_pairs)) + + goto_block = BBlock(ID(), [(ID(), new_inst(IDGotoNode(next_block_id)))]) + push!(new_blocks, goto_block) + + # Construct block which handles the case that we got a `ProducedValue`. If + # this happens, it means that `callable` has more things to produce still. + # This means that we need to call it again next time we enter this function. + # To achieve this, we set the resume block to the `callable_block_id`, + # and return the `ProducedValue` currently located in `result_id`. + push!(resume_block_ids, callable_block_id) + set_res = Expr(:call, set_resume_block!, refs_id, callable_block_id.id) + return_id = ID() + produced_block_inst_pairs = Mooncake.IDInstPair[ + (ID(), new_inst(set_res)), + (return_id, new_inst(ReturnNode(result_id))), + ] + push!(new_blocks, BBlock(is_produced_block_id, produced_block_inst_pairs)) + end + return new_blocks + end + return reduce(vcat, new_blocks) end + new_bblocks = reduce(vcat, new_bblocks) # Insert statements at the top. - cases = map(collect(resume_block_ids)) do (pred, succ) - return ID(), succ, Expr(:call, resume_block_is, refs_id, succ.id) + cases = map(resume_block_ids) do id + return ID(), id, Expr(:call, resume_block_is, refs_id, id.id) end cond_ids = ID[x[1] for x in cases] cond_dests = ID[x[2] for x in cases] @@ -537,6 +730,7 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} # New argtypes are the same as the old ones, except we have `Ref`s in the first argument # rather than nothing at all. new_argtypes = copy(ir.argtypes) + refs = (_refs..., Ref{Int32}(-1)) new_argtypes[1] = typeof(refs) # Return BBCode and the `Ref`s. @@ -567,6 +761,9 @@ end @inline deref_phi(refs::R, n::TupleRef) where {R<:Tuple} = refs[n.n][] @inline deref_phi(::R, x) where {R<:Tuple} = x +# Helper used in `derived_copyable_task_ir`. +@inline not_a_produced(x) = !(isa(x, ProducedValue)) + # Implement iterator interface. function Base.iterate(t::TapedTask, state::Nothing=nothing) v = consume(t) @@ -574,3 +771,24 @@ function Base.iterate(t::TapedTask, state::Nothing=nothing) end Base.IteratorSize(::Type{<:TapedTask}) = Base.SizeUnknown() Base.IteratorEltype(::Type{<:TapedTask}) = Base.EltypeUnknown() + +""" + +""" +mutable struct LazyCallable{sig<:Tuple,Tret} + mc::MistyClosure + position::Base.RefValue{Int32} + LazyCallable{sig,Tret}() where {sig,Tret} = new{sig,Tret}() +end + +function (l::LazyCallable)(args::Vararg{Any,N}) where {N} + isdefined(l, :mc) || construct_callable!(l) + return l.mc(args...) +end + +function construct_callable!(l::LazyCallable{sig}) where {sig} + mc, pos = build_callable(sig) + l.mc = mc + l.position = pos + return nothing +end diff --git a/src/test_utils.jl b/src/test_utils.jl index e126179f..f12ba951 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -70,7 +70,7 @@ function test_cases() Testcase("dynamic scope 2", 6, (dynamic_scope_tester_1,), [6]), # Failing tests - # Testcase("nested", (nested_outer, ), [true, false]), + Testcase("nested", nothing, (nested_outer, ), [true, false]), ] end @@ -158,12 +158,17 @@ function foreigncall_tester(s::String) return nothing end +function dynamic_scope_tester_1() + produce(Libtask.get_dynamic_scope()) + return nothing +end + @noinline function nested_inner() produce(true) return nothing end -might_produce(::Type{Tuple{typeof(nested_inner)}}) = true +Libtask.might_produce(::Type{Tuple{typeof(nested_inner)}}) = true function nested_outer() nested_inner() @@ -171,9 +176,4 @@ function nested_outer() return nothing end -function dynamic_scope_tester_1() - produce(Libtask.get_dynamic_scope()) - return nothing -end - end From f94ea17e0fcec8db6e64f64364a41ec80e23a239 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Tue, 11 Mar 2025 18:59:30 +0000 Subject: [PATCH 48/69] Fomatting --- src/copyable_task.jl | 21 +++++++++++---------- src/test_utils.jl | 2 +- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/copyable_task.jl b/src/copyable_task.jl index 79f4c0c0..ec62cb25 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -26,7 +26,7 @@ See also: [`Libtask.consume`](@ref) end function callable_ret_type(sig) - return Union{Base.code_ircode_by_type(sig)[1][2], ProducedValue} + return Union{Base.code_ircode_by_type(sig)[1][2],ProducedValue} end function build_callable(sig::Type{<:Tuple}) @@ -257,8 +257,6 @@ function stmt_might_produce(x)::Bool # return false end -get_function(x::Expr) = x. - """ produce_value(x::Expr) @@ -376,7 +374,7 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} # 1. Associate an ID to each split in each block, ensuring that the ID for the final # split of each block is the same ID as that of the original block. all_split_ids = map(zip(ir.blocks, all_splits)) do (block, splits) - return vcat([ID() for _ in splits[1:end-1]], block.id) + return vcat([ID() for _ in splits[1:(end - 1)]], block.id) end # 2. Construct a map between the ID of each block and the ID associated to its split. @@ -485,7 +483,10 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} deref_ids = map(phi_inds) do n id = bb.inst_ids[n] phi_id = phi_ids[n] - push!(inst_pairs, (id, new_inst(Expr(:call, deref_phi, refs_id, phi_id)))) + push!( + inst_pairs, + (id, new_inst(Expr(:call, deref_phi, refs_id, phi_id))), + ) return id end @@ -502,8 +503,8 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} # Iterate every statement in the split other than the final one, replacing uses # of SSAs with de-referenced `Ref`s, and writing the results of statements to # the corresponding `Ref`s. - _ids = view(bb.inst_ids, split.start:(split.last - 1)) - _insts = view(bb.insts, split.start:(split.last - 1)) + _ids = view(bb.inst_ids, (split.start):(split.last - 1)) + _insts = view(bb.insts, (split.start):(split.last - 1)) for (id, inst) in zip(_ids, _insts) stmt = inst.stmt if Meta.isexpr(stmt, :invoke) || @@ -515,17 +516,17 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} # in the `Ref`s that they are associated to. for (n, arg) in enumerate(stmt.args) arg isa ID || continue - + new_id = ID() ref_ind = ssa_id_to_ref_index_map[arg] expr = Expr(:call, get_ref_at, refs_id, ref_ind) push!(inst_pairs, (new_id, new_inst(expr))) stmt.args[n] = new_id end - + # Push the target instruction to the list. push!(inst_pairs, (id, inst)) - + # If we know it is not possible for this statement to contain any calls # to produce, then simply write out the result to its `Ref`. out_ind = ssa_id_to_ref_index_map[id] diff --git a/src/test_utils.jl b/src/test_utils.jl index f12ba951..0de3cf53 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -70,7 +70,7 @@ function test_cases() Testcase("dynamic scope 2", 6, (dynamic_scope_tester_1,), [6]), # Failing tests - Testcase("nested", nothing, (nested_outer, ), [true, false]), + Testcase("nested", nothing, (nested_outer,), [true, false]), ] end From 57b808c7649e356421477d7ab0742a0cbb7bd0bf Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 13 Mar 2025 12:01:12 +0000 Subject: [PATCH 49/69] Dynamic nested calls and uses of return values of calls which might produce --- src/copyable_task.jl | 120 ++++++++++++++++++++++++++++++++----------- src/test_utils.jl | 44 ++++++++++++++-- 2 files changed, 130 insertions(+), 34 deletions(-) diff --git a/src/copyable_task.jl b/src/copyable_task.jl index ec62cb25..85c3af8e 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -241,22 +241,34 @@ function is_produce_stmt(x)::Bool end """ - stmt_might_produce(x)::Bool + stmt_might_produce(x, ret_type::Type)::Bool `true` if `x` might contain a call to `produce`, and `false` otherwise. """ -function stmt_might_produce(x)::Bool +function stmt_might_produce(x, ret_type::Type)::Bool + + # Statement will terminate in an unusual fashion, so don't bother recursing. + # This isn't _strictly_ correct (there could be a `produce` statement before the + # `throw` call is hit), but this seems unlikely to happen in practice. If it does, the + # user should get a sensible error message anyway. + ret_type == Union{} && return false + + # Statement will terminate in the usual fashion, so _do_ bother recusing. is_produce_stmt(x) && return true Meta.isexpr(x, :invoke) && return might_produce(x.args[1].specTypes) + if Meta.isexpr(x, :call) + # This is a hack -- it's perfectly possible for `DataType` calls to produce in general. + f = get_function(x.args[1]) + _might_produce = !isa(f, Union{Core.IntrinsicFunction,Core.Builtin,DataType}) + return _might_produce + end return false - - # # TODO: make this correct - # Meta.isexpr(x, :call) && - # return !isa(x.args[1], Union{Core.IntrinsicFunction,Core.Builtin}) - # Meta.isexpr(x, :invoke) && return false # todo: make this more accurate - # return false end +get_function(x) = x +get_function(x::Expr) = eval(x) +get_function(x::GlobalRef) = isconst(x) ? getglobal(x.mod, x.name) : x.binding + """ produce_value(x::Expr) @@ -347,7 +359,11 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} # in the block at which the corresponding split starts and finishes. all_splits = map(ir.blocks) do block split_ends = vcat( - findall(inst -> stmt_might_produce(inst.stmt), block.insts), length(block) + findall( + inst -> stmt_might_produce(inst.stmt, CC.widenconst(inst.type)), + block.insts, + ), + length(block), ) return map(enumerate(split_ends)) do (n, split_end) return (start=(n == 1 ? 0 : split_ends[n - 1]) + 1, last=split_end) @@ -654,48 +670,63 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} push!(new_blocks, BBlock(splits_ids[n], inst_pairs)) # Derive TapedTask for this statement. - callable = if Meta.isexpr(stmt, :invoke) + (callable, callable_args) = if Meta.isexpr(stmt, :invoke) sig = stmt.args[1].specTypes - LazyCallable{sig,callable_ret_type(sig)}() + (LazyCallable{sig,callable_ret_type(sig)}(), stmt.args[2:end]) + elseif Meta.isexpr(stmt, :call) + (DynamicCallable(), stmt.args) else + display(stmt) + println() error("unhandled statement which might produce $stmt") end + # Find any `ID`s and replace them with calls to read whatever is stored + # in the `Ref`s that they are associated to. + callable_inst_pairs = Mooncake.IDInstPair[] + for (n, arg) in enumerate(callable_args) + arg isa ID || continue + + new_id = ID() + ref_ind = ssa_id_to_ref_index_map[arg] + expr = Expr(:call, get_ref_at, refs_id, ref_ind) + push!(callable_inst_pairs, (new_id, new_inst(expr))) + callable_args[n] = new_id + end + # Allocate a slot in the _refs vector for this callable. push!(_refs, Ref(callable)) callable_ind = length(_refs) # Retrieve the callable from the refs. callable_id = ID() - callable = Expr(:call, get_ref_at, refs_id, callable_ind) + callable_stmt = Expr(:call, get_ref_at, refs_id, callable_ind) + push!(callable_inst_pairs, (callable_id, new_inst(callable_stmt))) # Call the callable. - result = Expr(:call, callable_id, stmt.args[3:end]...) result_id = ID() + result_stmt = Expr(:call, callable_id, callable_args...) + push!(callable_inst_pairs, (result_id, new_inst(result_stmt))) # Determine whether this TapedTask has produced a not-a-`ProducedValue`. - not_produced = Expr(:call, not_a_produced, result_id) not_produced_id = ID() + not_produced_stmt = Expr(:call, not_a_produced, result_id) + push!(callable_inst_pairs, (not_produced_id, new_inst(not_produced_stmt))) # Go to a block which just returns the `ProducedValue`, if a # `ProducedValue` is returned, otherwise continue to the next split. is_produced_block_id = ID() - next_block_id = splits_ids[n + 1] # safe since the last split ends with a terminator - # switch = Switch(Any[not_produced_id], [is_produced_block_id], next_block_id) - switch = IDGotoIfNot(not_produced_id, is_produced_block_id) - - # Insert a new block to hold the three previous statements. - callable_inst_pairs = Mooncake.IDInstPair[ - (callable_id, new_inst(callable)), - (result_id, new_inst(result)), - (not_produced_id, new_inst(not_produced)), - (ID(), new_inst(switch)), - ] + is_not_produced_block_id = ID() + switch = Switch( + Any[not_produced_id], + [is_produced_block_id], + is_not_produced_block_id, + ) + push!(callable_inst_pairs, (ID(), new_inst(switch))) + + # Push the above statements onto a new block. push!(new_blocks, BBlock(callable_block_id, callable_inst_pairs)) - goto_block = BBlock(ID(), [(ID(), new_inst(IDGotoNode(next_block_id)))]) - push!(new_blocks, goto_block) - # Construct block which handles the case that we got a `ProducedValue`. If # this happens, it means that `callable` has more things to produce still. # This means that we need to call it again next time we enter this function. @@ -709,6 +740,21 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} (return_id, new_inst(ReturnNode(result_id))), ] push!(new_blocks, BBlock(is_produced_block_id, produced_block_inst_pairs)) + + # Construct block which handles the case that we did not get a + # `ProducedValue`. In this case, we must first push the result to the `Ref` + # associated to the call, and goto the next split. + next_block_id = splits_ids[n + 1] # safe since the last split ends with a terminator + result_ref_ind = ssa_id_to_ref_index_map[id] + set_ref = Expr(:call, set_ref_at!, refs_id, result_ref_ind, result_id) + not_produced_block_inst_pairs = Mooncake.IDInstPair[ + (ID(), new_inst(set_ref)) + (ID(), new_inst(IDGotoNode(next_block_id))) + ] + push!( + new_blocks, + BBlock(is_not_produced_block_id, not_produced_block_inst_pairs), + ) end return new_blocks end @@ -784,7 +830,7 @@ end function (l::LazyCallable)(args::Vararg{Any,N}) where {N} isdefined(l, :mc) || construct_callable!(l) - return l.mc(args...) + return l.mc(args[2:end]...) end function construct_callable!(l::LazyCallable{sig}) where {sig} @@ -793,3 +839,19 @@ function construct_callable!(l::LazyCallable{sig}) where {sig} l.position = pos return nothing end + +mutable struct DynamicCallable{V} + cache::V +end + +DynamicCallable() = DynamicCallable(Dict{Any,Any}()) + +function (dynamic_callable::DynamicCallable)(args::Vararg{Any,N}) where {N} + sig = Mooncake._typeof(args) + callable = get(dynamic_callable.cache, sig, nothing) + if callable === nothing + callable = build_callable(sig) + dynamic_callable.cache[sig] = callable + end + return callable[1](args[2:end]...) +end diff --git a/src/test_utils.jl b/src/test_utils.jl index 0de3cf53..083224ce 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -68,9 +68,25 @@ function test_cases() ), Testcase("dynamic scope 1", 5, (dynamic_scope_tester_1,), [5]), Testcase("dynamic scope 2", 6, (dynamic_scope_tester_1,), [6]), - - # Failing tests - Testcase("nested", nothing, (nested_outer,), [true, false]), + Testcase("nested (static)", nothing, (static_nested_outer,), [true, false]), + Testcase( + "nested (static + used)", + nothing, + (static_nested_outer_use_produced,), + [true, 1], + ), + Testcase( + "nested (dynamic)", + nothing, + (dynamic_nested_outer, Ref{Any}(nested_inner)), + [true, false], + ), + Testcase( + "nested (dynamic + used)", + nothing, + (dynamic_nested_outer_use_produced, Ref{Any}(nested_inner)), + [true, 1], + ), ] end @@ -165,15 +181,33 @@ end @noinline function nested_inner() produce(true) - return nothing + return 1 end Libtask.might_produce(::Type{Tuple{typeof(nested_inner)}}) = true -function nested_outer() +function static_nested_outer() nested_inner() produce(false) return nothing end +function static_nested_outer_use_produced() + y = nested_inner() + produce(y) + return nothing +end + +function dynamic_nested_outer(f::Ref{Any}) + f[]() + produce(false) + return nothing +end + +function dynamic_nested_outer_use_produced(f::Ref{Any}) + y = f[]() + produce(y) + return nothing +end + end From bf1893c733cf891ac076abf574f2d1095be5361a Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 13 Mar 2025 12:44:39 +0000 Subject: [PATCH 50/69] Fix docs build --- docs/src/internals.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/src/internals.md b/docs/src/internals.md index d98bb6c4..4becd94d 100644 --- a/docs/src/internals.md +++ b/docs/src/internals.md @@ -4,4 +4,6 @@ Libtask.produce_value Libtask.is_produce_stmt Libtask.might_produce +Libtask.stmt_might_produce +Libtask.LazyCallable ``` From 2c82159b9ca7699aba5623839a364f37a28d3d4a Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 13 Mar 2025 12:47:42 +0000 Subject: [PATCH 51/69] Update CI --- .github/workflows/Testing.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/Testing.yaml b/.github/workflows/Testing.yaml index 0d0facc7..59fbb3f8 100644 --- a/.github/workflows/Testing.yaml +++ b/.github/workflows/Testing.yaml @@ -25,8 +25,8 @@ jobs: - os: macOS-latest arch: x86 steps: - - uses: actions/checkout@v2 - - uses: julia-actions/setup-julia@v1 + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} arch: ${{ matrix.arch }} From b393c54df2dfa144fa1e6fdfb9f2b69eebab4083 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 13 Mar 2025 15:56:29 +0000 Subject: [PATCH 52/69] Handle callable structs --- src/copyable_task.jl | 47 +++++++++++++++++++++++++++++++++++++------- src/test_utils.jl | 12 +++++++++++ 2 files changed, 52 insertions(+), 7 deletions(-) diff --git a/src/copyable_task.jl b/src/copyable_task.jl index 85c3af8e..ff5fd939 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -39,9 +39,9 @@ function build_callable(sig::Type{<:Tuple}) return mc, refs[end] end -mutable struct TapedTask{Tdynamic_scope,Targs,Tmc<:MistyClosure} +mutable struct TapedTask{Tdynamic_scope,Tfargs,Tmc<:MistyClosure} dynamic_scope::Tdynamic_scope - args::Targs + fargs::Tfargs const mc::Tmc const position::Base.RefValue{Int32} end @@ -165,7 +165,7 @@ julia> consume(t) function TapedTask(dynamic_scope::Any, fargs...) seed_id!() mc, count_ref = build_callable(typeof(fargs)) - return TapedTask(dynamic_scope, fargs[2:end], mc, count_ref) + return TapedTask(dynamic_scope, fargs, mc, count_ref) end """ @@ -199,7 +199,7 @@ called, it start execution from the entry point. If `consume` has previously bee `nothing` will be returned. """ @inline function consume(t::TapedTask) - v = with(() -> t.mc(t.args...), dynamic_scope => t.dynamic_scope) + v = with(() -> t.mc(t.fargs...), dynamic_scope => t.dynamic_scope) return v isa ProducedValue ? v[] : nothing end @@ -287,12 +287,45 @@ end @inline Base.getindex(x::ProducedValue) = x.x +""" + inc_args(stmt) + +Increment by `1` the `n` field of any `Argument`s present in `stmt`. +Used in `make_ad_stmts!`. +""" +inc_args(x::Expr) = Expr(x.head, map(__inc, x.args)...) +inc_args(x::ReturnNode) = isdefined(x, :val) ? ReturnNode(__inc(x.val)) : x +inc_args(x::IDGotoIfNot) = IDGotoIfNot(__inc(x.cond), x.dest) +inc_args(x::IDGotoNode) = x +function inc_args(x::IDPhiNode) + new_values = Vector{Any}(undef, length(x.values)) + for n in eachindex(x.values) + if isassigned(x.values, n) + new_values[n] = __inc(x.values[n]) + end + end + return IDPhiNode(x.edges, new_values) +end +inc_args(::Nothing) = nothing +inc_args(x::GlobalRef) = x +inc_args(x::Core.PiNode) = Core.PiNode(__inc(x.val), __inc(x.typ)) + +__inc(x::Argument) = Argument(x.n + 1) +__inc(x) = x + function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} # The location from which all state can be retrieved. Since we're using `OpaqueClosure`s # to implement `TapedTask`s, this appears via the first argument. refs_id = Argument(1) + # Increment all arguments by 1. + for bb in ir.blocks, (n, inst) in enumerate(bb.insts) + bb.insts[n] = CC.NewInstruction( + inc_args(inst.stmt), inst.type, inst.info, inst.line, inst.flag + ) + end + # Construct map between SSA IDs and their index in the state data structure and back. # Also construct a map from each ref index to its type. We only construct `Ref`s # for statements which return a value e.g. `IDGotoIfNot`s do not have a meaningful @@ -778,7 +811,7 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} # rather than nothing at all. new_argtypes = copy(ir.argtypes) refs = (_refs..., Ref{Int32}(-1)) - new_argtypes[1] = typeof(refs) + new_argtypes = vcat(typeof(refs), copy(ir.argtypes)) # Return BBCode and the `Ref`s. return BBCode(new_bblocks, new_argtypes, ir.sptypes, ir.linetable, ir.meta), refs @@ -830,7 +863,7 @@ end function (l::LazyCallable)(args::Vararg{Any,N}) where {N} isdefined(l, :mc) || construct_callable!(l) - return l.mc(args[2:end]...) + return l.mc(args...) end function construct_callable!(l::LazyCallable{sig}) where {sig} @@ -853,5 +886,5 @@ function (dynamic_callable::DynamicCallable)(args::Vararg{Any,N}) where {N} callable = build_callable(sig) dynamic_callable.cache[sig] = callable end - return callable[1](args[2:end]...) + return callable[1](args...) end diff --git a/src/test_utils.jl b/src/test_utils.jl index 083224ce..1b1a4c6f 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -87,6 +87,7 @@ function test_cases() (dynamic_nested_outer_use_produced, Ref{Any}(nested_inner)), [true, 1], ), + Testcase("callable struct", nothing, (CallableStruct(5), 4), [5, 4, 9]), ] end @@ -210,4 +211,15 @@ function dynamic_nested_outer_use_produced(f::Ref{Any}) return nothing end +struct CallableStruct{T} + x::T +end + +function (c::CallableStruct)(y) + produce(c.x) + produce(y) + produce(c.x + y) + return nothing +end + end From 906cb6636aaedc0e79ca70d3e8db54157d4dc0bb Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 20 Mar 2025 09:04:18 +0000 Subject: [PATCH 53/69] Fix docs and add doctest --- docs/src/internals.md | 1 + src/copyable_task.jl | 10 +++++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/docs/src/internals.md b/docs/src/internals.md index 4becd94d..fb3ca3bd 100644 --- a/docs/src/internals.md +++ b/docs/src/internals.md @@ -6,4 +6,5 @@ Libtask.is_produce_stmt Libtask.might_produce Libtask.stmt_might_produce Libtask.LazyCallable +Libtask.inc_args ``` diff --git a/src/copyable_task.jl b/src/copyable_task.jl index ff5fd939..2ef2baa6 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -288,10 +288,14 @@ end @inline Base.getindex(x::ProducedValue) = x.x """ - inc_args(stmt) + inc_args(stmt::T)::T where {T} -Increment by `1` the `n` field of any `Argument`s present in `stmt`. -Used in `make_ad_stmts!`. +Returns a new `T` which is equal to `stmt`, except any `Argument`s present in `stmt` are +incremented by `1`. For example +```jldoctest +julia> Libtask.inc_args(Core.ReturnNode(Core.Argument(1))) +:(return _2) +``` """ inc_args(x::Expr) = Expr(x.head, map(__inc, x.args)...) inc_args(x::ReturnNode) = isdefined(x, :val) ? ReturnNode(__inc(x.val)) : x From 4f80127533e10c4311541bc74ca695985191f900 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 8 Apr 2025 14:29:44 +0100 Subject: [PATCH 54/69] Test kwargs --- src/copyable_task.jl | 9 +++--- src/test_utils.jl | 74 +++++++++++++++++++++++++++++++++++--------- 2 files changed, 64 insertions(+), 19 deletions(-) diff --git a/src/copyable_task.jl b/src/copyable_task.jl index 2ef2baa6..045d0e2d 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -47,7 +47,7 @@ mutable struct TapedTask{Tdynamic_scope,Tfargs,Tmc<:MistyClosure} end """ - TapedTask(dynamic_scope::Any, f, args...) + TapedTask(dynamic_scope::Any, f, args...; kwargs...) Construct a `TapedTask` with the specified `dynamic_scope`, for function `f` and positional arguments `args`. @@ -162,10 +162,11 @@ julia> consume(t) `Int`s have been used here, but it is permissible to set the value returned by [`Libtask.get_dynamic_scope`](@ref) to anything you like. """ -function TapedTask(dynamic_scope::Any, fargs...) +function TapedTask(dynamic_scope::Any, fargs...; kwargs...) seed_id!() - mc, count_ref = build_callable(typeof(fargs)) - return TapedTask(dynamic_scope, fargs, mc, count_ref) + all_args = isempty(kwargs) ? fargs : (Core.kwcall, getfield(kwargs, :data), fargs...) + mc, count_ref = build_callable(typeof(all_args)) + return TapedTask(dynamic_scope, all_args, mc, count_ref) end """ diff --git a/src/test_utils.jl b/src/test_utils.jl index 1b1a4c6f..c3a24707 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -8,6 +8,7 @@ struct Testcase name::String dynamic_scope::Any fargs::Tuple + kwargs::Union{NamedTuple,Nothing} expected_iteration_results::Vector end @@ -15,7 +16,11 @@ function (case::Testcase)() testset = @testset "$(case.name)" begin # Construct the task. - t = TapedTask(case.dynamic_scope, case.fargs...) + if case.kwargs === nothing + t = TapedTask(case.dynamic_scope, case.fargs...) + else + t = TapedTask(case.dynamic_scope, case.fargs...; case.kwargs...) + end # Iterate through t. Record the results, and take a copy after each iteration. iteration_results = [] @@ -42,52 +47,89 @@ function test_cases() "single block", nothing, (single_block, 5.0), + nothing, [sin(5.0), sin(sin(5.0)), sin(sin(sin(5.0))), sin(sin(sin(sin(5.0))))], ), - Testcase("produce old", nothing, (produce_old_value, 5.0), [sin(5.0), sin(5.0)]), - Testcase("branch on old value l", nothing, (branch_on_old_value, 2.0), [true, 2.0]), Testcase( - "branch on old value r", nothing, (branch_on_old_value, -1.0), [false, -2.0] + "produce old", nothing, (produce_old_value, 5.0), nothing, [sin(5.0), sin(5.0)] + ), + Testcase( + "branch on old value l", + nothing, + (branch_on_old_value, 2.0), + nothing, + [true, 2.0], + ), + Testcase( + "branch on old value r", + nothing, + (branch_on_old_value, -1.0), + nothing, + [false, -2.0], ), - Testcase("no produce", nothing, (no_produce_test, 5.0, 4.0), []), - Testcase("new object", nothing, (new_object_test, 5, 4), [C(5, 4), C(5, 4)]), + Testcase("no produce", nothing, (no_produce_test, 5.0, 4.0), nothing, []), Testcase( - "branching test l", nothing, (branching_test, 5.0, 4.0), [string(sin(5.0))] + "new object", nothing, (new_object_test, 5, 4), nothing, [C(5, 4), C(5, 4)] ), Testcase( - "branching test r", nothing, (branching_test, 4.0, 5.0), [sin(4.0) * cos(5.0)] + "branching test l", + nothing, + (branching_test, 5.0, 4.0), + nothing, + [string(sin(5.0))], ), - Testcase("unused argument test", nothing, (unused_argument_test, 3), [1]), - Testcase("test with const", nothing, (test_with_const,), [1]), - Testcase("while loop", nothing, (while_loop,), collect(1:9)), + Testcase( + "branching test r", + nothing, + (branching_test, 4.0, 5.0), + nothing, + [sin(4.0) * cos(5.0)], + ), + Testcase("unused argument test", nothing, (unused_argument_test, 3), nothing, [1]), + Testcase("test with const", nothing, (test_with_const,), nothing, [1]), + Testcase("while loop", nothing, (while_loop,), nothing, collect(1:9)), Testcase( "foreigncall tester", nothing, (foreigncall_tester, "hi"), + nothing, [Ptr{UInt8}, Ptr{UInt8}], ), - Testcase("dynamic scope 1", 5, (dynamic_scope_tester_1,), [5]), - Testcase("dynamic scope 2", 6, (dynamic_scope_tester_1,), [6]), - Testcase("nested (static)", nothing, (static_nested_outer,), [true, false]), + Testcase("dynamic scope 1", 5, (dynamic_scope_tester_1,), nothing, [5]), + Testcase("dynamic scope 2", 6, (dynamic_scope_tester_1,), nothing, [6]), + Testcase( + "nested (static)", nothing, (static_nested_outer,), nothing, [true, false] + ), Testcase( "nested (static + used)", nothing, (static_nested_outer_use_produced,), + nothing, [true, 1], ), Testcase( "nested (dynamic)", nothing, (dynamic_nested_outer, Ref{Any}(nested_inner)), + nothing, [true, false], ), Testcase( "nested (dynamic + used)", nothing, (dynamic_nested_outer_use_produced, Ref{Any}(nested_inner)), + nothing, [true, 1], ), - Testcase("callable struct", nothing, (CallableStruct(5), 4), [5, 4, 9]), + Testcase("callable struct", nothing, (CallableStruct(5), 4), nothing, [5, 4, 9]), + Testcase( + "kwarg tester 1", + nothing, + (Core.kwcall, (; y=5.0), kwarg_tester, 4.0), + nothing, + [], + ), + Testcase("kwargs tester 2", nothing, (kwarg_tester, 4.0), (; y=5.0), []), ] end @@ -222,4 +264,6 @@ function (c::CallableStruct)(y) return nothing end +kwarg_tester(x; y) = x + y + end From 065fb1954a1f44b8e07898a00d046dbb53571c67 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 8 Apr 2025 14:31:17 +0100 Subject: [PATCH 55/69] More tests --- src/test_utils.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/test_utils.jl b/src/test_utils.jl index c3a24707..346cad29 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -130,6 +130,8 @@ function test_cases() [], ), Testcase("kwargs tester 2", nothing, (kwarg_tester, 4.0), (; y=5.0), []), + Testcase("default kwarg tester", nothing, (default_kwarg_tester, 4.0), nothing, []), + Testcase("default kwarg tester", nothing, (default_kwarg_tester, 4.0), (;), []), ] end @@ -266,4 +268,6 @@ end kwarg_tester(x; y) = x + y +default_kwarg_tester(x; y=5.0) = x * y + end From 1ba2dfe295afdf221b4add494e17f48123329585 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 8 Apr 2025 15:13:15 +0100 Subject: [PATCH 56/69] Fix inference bug --- src/copyable_task.jl | 2 +- src/test_utils.jl | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/copyable_task.jl b/src/copyable_task.jl index 045d0e2d..62d3d0ed 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -22,7 +22,7 @@ See also: [`Libtask.consume`](@ref) """ @noinline function produce(x) global __v = 4 # silly side-effect to prevent this call getting constant-folded away. Should really use the effects system. - return ProducedValue(x) + return x end function callable_ret_type(sig) diff --git a/src/test_utils.jl b/src/test_utils.jl index 346cad29..3c2f066d 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -15,6 +15,9 @@ end function (case::Testcase)() testset = @testset "$(case.name)" begin + # Display some information. + @info "$(case.name)" + # Construct the task. if case.kwargs === nothing t = TapedTask(case.dynamic_scope, case.fargs...) @@ -132,6 +135,9 @@ function test_cases() Testcase("kwargs tester 2", nothing, (kwarg_tester, 4.0), (; y=5.0), []), Testcase("default kwarg tester", nothing, (default_kwarg_tester, 4.0), nothing, []), Testcase("default kwarg tester", nothing, (default_kwarg_tester, 4.0), (;), []), + Testcase( + "final statment produce", nothing, (final_statement_produce,), nothing, [1, 2] + ), ] end @@ -270,4 +276,9 @@ kwarg_tester(x; y) = x + y default_kwarg_tester(x; y=5.0) = x * y +function final_statement_produce() + produce(1) + return produce(2) +end + end From 999f5bceff528610f845de6f39e8e88861962cf1 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 8 Apr 2025 15:55:19 +0100 Subject: [PATCH 57/69] Improve documentation --- src/copyable_task.jl | 38 +++++++++++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/src/copyable_task.jl b/src/copyable_task.jl index 62d3d0ed..1eca5feb 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -41,7 +41,7 @@ end mutable struct TapedTask{Tdynamic_scope,Tfargs,Tmc<:MistyClosure} dynamic_scope::Tdynamic_scope - fargs::Tfargs + const fargs::Tfargs const mc::Tmc const position::Base.RefValue{Int32} end @@ -369,7 +369,7 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} end end - # For each existing basic block, produce a sequence of `NamedTuple`s which + # For each existing basic block, create a sequence of `NamedTuple`s which # define the manner in which it must be split. # A block will in general be split as follows: # 1 - %1 = φ(...) @@ -666,6 +666,16 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} end push!(new_blocks, BBlock(splits_ids[n], inst_pairs)) elseif is_produce_stmt(stmt) + # This is a statement of the form + # %n = produce(arg) + # + # We transform this into + # Libtask.set_resume_block!(refs_id, id_of_next_block) + # return ProducedValue(arg) + # + # The point is to ensure that, next time that this `TapedTask` is called, + # computation is resumed from the statement _after_ this produce statement, + # and to return whatever this produce statement returns. # When this TapedTask is next called, we should resume from the first # statement of the next split. @@ -699,7 +709,29 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} push!(new_blocks, BBlock(splits_ids[n], inst_pairs)) else # The final statement is one which might produce, but is not itself a - # `produce` statement. + # `produce` statement. For example + # y = f(x) + # + # becomes (morally speaking) + # y = f(x) + # if y isa ProducedValue + # set_resume_block!(refs_id, id_of_current_block) + # return y + # end + # + # The point is to ensure that, if `f` "produces" (as indicated by `y` being + # a `ProducedValue`) then the next time that this TapedTask is called, we + # must resume from the call to `f`, as subsequent runs might also produce. + # On the other hand, if anything other than a `ProducedValue` is returned, + # we know that `f` has nothing else to produce, and execution can safely + # continue to the next split. + # In addition to the above, we must do the usual thing and ensure that any + # ssas are read from storage, and write the result of this computation to + # storage before continuing to the next instruction. + # + # You should look at the IR generated by a simple example in the test suite + # which involves calls that might produce, in order to get a sense of what + # the resulting code looks like prior to digging into the code below. # Create a new basic block from the existing statements, since all new # statement need to live in their own basic blocks. From ec9edb70a746f97c09a77bdc7eb3b565d91af610 Mon Sep 17 00:00:00 2001 From: Will Tebbutt <3628294+willtebbutt@users.noreply.github.com> Date: Tue, 15 Apr 2025 09:52:50 +0100 Subject: [PATCH 58/69] Performance enhancements --- src/Libtask.jl | 1 + src/copyable_task.jl | 102 +++++++++++++++++++++++++++++++++--------- src/test_utils.jl | 95 ++++++++++++++++++++++++++++++++------- test/copyable_task.jl | 4 +- 4 files changed, 162 insertions(+), 40 deletions(-) diff --git a/src/Libtask.jl b/src/Libtask.jl index df61d87f..b40e9f10 100644 --- a/src/Libtask.jl +++ b/src/Libtask.jl @@ -4,6 +4,7 @@ module Libtask using Mooncake using Mooncake: BBCode, BBlock, ID, new_inst, stmt, seed_id!, terminator using Mooncake: IDGotoIfNot, IDGotoNode, IDPhiNode, Switch +using Mooncake.BasicBlockCode: collect_stmts, characterise_used_ids # We'll emit `MistyClosure`s rather than `OpaqueClosure`s. using MistyClosures diff --git a/src/copyable_task.jl b/src/copyable_task.jl index 1eca5feb..13741fe7 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -1,14 +1,16 @@ -const dynamic_scope = ScopedValue{Any}(0) - """ - get_dynamic_scope() + get_dynamic_scope(T::Type) Returns the dynamic scope associated to `Libtask`. If called from inside a `TapedTask`, this will return whatever is contained in its `dynamic_scope` field. +The type `T` is required for optimal performance. If you know that the result of this +operation must return a specific type, specific `T`. If you do not know what type it will +return, pass `Any` -- this will typically yield type instabilities, but will run correctly. + See also [`set_dynamic_scope!`](@ref). """ -get_dynamic_scope() = dynamic_scope[] +get_dynamic_scope(::Type{T}) where {T} = typeassert(task_local_storage(:task_variable), T) __v::Int = 5 @@ -25,16 +27,21 @@ See also: [`Libtask.consume`](@ref) return x end -function callable_ret_type(sig) - return Union{Base.code_ircode_by_type(sig)[1][2],ProducedValue} +function callable_ret_type(sig, types) + produce_type = Union{} + for t in types + p = isconcretetype(t) ? ProducedValue{t} : ProducedValue{T} where {T<:t} + produce_type = CC.tmerge(p, produce_type) + end + return Union{Base.code_ircode_by_type(sig)[1][2],produce_type} end function build_callable(sig::Type{<:Tuple}) ir = Base.code_ircode_by_type(sig)[1][1] - bb, refs = derive_copyable_task_ir(BBCode(ir)) + bb, refs, types = derive_copyable_task_ir(BBCode(ir)) unoptimised_ir = IRCode(bb) optimised_ir = Mooncake.optimise_ir!(unoptimised_ir) - mc_ret_type = callable_ret_type(sig) + mc_ret_type = callable_ret_type(sig, types) mc = Mooncake.misty_closure(mc_ret_type, optimised_ir, refs...; do_compile=true) return mc, refs[end] end @@ -200,7 +207,8 @@ called, it start execution from the entry point. If `consume` has previously bee `nothing` will be returned. """ @inline function consume(t::TapedTask) - v = with(() -> t.mc(t.fargs...), dynamic_scope => t.dynamic_scope) + task_local_storage(:task_variable, t.dynamic_scope) + v = t.mc.oc(t.fargs...) return v isa ProducedValue ? v[] : nothing end @@ -285,6 +293,7 @@ end struct ProducedValue{T} x::T end +ProducedValue(::Type{T}) where {T} = ProducedValue{Type{T}}(T) @inline Base.getindex(x::ProducedValue) = x.x @@ -318,7 +327,35 @@ inc_args(x::Core.PiNode) = Core.PiNode(__inc(x.val), __inc(x.typ)) __inc(x::Argument) = Argument(x.n + 1) __inc(x) = x -function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} +const TypeInfo = Tuple{Vector{Any},Dict{ID,Type}} + +""" + _typeof(x) + +Central definition of typeof, which is specific to the use-required in this package. +""" +_typeof(x) = Base._stable_typeof(x) +_typeof(x::Tuple) = Tuple{tuple_map(_typeof, x)...} +_typeof(x::NamedTuple{names}) where {names} = NamedTuple{names,_typeof(Tuple(x))} + +""" + get_type(info::ADInfo, x) + +Returns the static / inferred type associated to `x`. +""" +get_type(info::TypeInfo, x::Argument) = info[1][x.n - 1] +get_type(info::TypeInfo, x::ID) = CC.widenconst(info[2][x]) +get_type(::TypeInfo, x::QuoteNode) = _typeof(x.value) +get_type(::TypeInfo, x) = _typeof(x) +function get_type(::TypeInfo, x::GlobalRef) + return isconst(x) ? _typeof(getglobal(x.mod, x.name)) : x.binding.ty +end +function get_type(::TypeInfo, x::Expr) + x.head === :boundscheck && return Bool + return error("Unrecognised expression $x found in argument slot.") +end + +function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple,Vector{Any}} # The location from which all state can be retrieved. Since we're using `OpaqueClosure`s # to implement `TapedTask`s, this appears via the first argument. @@ -338,13 +375,17 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} ssa_id_to_ref_index_map = Dict{ID,Int}() ref_index_to_ssa_id_map = Dict{Int,ID}() ref_index_to_type_map = Dict{Int,Type}() + id_to_type_map = Dict{ID,Type}() + is_used_dict = characterise_used_ids(collect_stmts(ir)) n = 0 for bb in ir.blocks for (id, stmt) in zip(bb.inst_ids, bb.insts) + id_to_type_map[id] = CC.widenconst(stmt.type) stmt.stmt isa IDGotoNode && continue stmt.stmt isa IDGotoIfNot && continue stmt.stmt === nothing && continue stmt.stmt isa ReturnNode && continue + is_used_dict[id] || continue n += 1 ssa_id_to_ref_index_map[id] = n ref_index_to_ssa_id_map[n] = id @@ -447,6 +488,9 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} # A set of blocks from which we might wish to resume computation. resume_block_ids = Vector{ID}() + # A list onto which we'll push the type of any statement which might produce. + possible_produce_types = Any[] + # This where most of the action happens. # # For each split of each block, we must @@ -582,10 +626,13 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} push!(inst_pairs, (id, inst)) # If we know it is not possible for this statement to contain any calls - # to produce, then simply write out the result to its `Ref`. - out_ind = ssa_id_to_ref_index_map[id] - set_ref = Expr(:call, set_ref_at!, refs_id, out_ind, id) - push!(inst_pairs, (ID(), new_inst(set_ref))) + # to produce, then simply write out the result to its `Ref`. If it is + # never used, then there is no need to store it. + if is_used_dict[id] + out_ind = ssa_id_to_ref_index_map[id] + set_ref = Expr(:call, set_ref_at!, refs_id, out_ind, id) + push!(inst_pairs, (ID(), new_inst(set_ref))) + end elseif Meta.isexpr(stmt, :boundscheck) push!(inst_pairs, (id, inst)) elseif Meta.isexpr(stmt, :code_coverage_effect) @@ -677,6 +724,10 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} # computation is resumed from the statement _after_ this produce statement, # and to return whatever this produce statement returns. + # Log the result type of this statement. + arg = stmt.args[Meta.isexpr(stmt, :invoke) ? 3 : 2] + push!(possible_produce_types, get_type((ir.argtypes, id_to_type_map), arg)) + # When this TapedTask is next called, we should resume from the first # statement of the next split. resume_id = splits_ids[n + 1] @@ -733,6 +784,11 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} # which involves calls that might produce, in order to get a sense of what # the resulting code looks like prior to digging into the code below. + # At present, we're not able to properly infer the values which might + # potentially be produced by a call-which-might-produce. Consequently, we + # have to assume they can produce anything. + push!(possible_produce_types, Any) + # Create a new basic block from the existing statements, since all new # statement need to live in their own basic blocks. callable_block_id = ID() @@ -742,7 +798,8 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} # Derive TapedTask for this statement. (callable, callable_args) = if Meta.isexpr(stmt, :invoke) sig = stmt.args[1].specTypes - (LazyCallable{sig,callable_ret_type(sig)}(), stmt.args[2:end]) + v = Any[Any] + (LazyCallable{sig,callable_ret_type(sig, v)}(), stmt.args[2:end]) elseif Meta.isexpr(stmt, :call) (DynamicCallable(), stmt.args) else @@ -815,8 +872,12 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} # `ProducedValue`. In this case, we must first push the result to the `Ref` # associated to the call, and goto the next split. next_block_id = splits_ids[n + 1] # safe since the last split ends with a terminator - result_ref_ind = ssa_id_to_ref_index_map[id] - set_ref = Expr(:call, set_ref_at!, refs_id, result_ref_ind, result_id) + if is_used_dict[id] + result_ref_ind = ssa_id_to_ref_index_map[id] + set_ref = Expr(:call, set_ref_at!, refs_id, result_ref_ind, result_id) + else + set_ref = nothing + end not_produced_block_inst_pairs = Mooncake.IDInstPair[ (ID(), new_inst(set_ref)) (ID(), new_inst(IDGotoNode(next_block_id))) @@ -851,13 +912,12 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple} new_argtypes = vcat(typeof(refs), copy(ir.argtypes)) # Return BBCode and the `Ref`s. - return BBCode(new_bblocks, new_argtypes, ir.sptypes, ir.linetable, ir.meta), refs + new_ir = BBCode(new_bblocks, new_argtypes, ir.sptypes, ir.linetable, ir.meta) + return new_ir, refs, possible_produce_types end # Helper used in `derive_copyable_task_ir`. -@inline function get_ref_at(refs::R, n::Int) where {R<:Tuple} - return refs[n][] -end +@inline get_ref_at(refs::R, n::Int) where {R<:Tuple} = refs[n][] # Helper used in `derive_copyable_task_ir`. @inline function set_ref_at!(refs::R, n::Int, val) where {R<:Tuple} diff --git a/src/test_utils.jl b/src/test_utils.jl index 3c2f066d..f2580a80 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -4,12 +4,20 @@ using ..Libtask using Test using ..Libtask: TapedTask +# Function barrier to ensure inference in value types. +function count_allocs(f::F, x::Vararg{Any,N}) where {F,N} + @allocations f(x...) +end + +@enum PerfFlag none allocs + struct Testcase name::String dynamic_scope::Any fargs::Tuple kwargs::Union{NamedTuple,Nothing} expected_iteration_results::Vector + perf::PerfFlag end function (case::Testcase)() @@ -40,6 +48,21 @@ function (case::Testcase)() for (n, t_copy) in enumerate(t_copies) @test iteration_results[n:end] == collect(t_copy) end + + # Check no allocations if requested. + if case.perf == allocs + + # Construct the task. + if case.kwargs === nothing + t = TapedTask(case.dynamic_scope, case.fargs...) + else + t = TapedTask(case.dynamic_scope, case.fargs...; case.kwargs...) + end + + for _ in iteration_results + @test count_allocs(consume, t) == 0 + end + end end return testset end @@ -52,9 +75,15 @@ function test_cases() (single_block, 5.0), nothing, [sin(5.0), sin(sin(5.0)), sin(sin(sin(5.0))), sin(sin(sin(sin(5.0))))], + allocs, ), Testcase( - "produce old", nothing, (produce_old_value, 5.0), nothing, [sin(5.0), sin(5.0)] + "produce old", + nothing, + (produce_old_value, 5.0), + nothing, + [sin(5.0), sin(5.0)], + allocs, ), Testcase( "branch on old value l", @@ -62,6 +91,7 @@ function test_cases() (branch_on_old_value, 2.0), nothing, [true, 2.0], + allocs, ), Testcase( "branch on old value r", @@ -69,17 +99,24 @@ function test_cases() (branch_on_old_value, -1.0), nothing, [false, -2.0], + allocs, ), - Testcase("no produce", nothing, (no_produce_test, 5.0, 4.0), nothing, []), + Testcase("no produce", nothing, (no_produce_test, 5.0, 4.0), nothing, [], allocs), Testcase( - "new object", nothing, (new_object_test, 5, 4), nothing, [C(5, 4), C(5, 4)] + "new object", + nothing, + (new_object_test, 5, 4), + nothing, + [C(5, 4), C(5, 4)], + none, ), Testcase( "branching test l", nothing, (branching_test, 5.0, 4.0), nothing, - [string(sin(5.0))], + [complex(sin(5.0))], + allocs, ), Testcase( "branching test r", @@ -87,21 +124,25 @@ function test_cases() (branching_test, 4.0, 5.0), nothing, [sin(4.0) * cos(5.0)], + allocs, + ), + Testcase( + "unused argument test", nothing, (unused_argument_test, 3), nothing, [1], allocs ), - Testcase("unused argument test", nothing, (unused_argument_test, 3), nothing, [1]), - Testcase("test with const", nothing, (test_with_const,), nothing, [1]), - Testcase("while loop", nothing, (while_loop,), nothing, collect(1:9)), + Testcase("test with const", nothing, (test_with_const,), nothing, [1], allocs), + Testcase("while loop", nothing, (while_loop,), nothing, collect(1:9), allocs), Testcase( "foreigncall tester", nothing, (foreigncall_tester, "hi"), nothing, [Ptr{UInt8}, Ptr{UInt8}], + allocs, ), - Testcase("dynamic scope 1", 5, (dynamic_scope_tester_1,), nothing, [5]), - Testcase("dynamic scope 2", 6, (dynamic_scope_tester_1,), nothing, [6]), + Testcase("dynamic scope 1", 5, (dynamic_scope_tester_1,), nothing, [5], allocs), + Testcase("dynamic scope 2", 6, (dynamic_scope_tester_1,), nothing, [6], none), Testcase( - "nested (static)", nothing, (static_nested_outer,), nothing, [true, false] + "nested (static)", nothing, (static_nested_outer,), nothing, [true, false], none ), Testcase( "nested (static + used)", @@ -109,6 +150,7 @@ function test_cases() (static_nested_outer_use_produced,), nothing, [true, 1], + none, ), Testcase( "nested (dynamic)", @@ -116,6 +158,7 @@ function test_cases() (dynamic_nested_outer, Ref{Any}(nested_inner)), nothing, [true, false], + none, ), Testcase( "nested (dynamic + used)", @@ -123,20 +166,38 @@ function test_cases() (dynamic_nested_outer_use_produced, Ref{Any}(nested_inner)), nothing, [true, 1], + none, + ), + Testcase( + "callable struct", nothing, (CallableStruct(5), 4), nothing, [5, 4, 9], allocs ), - Testcase("callable struct", nothing, (CallableStruct(5), 4), nothing, [5, 4, 9]), Testcase( "kwarg tester 1", nothing, (Core.kwcall, (; y=5.0), kwarg_tester, 4.0), nothing, [], + allocs, + ), + Testcase("kwargs tester 2", nothing, (kwarg_tester, 4.0), (; y=5.0), [], allocs), + Testcase( + "default kwarg tester", + nothing, + (default_kwarg_tester, 4.0), + nothing, + [], + allocs, + ), + Testcase( + "default kwarg tester", nothing, (default_kwarg_tester, 4.0), (;), [], allocs ), - Testcase("kwargs tester 2", nothing, (kwarg_tester, 4.0), (; y=5.0), []), - Testcase("default kwarg tester", nothing, (default_kwarg_tester, 4.0), nothing, []), - Testcase("default kwarg tester", nothing, (default_kwarg_tester, 4.0), (;), []), Testcase( - "final statment produce", nothing, (final_statement_produce,), nothing, [1, 2] + "final statment produce", + nothing, + (final_statement_produce,), + nothing, + [1, 2], + allocs, ), ] end @@ -190,7 +251,7 @@ end function branching_test(x, y) if x > y - produce(string(sin(x))) + produce(complex(sin(x))) else produce(sin(x) * cos(y)) end @@ -226,7 +287,7 @@ function foreigncall_tester(s::String) end function dynamic_scope_tester_1() - produce(Libtask.get_dynamic_scope()) + produce(Libtask.get_dynamic_scope(Int)) return nothing end diff --git a/test/copyable_task.jl b/test/copyable_task.jl index 9fd18c4c..42625176 100644 --- a/test/copyable_task.jl +++ b/test/copyable_task.jl @@ -4,8 +4,8 @@ end @testset "set_dynamic_scope" begin function f() - produce(typeassert(Libtask.get_dynamic_scope(), Int)) - produce(typeassert(Libtask.get_dynamic_scope(), Int)) + produce(Libtask.get_dynamic_scope(Int)) + produce(Libtask.get_dynamic_scope(Int)) return nothing end t = TapedTask(5, f) From 5135c4ae917debcf9a29c9a6b8c6d0d88a77e1b9 Mon Sep 17 00:00:00 2001 From: Will Tebbutt <3628294+willtebbutt@users.noreply.github.com> Date: Tue, 15 Apr 2025 11:49:06 +0100 Subject: [PATCH 59/69] Caching and tweaks --- src/copyable_task.jl | 49 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/src/copyable_task.jl b/src/copyable_task.jl index 13741fe7..0e8323ec 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -37,13 +37,20 @@ function callable_ret_type(sig, types) end function build_callable(sig::Type{<:Tuple}) - ir = Base.code_ircode_by_type(sig)[1][1] - bb, refs, types = derive_copyable_task_ir(BBCode(ir)) - unoptimised_ir = IRCode(bb) - optimised_ir = Mooncake.optimise_ir!(unoptimised_ir) - mc_ret_type = callable_ret_type(sig, types) - mc = Mooncake.misty_closure(mc_ret_type, optimised_ir, refs...; do_compile=true) - return mc, refs[end] + key = CacheKey(Base.get_world_counter(), sig) + if haskey(mc_cache, key) + v = fresh_copy(mc_cache[key]) + return v + else + ir = Base.code_ircode_by_type(sig)[1][1] + bb, refs, types = derive_copyable_task_ir(BBCode(ir)) + unoptimised_ir = IRCode(bb) + optimised_ir = Mooncake.optimise_ir!(unoptimised_ir) + mc_ret_type = callable_ret_type(sig, types) + mc = Mooncake.misty_closure(mc_ret_type, optimised_ir, refs...; do_compile=true) + mc_cache[key] = mc + return mc, refs[end] + end end mutable struct TapedTask{Tdynamic_scope,Tfargs,Tmc<:MistyClosure} @@ -53,6 +60,13 @@ mutable struct TapedTask{Tdynamic_scope,Tfargs,Tmc<:MistyClosure} const position::Base.RefValue{Int32} end +struct CacheKey + world_age::UInt + key::Any +end + +const mc_cache = Dict{CacheKey,MistyClosure}() + """ TapedTask(dynamic_scope::Any, f, args...; kwargs...) @@ -170,12 +184,27 @@ julia> consume(t) [`Libtask.get_dynamic_scope`](@ref) to anything you like. """ function TapedTask(dynamic_scope::Any, fargs...; kwargs...) - seed_id!() all_args = isempty(kwargs) ? fargs : (Core.kwcall, getfield(kwargs, :data), fargs...) + seed_id!() mc, count_ref = build_callable(typeof(all_args)) return TapedTask(dynamic_scope, all_args, mc, count_ref) end +function fresh_copy(mc::T) where {T<:MistyClosure} + new_captures = Mooncake.tuple_map(mc.oc.captures) do r + if eltype(r) <: DynamicCallable + return Base.RefValue(DynamicCallable()) + elseif eltype(r) <: LazyCallable + return _typeof(r)(eltype(r)()) + else + return _typeof(r)() + end + end + new_position = new_captures[end] + new_position[] = -1 + return Mooncake.replace_captures(mc, new_captures), new_position +end + """ set_dynamic_scope!(t::TapedTask, new_dynamic_scope)::Nothing @@ -335,7 +364,7 @@ const TypeInfo = Tuple{Vector{Any},Dict{ID,Type}} Central definition of typeof, which is specific to the use-required in this package. """ _typeof(x) = Base._stable_typeof(x) -_typeof(x::Tuple) = Tuple{tuple_map(_typeof, x)...} +_typeof(x::Tuple) = Tuple{map(_typeof, x)...} _typeof(x::NamedTuple{names}) where {names} = NamedTuple{names,_typeof(Tuple(x))} """ @@ -641,6 +670,8 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple,Vector{Any}} push!(inst_pairs, (id, inst)) elseif Meta.isexpr(stmt, :gc_preserve_end) push!(inst_pairs, (id, inst)) + elseif Meta.isexpr(stmt, :throw_undef_if_not) + push!(inst_pairs, (id, inst)) elseif stmt isa Nothing push!(inst_pairs, (id, inst)) elseif stmt isa GlobalRef From b42227ea640fcb03ac050cad7a099a11cf3edb06 Mon Sep 17 00:00:00 2001 From: Will Tebbutt <3628294+willtebbutt@users.noreply.github.com> Date: Tue, 15 Apr 2025 12:59:28 +0100 Subject: [PATCH 60/69] Fix docs build --- Project.toml | 2 -- docs/src/internals.md | 2 ++ src/Libtask.jl | 8 -------- src/copyable_task.jl | 6 +++--- test/runtests.jl | 3 +-- 5 files changed, 6 insertions(+), 15 deletions(-) diff --git a/Project.toml b/Project.toml index ce0f4b20..00fbaebd 100644 --- a/Project.toml +++ b/Project.toml @@ -8,7 +8,6 @@ version = "0.8.8" [deps] MistyClosures = "dbe65cb8-6be2-42dd-bbc5-4196aaced4f4" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" -ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] @@ -16,7 +15,6 @@ Aqua = "0.8.11" JuliaFormatter = "1.0.62" MistyClosures = "2.0.0" Mooncake = "0.4.99" -ScopedValues = "1.3.0" Test = "1" julia = "1.10.8" diff --git a/docs/src/internals.md b/docs/src/internals.md index fb3ca3bd..e92ad4f8 100644 --- a/docs/src/internals.md +++ b/docs/src/internals.md @@ -7,4 +7,6 @@ Libtask.might_produce Libtask.stmt_might_produce Libtask.LazyCallable Libtask.inc_args +Libtask.get_type +Libtask._typeof ``` diff --git a/src/Libtask.jl b/src/Libtask.jl index b40e9f10..8345d6bd 100644 --- a/src/Libtask.jl +++ b/src/Libtask.jl @@ -9,14 +9,6 @@ using Mooncake.BasicBlockCode: collect_stmts, characterise_used_ids # We'll emit `MistyClosure`s rather than `OpaqueClosure`s. using MistyClosures -# ScopedValues only became available as part of `Base` in v1.11. Therefore, on v1.10 we -# need to use the `ScopedValues` package. -@static if VERSION < v"1.11" - using ScopedValues: ScopedValue, with -else - using Base.ScopedValues: ScopedValue, with -end - # Import some names from the compiler. const CC = Core.Compiler using Core.Compiler: Argument, IRCode, ReturnNode diff --git a/src/copyable_task.jl b/src/copyable_task.jl index 0e8323ec..adb7db2b 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -145,7 +145,7 @@ julia> consume(t2) 2 ``` -## Scoped Values +## TapedTask-Specific Globals It is often desirable to permit a copy of a task and the original to differ in very specific ways. For example, in the context of Sequential Monte Carlo, you might want the only @@ -156,8 +156,8 @@ A generic mechanism is available to achieve this. [`Libtask.get_dynamic_scope`]( to a given [`Libtask.TapedTask`](@ref). The former can be called inside a function: ```jldoctest sv julia> function f() - produce(get_dynamic_scope()) - produce(get_dynamic_scope()) + produce(get_dynamic_scope(Int)) + produce(get_dynamic_scope(Int)) return nothing end f (generic function with 1 method) diff --git a/test/runtests.jl b/test/runtests.jl index 0a049368..54876973 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,8 +1,7 @@ include("front_matter.jl") @testset "Libtask" begin @testset "quality" begin - # ScopedValues is stale on 1.11. - Aqua.test_all(Libtask; stale_deps=VERSION < v"1.11" ? true : false) + Aqua.test_all(Libtask) @test JuliaFormatter.format(Libtask; verbose=false, overwrite=false) end include("copyable_task.jl") From aa86594f220c9f58fd3baae7eb6e765827f99fa2 Mon Sep 17 00:00:00 2001 From: Will Tebbutt <3628294+willtebbutt@users.noreply.github.com> Date: Tue, 15 Apr 2025 13:14:38 +0100 Subject: [PATCH 61/69] Docs and tidy up --- docs/src/index.md | 4 ++-- src/Libtask.jl | 2 +- src/copyable_task.jl | 52 +++++++++++++++++++++---------------------- src/test_utils.jl | 18 +++++++-------- test/copyable_task.jl | 8 +++---- 5 files changed, 41 insertions(+), 43 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 3bc62e1f..9beb0078 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -14,11 +14,11 @@ First, manipulation of [`TapedTask`](@ref)s: ```@docs; canonical=true Libtask.consume Base.copy(::Libtask.TapedTask) -Libtask.set_dynamic_scope! +Libtask.set_taped_globals! ``` Functions for use inside a [`TapedTask`](@ref)s are: ```@docs; canonical=true Libtask.produce -Libtask.get_dynamic_scope +Libtask.get_taped_globals ``` diff --git a/src/Libtask.jl b/src/Libtask.jl index 8345d6bd..34dcceb3 100644 --- a/src/Libtask.jl +++ b/src/Libtask.jl @@ -16,6 +16,6 @@ using Core.Compiler: Argument, IRCode, ReturnNode include("copyable_task.jl") include("test_utils.jl") -export TapedTask, consume, produce, get_dynamic_scope, set_dynamic_scope! +export TapedTask, consume, produce, get_taped_globals, set_taped_globals! end diff --git a/src/copyable_task.jl b/src/copyable_task.jl index adb7db2b..947eafd9 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -1,16 +1,16 @@ """ - get_dynamic_scope(T::Type) + get_taped_globals(T::Type) Returns the dynamic scope associated to `Libtask`. If called from inside a `TapedTask`, this -will return whatever is contained in its `dynamic_scope` field. +will return whatever is contained in its `taped_globals` field. The type `T` is required for optimal performance. If you know that the result of this operation must return a specific type, specific `T`. If you do not know what type it will return, pass `Any` -- this will typically yield type instabilities, but will run correctly. -See also [`set_dynamic_scope!`](@ref). +See also [`set_taped_globals!`](@ref). """ -get_dynamic_scope(::Type{T}) where {T} = typeassert(task_local_storage(:task_variable), T) +get_taped_globals(::Type{T}) where {T} = typeassert(task_local_storage(:task_variable), T) __v::Int = 5 @@ -53,8 +53,8 @@ function build_callable(sig::Type{<:Tuple}) end end -mutable struct TapedTask{Tdynamic_scope,Tfargs,Tmc<:MistyClosure} - dynamic_scope::Tdynamic_scope +mutable struct TapedTask{Ttaped_globals,Tfargs,Tmc<:MistyClosure} + taped_globals::Ttaped_globals const fargs::Tfargs const mc::Tmc const position::Base.RefValue{Int32} @@ -68,10 +68,10 @@ end const mc_cache = Dict{CacheKey,MistyClosure}() """ - TapedTask(dynamic_scope::Any, f, args...; kwargs...) + TapedTask(taped_globals::Any, f, args...; kwargs...) -Construct a `TapedTask` with the specified `dynamic_scope`, for function `f` and positional -arguments `args`. +Construct a `TapedTask` with the specified `taped_globals`, for function `f`, positional +arguments `args`, and keyword argument `kwargs`. # Extended Help @@ -151,20 +151,20 @@ It is often desirable to permit a copy of a task and the original to differ in v ways. For example, in the context of Sequential Monte Carlo, you might want the only difference between two copies to be their random number generator. -A generic mechanism is available to achieve this. [`Libtask.get_dynamic_scope`](@ref) and -[`Libtask.set_dynamic_scope!`](@ref) let you set and retrieve a variable which is specific +A generic mechanism is available to achieve this. [`Libtask.get_taped_globals`](@ref) and +[`Libtask.set_taped_globals!`](@ref) let you set and retrieve a variable which is specific to a given [`Libtask.TapedTask`](@ref). The former can be called inside a function: ```jldoctest sv julia> function f() - produce(get_dynamic_scope(Int)) - produce(get_dynamic_scope(Int)) + produce(get_taped_globals(Int)) + produce(get_taped_globals(Int)) return nothing end f (generic function with 1 method) ``` The first argument to [`Libtask.TapedTask`](@ref) is the value that -[`Libtask.get_dynamic_scope`](@ref) will return: +[`Libtask.get_taped_globals`](@ref) will return: ```jldoctest sv julia> t = TapedTask(1, f); @@ -174,20 +174,20 @@ julia> consume(t) The value that it returns can be changed between [`Libtask.consume`](@ref) calls: ```jldoctest sv -julia> set_dynamic_scope!(t, 2) +julia> set_taped_globals!(t, 2) julia> consume(t) 2 ``` `Int`s have been used here, but it is permissible to set the value returned by -[`Libtask.get_dynamic_scope`](@ref) to anything you like. +[`Libtask.get_taped_globals`](@ref) to anything you like. """ -function TapedTask(dynamic_scope::Any, fargs...; kwargs...) +function TapedTask(taped_globals::Any, fargs...; kwargs...) all_args = isempty(kwargs) ? fargs : (Core.kwcall, getfield(kwargs, :data), fargs...) seed_id!() mc, count_ref = build_callable(typeof(all_args)) - return TapedTask(dynamic_scope, all_args, mc, count_ref) + return TapedTask(taped_globals, all_args, mc, count_ref) end function fresh_copy(mc::T) where {T<:MistyClosure} @@ -206,16 +206,14 @@ function fresh_copy(mc::T) where {T<:MistyClosure} end """ - set_dynamic_scope!(t::TapedTask, new_dynamic_scope)::Nothing + set_taped_globals!(t::TapedTask, new_taped_globals)::Nothing -Set the `dynamic_scope` of `t` to `new_dynamic_scope`. Any references to -`LibTask.dynamic_scope` in future calls to `consume(t)` (either directly, or implicitly via -iteration) will see this new value. - -See also: [`get_dynamic_scope`](@ref). +Set the `taped_globals` of `t` to `new_taped_globals`. Any calls to +[`get_taped_globals`](@ref) in future calls to `consume(t)` (either directly, or implicitly +via iteration) will see this new value. """ -function set_dynamic_scope!(t::TapedTask{T}, new_dynamic_scope::T)::Nothing where {T} - t.dynamic_scope = new_dynamic_scope +function set_taped_globals!(t::TapedTask{T}, new_taped_globals::T)::Nothing where {T} + t.taped_globals = new_taped_globals return nothing end @@ -236,7 +234,7 @@ called, it start execution from the entry point. If `consume` has previously bee `nothing` will be returned. """ @inline function consume(t::TapedTask) - task_local_storage(:task_variable, t.dynamic_scope) + task_local_storage(:task_variable, t.taped_globals) v = t.mc.oc(t.fargs...) return v isa ProducedValue ? v[] : nothing end diff --git a/src/test_utils.jl b/src/test_utils.jl index f2580a80..59ec70f0 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -13,7 +13,7 @@ end struct Testcase name::String - dynamic_scope::Any + taped_globals::Any fargs::Tuple kwargs::Union{NamedTuple,Nothing} expected_iteration_results::Vector @@ -28,9 +28,9 @@ function (case::Testcase)() # Construct the task. if case.kwargs === nothing - t = TapedTask(case.dynamic_scope, case.fargs...) + t = TapedTask(case.taped_globals, case.fargs...) else - t = TapedTask(case.dynamic_scope, case.fargs...; case.kwargs...) + t = TapedTask(case.taped_globals, case.fargs...; case.kwargs...) end # Iterate through t. Record the results, and take a copy after each iteration. @@ -54,9 +54,9 @@ function (case::Testcase)() # Construct the task. if case.kwargs === nothing - t = TapedTask(case.dynamic_scope, case.fargs...) + t = TapedTask(case.taped_globals, case.fargs...) else - t = TapedTask(case.dynamic_scope, case.fargs...; case.kwargs...) + t = TapedTask(case.taped_globals, case.fargs...; case.kwargs...) end for _ in iteration_results @@ -139,8 +139,8 @@ function test_cases() [Ptr{UInt8}, Ptr{UInt8}], allocs, ), - Testcase("dynamic scope 1", 5, (dynamic_scope_tester_1,), nothing, [5], allocs), - Testcase("dynamic scope 2", 6, (dynamic_scope_tester_1,), nothing, [6], none), + Testcase("dynamic scope 1", 5, (taped_globals_tester_1,), nothing, [5], allocs), + Testcase("dynamic scope 2", 6, (taped_globals_tester_1,), nothing, [6], none), Testcase( "nested (static)", nothing, (static_nested_outer,), nothing, [true, false], none ), @@ -286,8 +286,8 @@ function foreigncall_tester(s::String) return nothing end -function dynamic_scope_tester_1() - produce(Libtask.get_dynamic_scope(Int)) +function taped_globals_tester_1() + produce(Libtask.get_taped_globals(Int)) return nothing end diff --git a/test/copyable_task.jl b/test/copyable_task.jl index 42625176..5b3feb27 100644 --- a/test/copyable_task.jl +++ b/test/copyable_task.jl @@ -2,15 +2,15 @@ for case in Libtask.TestUtils.test_cases() case() end - @testset "set_dynamic_scope" begin + @testset "set_taped_globals!" begin function f() - produce(Libtask.get_dynamic_scope(Int)) - produce(Libtask.get_dynamic_scope(Int)) + produce(Libtask.get_taped_globals(Int)) + produce(Libtask.get_taped_globals(Int)) return nothing end t = TapedTask(5, f) @test consume(t) == 5 - Libtask.set_dynamic_scope!(t, 6) + Libtask.set_taped_globals!(t, 6) @test consume(t) == 6 @test consume(t) === nothing end From 5b07ec1a84ba14e18fac0076d68c9c5fdd2c8422 Mon Sep 17 00:00:00 2001 From: Will Tebbutt <3628294+willtebbutt@users.noreply.github.com> Date: Tue, 15 Apr 2025 13:33:21 +0100 Subject: [PATCH 62/69] Include specifics from Mooncake --- Project.toml | 4 +- src/Libtask.jl | 12 +- src/bbcode.jl | 1004 ++++++++++++++++++++++++++++++++++++++++++ src/copyable_task.jl | 18 +- src/utils.jl | 204 +++++++++ 5 files changed, 1225 insertions(+), 17 deletions(-) create mode 100644 src/bbcode.jl create mode 100644 src/utils.jl diff --git a/Project.toml b/Project.toml index 00fbaebd..a7a110f6 100644 --- a/Project.toml +++ b/Project.toml @@ -6,15 +6,15 @@ repo = "https://github.com/TuringLang/Libtask.jl.git" version = "0.8.8" [deps] +Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" MistyClosures = "dbe65cb8-6be2-42dd-bbc5-4196aaced4f4" -Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] Aqua = "0.8.11" +Graphs = "1.12.1" JuliaFormatter = "1.0.62" MistyClosures = "2.0.0" -Mooncake = "0.4.99" Test = "1" julia = "1.10.8" diff --git a/src/Libtask.jl b/src/Libtask.jl index 34dcceb3..ff4692f5 100644 --- a/src/Libtask.jl +++ b/src/Libtask.jl @@ -1,18 +1,18 @@ module Libtask -# Need this for BBCode. -using Mooncake -using Mooncake: BBCode, BBlock, ID, new_inst, stmt, seed_id!, terminator -using Mooncake: IDGotoIfNot, IDGotoNode, IDPhiNode, Switch -using Mooncake.BasicBlockCode: collect_stmts, characterise_used_ids - # We'll emit `MistyClosure`s rather than `OpaqueClosure`s. using MistyClosures # Import some names from the compiler. const CC = Core.Compiler +using Core: OpaqueClosure using Core.Compiler: Argument, IRCode, ReturnNode +# IR-related functionality from Mooncake. +include("utils.jl") +include("bbcode.jl") +using .BasicBlockCode + include("copyable_task.jl") include("test_utils.jl") diff --git a/src/bbcode.jl b/src/bbcode.jl new file mode 100644 index 00000000..646aa0e9 --- /dev/null +++ b/src/bbcode.jl @@ -0,0 +1,1004 @@ +""" + module BasicBlockCode + +Copied over from Mooncake.jl in order to avoid making this package depend on Mooncake. +Refer to Mooncake's developer docs for context on this file. +""" +module BasicBlockCode + +using Graphs + +using Core.Compiler: + ReturnNode, + PhiNode, + GotoIfNot, + GotoNode, + NewInstruction, + IRCode, + SSAValue, + PiNode, + Argument +const CC = Core.Compiler + +export ID, + seed_id!, + IDPhiNode, + IDGotoNode, + IDGotoIfNot, + Switch, + BBlock, + phi_nodes, + terminator, + insert_before_terminator!, + collect_stmts, + compute_all_predecessors, + BBCode, + remove_unreachable_blocks!, + characterise_used_ids, + characterise_unique_predecessor_blocks, + sort_blocks!, + InstVector, + IDInstPair, + __line_numbers_to_block_numbers!, + is_reachable_return_node, + new_inst + +const _id_count::Dict{Int,Int32} = Dict{Int,Int32}() + +""" + new_inst(stmt, type=Any, flag=CC.IR_FLAG_REFINED)::NewInstruction + +Create a `NewInstruction` with fields: +- `stmt` = `stmt` +- `type` = `type` +- `info` = `CC.NoCallInfo()` +- `line` = `Int32(1)` +- `flag` = `flag` +""" +function new_inst(@nospecialize(stmt), @nospecialize(type)=Any, flag=CC.IR_FLAG_REFINED) + return NewInstruction(stmt, type, CC.NoCallInfo(), Int32(1), flag) +end + +""" + const InstVector = Vector{NewInstruction} + +Note: the `CC.NewInstruction` type is used to represent instructions because it has the +correct fields. While it is only used to represent new instrucdtions in `Core.Compiler`, it +is used to represent all instructions in `BBCode`. +""" +const InstVector = Vector{NewInstruction} + +""" + ID() + +An `ID` (read: unique name) is just a wrapper around an `Int32`. Uniqueness is ensured via a +global counter, which is incremented each time that an `ID` is created. + +This counter can be reset using `seed_id!` if you need to ensure deterministic `ID`s are +produced, in the same way that seed for random number generators can be set. +""" +struct ID + id::Int32 + function ID() + current_thread_id = Threads.threadid() + id_count = get(_id_count, current_thread_id, Int32(0)) + _id_count[current_thread_id] = id_count + Int32(1) + return new(id_count) + end +end + +Base.copy(id::ID) = id + +""" + seed_id!() + +Set the global counter used to ensure ID uniqueness to 0. This is useful when you want to +ensure determinism between two runs of the same function which makes use of `ID`s. + +This is akin to setting the random seed associated to a random number generator globally. +""" +function seed_id!() + return global _id_count[Threads.threadid()] = 0 +end + +""" + IDPhiNode(edges::Vector{ID}, values::Vector{Any}) + +Like a `PhiNode`, but `edges` are `ID`s rather than `Int32`s. +""" +struct IDPhiNode + edges::Vector{ID} + values::Vector{Any} +end + +Base.:(==)(x::IDPhiNode, y::IDPhiNode) = x.edges == y.edges && x.values == y.values + +Base.copy(node::IDPhiNode) = IDPhiNode(copy(node.edges), copy(node.values)) + +""" + IDGotoNode(label::ID) + +Like a `GotoNode`, but `label` is an `ID` rather than an `Int64`. +""" +struct IDGotoNode + label::ID +end + +Base.copy(node::IDGotoNode) = IDGotoNode(copy(node.label)) + +""" + IDGotoIfNot(cond::Any, dest::ID) + +Like a `GotoIfNot`, but `dest` is an `ID` rather than an `Int64`. +""" +struct IDGotoIfNot + cond::Any + dest::ID +end + +Base.copy(node::IDGotoIfNot) = IDGotoIfNot(copy(node.cond), copy(node.dest)) + +""" + Switch(conds::Vector{Any}, dests::Vector{ID}, fallthrough_dest::ID) + +A switch-statement node. These can be inserted in the `BBCode` representation of Julia IR. +`Switch` has the following semantics: +```julia +goto dests[1] if not conds[1] +goto dests[2] if not conds[2] +... +goto dests[N] if not conds[N] +goto fallthrough_dest +``` +where the value associated to each element of `conds` is a `Bool`, and `dests` indicate +which block to jump to. If none of the conditions are met, then we go to whichever block is +specified by `fallthrough_dest`. + +`Switch` statements are lowered into the above sequence of `GotoIfNot`s and `GotoNode`s +when converting `BBCode` back into `IRCode`, because `Switch` statements are not valid +nodes in regular Julia IR. +""" +struct Switch + conds::Vector{Any} + dests::Vector{ID} + fallthrough_dest::ID + function Switch(conds::Vector{Any}, dests::Vector{ID}, fallthrough_dest::ID) + @assert length(conds) == length(dests) + return new(conds, dests, fallthrough_dest) + end +end + +""" + Terminator = Union{Switch, IDGotoIfNot, IDGotoNode, ReturnNode} + +A Union of the possible types of a terminator node. +""" +const Terminator = Union{Switch,IDGotoIfNot,IDGotoNode,ReturnNode} + +""" + BBlock(id::ID, stmt_ids::Vector{ID}, stmts::InstVector) + +A basic block data structure (not called `BasicBlock` to avoid accidental confusion with +`CC.BasicBlock`). Forms a single basic block. + +Each `BBlock` has an `ID` (a unique name). This makes it possible to refer to blocks in a +way that does not change when additional `BBlocks` are inserted into a `BBCode`. +This differs from the positional block numbering found in `IRCode`, in which the number +associated to a basic block changes when new blocks are inserted. + +The `n`th line of code in a `BBlock` is associated to `ID` `stmt_ids[n]`, and the `n`th +instruction from `stmts`. + +Note that `PhiNode`s, `GotoIfNot`s, and `GotoNode`s should not appear in a `BBlock` -- +instead an `IDPhiNode`, `IDGotoIfNot`, or `IDGotoNode` should be used. +""" +mutable struct BBlock + id::ID + inst_ids::Vector{ID} + insts::InstVector + function BBlock(id::ID, inst_ids::Vector{ID}, insts::InstVector) + @assert length(inst_ids) == length(insts) + return new(id, inst_ids, insts) + end +end + +""" + const IDInstPair = Tuple{ID, NewInstruction} +""" +const IDInstPair = Tuple{ID,NewInstruction} + +""" + BBlock(id::ID, inst_pairs::Vector{IDInstPair}) + +Convenience constructor -- splits `inst_pairs` into a `Vector{ID}` and `InstVector` in order +to build a `BBlock`. +""" +function BBlock(id::ID, inst_pairs::Vector{IDInstPair}) + return BBlock(id, first.(inst_pairs), last.(inst_pairs)) +end + +Base.length(bb::BBlock) = length(bb.inst_ids) + +Base.copy(bb::BBlock) = BBlock(bb.id, copy(bb.inst_ids), copy(bb.insts)) + +""" + phi_nodes(bb::BBlock)::Tuple{Vector{ID}, Vector{IDPhiNode}} + +Returns all of the `IDPhiNode`s at the start of `bb`, along with their `ID`s. If there are +no `IDPhiNode`s at the start of `bb`, then both vectors will be empty. +""" +function phi_nodes(bb::BBlock) + n_phi_nodes = findlast(x -> x.stmt isa IDPhiNode, bb.insts) + if n_phi_nodes === nothing + n_phi_nodes = 0 + end + return bb.inst_ids[1:n_phi_nodes], bb.insts[1:n_phi_nodes] +end + +""" + Base.insert!(bb::BBlock, n::Int, id::ID, stmt::CC.NewInstruction)::Nothing + +Inserts `stmt` and `id` into `bb` immediately before the `n`th instruction. +""" +function Base.insert!(bb::BBlock, n::Int, id::ID, inst::NewInstruction)::Nothing + insert!(bb.inst_ids, n, id) + insert!(bb.insts, n, inst) + return nothing +end + +""" + terminator(bb::BBlock) + +Returns the terminator associated to `bb`. If the last instruction in `bb` isa +`Terminator` then that is returned, otherwise `nothing` is returned. +""" +terminator(bb::BBlock) = isa(bb.insts[end].stmt, Terminator) ? bb.insts[end].stmt : nothing + +""" + insert_before_terminator!(bb::BBlock, id::ID, inst::NewInstruction)::Nothing + +If the final instruction in `bb` is a `Terminator`, insert `inst` immediately before it. +Otherwise, insert `inst` at the end of the block. +""" +function insert_before_terminator!(bb::BBlock, id::ID, inst::NewInstruction)::Nothing + insert!(bb, length(bb.insts) + (terminator(bb) === nothing ? 1 : 0), id, inst) + return nothing +end + +""" + collect_stmts(bb::BBlock)::Vector{IDInstPair} + +Returns a `Vector` containing the `ID`s and instructions associated to each line in `bb`. +These should be assumed to be ordered. +""" +collect_stmts(bb::BBlock)::Vector{IDInstPair} = collect(zip(bb.inst_ids, bb.insts)) + +""" + BBCode( + blocks::Vector{BBlock} + argtypes::Vector{Any} + sptypes::Vector{CC.VarState} + linetable::Vector{Core.LineInfoNode} + meta::Vector{Expr} + ) + +A `BBCode` is a data structure which is similar to `IRCode`, but adds additional structure. + +In particular, a `BBCode` comprises a sequence of basic blocks (`BBlock`s), each of which +comprise a sequence of statements. Moreover, each `BBlock` has its own unique `ID`, as does +each statment. + +The consequence of this is that new basic blocks can be inserted into a `BBCode`. This is +distinct from `IRCode`, in which to create a new basic block, one must insert additional +statments which you know will create a new basic block -- this is generally quite an +unreliable process, while inserting a new `BBlock` into `BBCode` is entirely predictable. +Furthermore, inserting a new `BBlock` does not change the `ID` associated to the other +blocks, meaning that you can safely assume that references from existing basic block +terminators / phi nodes to other blocks will not be modified by inserting a new basic block. + +Additionally, since each statment in each basic block has its own unique `ID`, new +statments can be inserted without changing references between other blocks. `IRCode` also +has some support for this via its `new_nodes` field, but eventually all statements will be +renamed upon `compact!`ing the `IRCode`, meaning that the name of any given statement will +eventually change. + +Finally, note that the basic blocks in a `BBCode` support the custom `Switch` statement. +This statement is not valid in `IRCode`, and is therefore lowered into a collection of +`GotoIfNot`s and `GotoNode`s when a `BBCode` is converted back into an `IRCode`. +""" +struct BBCode + blocks::Vector{BBlock} + argtypes::Vector{Any} + sptypes::Vector{CC.VarState} + linetable::Vector{Core.LineInfoNode} + meta::Vector{Expr} +end + +""" + BBCode(ir::Union{IRCode, BBCode}, new_blocks::Vector{Block}) + +Make a new `BBCode` whose `blocks` is given by `new_blocks`, and fresh copies are made of +all other fields from `ir`. +""" +function BBCode(ir::Union{IRCode,BBCode}, new_blocks::Vector{BBlock}) + return BBCode( + new_blocks, + CC.copy(ir.argtypes), + CC.copy(ir.sptypes), + CC.copy(ir.linetable), + CC.copy(ir.meta), + ) +end + +# Makes use of the above outer constructor for `BBCode`. +Base.copy(ir::BBCode) = BBCode(ir, copy(ir.blocks)) + +""" + compute_all_successors(ir::BBCode)::Dict{ID, Vector{ID}} + +Compute a map from the `ID` of each `BBlock` in `ir` to its possible successors. +""" +compute_all_successors(ir::BBCode)::Dict{ID,Vector{ID}} = _compute_all_successors(ir.blocks) + +""" + _compute_all_successors(blks::Vector{BBlock})::Dict{ID, Vector{ID}} + +Internal method implementing [`compute_all_successors`](@ref). This method is easier to +construct test cases for because it only requires the collection of `BBlocks`, not all of +the other stuff that goes into a `BBCode`. +""" +@noinline function _compute_all_successors(blks::Vector{BBlock})::Dict{ID,Vector{ID}} + succs = map(enumerate(blks)) do (n, blk) + is_final_block = n == length(blks) + t = terminator(blk) + if t === nothing + return is_final_block ? ID[] : ID[blks[n + 1].id] + elseif t isa IDGotoNode + return [t.label] + elseif t isa IDGotoIfNot + return is_final_block ? ID[t.dest] : ID[t.dest, blks[n + 1].id] + elseif t isa ReturnNode + return ID[] + elseif t isa Switch + return vcat(t.dests, t.fallthrough_dest) + else + error("Unhandled terminator $t") + end + end + return Dict{ID,Vector{ID}}((b.id, succ) for (b, succ) in zip(blks, succs)) +end + +""" + compute_all_predecessors(ir::BBCode)::Dict{ID, Vector{ID}} + +Compute a map from the `ID` of each `BBlock` in `ir` to its possible predecessors. +""" +function compute_all_predecessors(ir::BBCode)::Dict{ID,Vector{ID}} + return _compute_all_predecessors(ir.blocks) +end + +""" + _compute_all_predecessors(blks::Vector{BBlock})::Dict{ID, Vector{ID}} + +Internal method implementing [`compute_all_predecessors`](@ref). This method is easier to +construct test cases for because it only requires the collection of `BBlocks`, not all of +the other stuff that goes into a `BBCode`. +""" +function _compute_all_predecessors(blks::Vector{BBlock})::Dict{ID,Vector{ID}} + successor_map = _compute_all_successors(blks) + + # Initialise predecessor map to be empty. + ks = collect(keys(successor_map)) + predecessor_map = Dict{ID,Vector{ID}}(zip(ks, map(_ -> ID[], ks))) + + # Find all predecessors by iterating through the successor map. + for (k, succs) in successor_map + for succ in succs + push!(predecessor_map[succ], k) + end + end + + return predecessor_map +end + +""" + collect_stmts(ir::BBCode)::Vector{IDInstPair} + +Produce a `Vector` containing all of the statements in `ir`. These are returned in +order, so it is safe to assume that element `n` refers to the `nth` element of the `IRCode` +associated to `ir`. +""" +collect_stmts(ir::BBCode)::Vector{IDInstPair} = reduce(vcat, map(collect_stmts, ir.blocks)) + +""" + id_to_line_map(ir::BBCode) + +Produces a `Dict` mapping from each `ID` associated with a line in `ir` to its line number. +This is isomorphic to mapping to its `SSAValue` in `IRCode`. Terminators do not have `ID`s +associated to them, so not every line in the original `IRCode` is mapped to. +""" +function id_to_line_map(ir::BBCode) + lines = collect_stmts(ir) + lines_and_line_numbers = collect(zip(lines, eachindex(lines))) + ids_and_line_numbers = map(x -> (x[1][1], x[2]), lines_and_line_numbers) + return Dict(ids_and_line_numbers) +end + +concatenate_ids(bb_code::BBCode) = reduce(vcat, map(b -> b.inst_ids, bb_code.blocks)) +concatenate_stmts(bb_code::BBCode) = reduce(vcat, map(b -> b.insts, bb_code.blocks)) + +""" + control_flow_graph(bb_code::BBCode)::Core.Compiler.CFG + +Computes the `Core.Compiler.CFG` object associated to this `bb_code`. +""" +control_flow_graph(bb_code::BBCode)::Core.Compiler.CFG = _control_flow_graph(bb_code.blocks) + +""" + _control_flow_graph(blks::Vector{BBlock})::Core.Compiler.CFG + +Internal function, used to implement [`control_flow_graph`](@ref). Easier to write test +cases for because there is no need to construct an ensure BBCode object, just the `BBlock`s. +""" +function _control_flow_graph(blks::Vector{BBlock})::Core.Compiler.CFG + + # Get IDs of predecessors and successors. + preds_ids = _compute_all_predecessors(blks) + succs_ids = _compute_all_successors(blks) + + # Construct map from block ID to block number. + block_ids = map(b -> b.id, blks) + id_to_num = Dict{ID,Int}(zip(block_ids, collect(eachindex(block_ids)))) + + # Convert predecessor and successor IDs to numbers. + preds = map(id -> sort(map(p -> id_to_num[p], preds_ids[id])), block_ids) + succs = map(id -> sort(map(s -> id_to_num[s], succs_ids[id])), block_ids) + + index = vcat(0, cumsum(map(length, blks))) .+ 1 + basic_blocks = map(eachindex(blks)) do n + stmt_range = Core.Compiler.StmtRange(index[n], index[n + 1] - 1) + return Core.Compiler.BasicBlock(stmt_range, preds[n], succs[n]) + end + return Core.Compiler.CFG(basic_blocks, index[2:(end - 1)]) +end + +""" + _instructions_to_blocks(insts::InstVector, cfg::CC.CFG)::InstVector + +Pulls out the instructions from `insts`, and calls `__line_numbers_to_block_numbers!`. +""" +function _lines_to_blocks(insts::InstVector, cfg::CC.CFG)::InstVector + stmts = __line_numbers_to_block_numbers!(Any[x.stmt for x in insts], cfg) + return map((inst, stmt) -> NewInstruction(inst; stmt), insts, stmts) +end + +""" + __line_numbers_to_block_numbers!(insts::Vector{Any}, cfg::CC.CFG) + +Converts any edges in `GotoNode`s, `GotoIfNot`s, `PhiNode`s, and `:enter` expressions which +refer to line numbers into references to block numbers. The `cfg` provides the information +required to perform this conversion. + +For context, `CodeInfo` objects have references to line numbers, while `IRCode` uses +block numbers. + +This code is copied over directly from the body of `Core.Compiler.inflate_ir!`. +""" +function __line_numbers_to_block_numbers!(insts::Vector{Any}, cfg::CC.CFG) + for i in eachindex(insts) + stmt = insts[i] + if isa(stmt, GotoNode) + insts[i] = GotoNode(CC.block_for_inst(cfg, stmt.label)) + elseif isa(stmt, GotoIfNot) + insts[i] = GotoIfNot(stmt.cond, CC.block_for_inst(cfg, stmt.dest)) + elseif isa(stmt, PhiNode) + insts[i] = PhiNode( + Int32[CC.block_for_inst(cfg, Int(edge)) for edge in stmt.edges], stmt.values + ) + elseif Meta.isexpr(stmt, :enter) + stmt.args[1] = CC.block_for_inst(cfg, stmt.args[1]::Int) + insts[i] = stmt + end + end + return insts +end + +# +# Converting from IRCode to BBCode +# + +""" + BBCode(ir::IRCode) + +Convert an `ir` into a `BBCode`. Creates a completely independent data structure, so +mutating the `BBCode` returned will not mutate `ir`. + +All `PhiNode`s, `GotoIfNot`s, and `GotoNode`s will be replaced with the `IDPhiNode`s, +`IDGotoIfNot`s, and `IDGotoNode`s respectively. + +See `IRCode` for conversion back to `IRCode`. + +Note that `IRCode(BBCode(ir))` should be equal to the identity function. +""" +function BBCode(ir::IRCode) + + # Produce a new set of statements with `IDs` rather than `SSAValues` and block numbers. + insts = new_inst_vec(ir.stmts) + ssa_ids, stmts = _ssas_to_ids(insts) + block_ids, stmts = _block_nums_to_ids(stmts, ir.cfg) + + # Chop up the new statements into `BBlocks`, according to the `CFG` in `ir`. + blocks = map(zip(ir.cfg.blocks, block_ids)) do (bb, id) + return BBlock(id, ssa_ids[bb.stmts], stmts[bb.stmts]) + end + return BBCode(ir, blocks) +end + +""" + new_inst_vec(x::CC.InstructionStream) + +Convert an `Compiler.InstructionStream` into a list of `Compiler.NewInstruction`s. +""" +function new_inst_vec(x::CC.InstructionStream) + stmt = @static VERSION < v"1.11.0-rc4" ? x.inst : x.stmt + return map((v...,) -> NewInstruction(v...), stmt, x.type, x.info, x.line, x.flag) +end + +# Maps from positional names (SSAValues for nodes, Integers for basic blocks) to IDs. +const SSAToIdDict = Dict{SSAValue,ID} +const BlockNumToIdDict = Dict{Integer,ID} + +""" + _ssas_to_ids(insts::InstVector)::Tuple{Vector{ID}, InstVector} + +Assigns an ID to each line in `stmts`, and replaces each instance of an `SSAValue` in each +line with the corresponding `ID`. For example, a call statement of the form +`Expr(:call, :f, %4)` is be replaced with `Expr(:call, :f, id_assigned_to_%4)`. +""" +function _ssas_to_ids(insts::InstVector)::Tuple{Vector{ID},InstVector} + ids = map(_ -> ID(), insts) + val_id_map = SSAToIdDict(zip(SSAValue.(eachindex(insts)), ids)) + return ids, map(Base.Fix1(_ssa_to_ids, val_id_map), insts) +end + +""" + _ssa_to_ids(d::SSAToIdDict, inst::NewInstruction) + +Produce a new instance of `inst` in which all instances of `SSAValue`s are replaced with +the `ID`s prescribed by `d`, all basic block numbers are replaced with the `ID`s +prescribed by `d`, and `GotoIfNot`, `GotoNode`, and `PhiNode` instances are replaced with +the corresponding `ID` versions. +""" +function _ssa_to_ids(d::SSAToIdDict, inst::NewInstruction) + return NewInstruction(inst; stmt=_ssa_to_ids(d, inst.stmt)) +end +function _ssa_to_ids(d::SSAToIdDict, x::ReturnNode) + return isdefined(x, :val) ? ReturnNode(get(d, x.val, x.val)) : x +end +_ssa_to_ids(d::SSAToIdDict, x::Expr) = Expr(x.head, map(a -> get(d, a, a), x.args)...) +_ssa_to_ids(d::SSAToIdDict, x::PiNode) = PiNode(get(d, x.val, x.val), get(d, x.typ, x.typ)) +_ssa_to_ids(d::SSAToIdDict, x::QuoteNode) = x +_ssa_to_ids(d::SSAToIdDict, x) = x +function _ssa_to_ids(d::SSAToIdDict, x::PhiNode) + new_values = Vector{Any}(undef, length(x.values)) + for n in eachindex(x.values) + if isassigned(x.values, n) + new_values[n] = get(d, x.values[n], x.values[n]) + end + end + return PhiNode(x.edges, new_values) +end +_ssa_to_ids(d::SSAToIdDict, x::GotoNode) = x +_ssa_to_ids(d::SSAToIdDict, x::GotoIfNot) = GotoIfNot(get(d, x.cond, x.cond), x.dest) + +""" + _block_nums_to_ids(insts::InstVector, cfg::CC.CFG)::Tuple{Vector{ID}, InstVector} + +Assign to each basic block in `cfg` an `ID`. Replace all integers referencing block numbers +in `insts` with the corresponding `ID`. Return the `ID`s and the updated instructions. +""" +function _block_nums_to_ids(insts::InstVector, cfg::CC.CFG)::Tuple{Vector{ID},InstVector} + ids = map(_ -> ID(), cfg.blocks) + block_num_id_map = BlockNumToIdDict(zip(eachindex(cfg.blocks), ids)) + return ids, map(Base.Fix1(_block_num_to_ids, block_num_id_map), insts) +end + +function _block_num_to_ids(d::BlockNumToIdDict, x::NewInstruction) + return NewInstruction(x; stmt=_block_num_to_ids(d, x.stmt)) +end +function _block_num_to_ids(d::BlockNumToIdDict, x::PhiNode) + return IDPhiNode(ID[d[e] for e in x.edges], x.values) +end +_block_num_to_ids(d::BlockNumToIdDict, x::GotoNode) = IDGotoNode(d[x.label]) +_block_num_to_ids(d::BlockNumToIdDict, x::GotoIfNot) = IDGotoIfNot(x.cond, d[x.dest]) +_block_num_to_ids(d::BlockNumToIdDict, x) = x + +# +# Converting from BBCode to IRCode +# + +""" + IRCode(bb_code::BBCode) + +Produce an `IRCode` instance which is equivalent to `bb_code`. The resulting `IRCode` +shares no memory with `bb_code`, so can be safely mutated without modifying `bb_code`. + +All `IDPhiNode`s, `IDGotoIfNot`s, and `IDGotoNode`s are converted into `PhiNode`s, +`GotoIfNot`s, and `GotoNode`s respectively. + +In the resulting `bb_code`, any `Switch` nodes are lowered into a semantically-equivalent +collection of `GotoIfNot` nodes. +""" +function CC.IRCode(bb_code::BBCode) + bb_code = _lower_switch_statements(bb_code) + bb_code = _remove_double_edges(bb_code) + insts = _ids_to_line_numbers(bb_code) + cfg = control_flow_graph(bb_code) + insts = _lines_to_blocks(insts, cfg) + return IRCode( + CC.InstructionStream( + map(x -> x.stmt, insts), + map(x -> x.type, insts), + map(x -> x.info, insts), + map(x -> x.line, insts), + map(x -> x.flag, insts), + ), + cfg, + CC.copy(bb_code.linetable), + CC.copy(bb_code.argtypes), + CC.copy(bb_code.meta), + CC.copy(bb_code.sptypes), + ) +end + +""" + _lower_switch_statements(bb_code::BBCode) + +Converts all `Switch`s into a semantically-equivalent collection of `GotoIfNot`s. See the +`Switch` docstring for an explanation of what is going on here. +""" +function _lower_switch_statements(bb_code::BBCode) + new_blocks = Vector{BBlock}(undef, 0) + for block in bb_code.blocks + t = terminator(block) + if t isa Switch + + # Create new block without the `Switch`. + bb = BBlock(block.id, block.inst_ids[1:(end - 1)], block.insts[1:(end - 1)]) + push!(new_blocks, bb) + + # Create new blocks for each `GotoIfNot` from the `Switch`. + foreach(t.conds, t.dests) do cond, dest + blk = BBlock(ID(), [ID()], [new_inst(IDGotoIfNot(cond, dest), Any)]) + push!(new_blocks, blk) + end + + # Create a new block for the fallthrough dest. + fallthrough_inst = new_inst(IDGotoNode(t.fallthrough_dest), Any) + push!(new_blocks, BBlock(ID(), [ID()], [fallthrough_inst])) + else + push!(new_blocks, block) + end + end + return BBCode(bb_code, new_blocks) +end + +""" + _ids_to_line_numbers(bb_code::BBCode)::InstVector + +For each statement in `bb_code`, returns a `NewInstruction` in which every `ID` is replaced +by either an `SSAValue`, or an `Int64` / `Int32` which refers to an `SSAValue`. +""" +function _ids_to_line_numbers(bb_code::BBCode)::InstVector + + # Construct map from `ID`s to `SSAValue`s. + block_ids = [b.id for b in bb_code.blocks] + block_lengths = map(length, bb_code.blocks) + block_start_ssas = SSAValue.(vcat(1, cumsum(block_lengths)[1:(end - 1)] .+ 1)) + line_ids = concatenate_ids(bb_code) + line_ssas = SSAValue.(eachindex(line_ids)) + id_to_ssa_map = Dict(zip(vcat(block_ids, line_ids), vcat(block_start_ssas, line_ssas))) + + # Apply map. + return [_to_ssas(id_to_ssa_map, stmt) for stmt in concatenate_stmts(bb_code)] +end + +""" + _to_ssas(d::Dict, inst::NewInstruction) + +Like `_ssas_to_ids`, but in reverse. Converts IDs to SSAValues / (integers corresponding +to ssas). +""" +_to_ssas(d::Dict, inst::NewInstruction) = NewInstruction(inst; stmt=_to_ssas(d, inst.stmt)) +_to_ssas(d::Dict, x::ReturnNode) = isdefined(x, :val) ? ReturnNode(get(d, x.val, x.val)) : x +_to_ssas(d::Dict, x::Expr) = Expr(x.head, map(a -> get(d, a, a), x.args)...) +_to_ssas(d::Dict, x::PiNode) = PiNode(get(d, x.val, x.val), get(d, x.typ, x.typ)) +_to_ssas(d::Dict, x::QuoteNode) = x +_to_ssas(d::Dict, x) = x +function _to_ssas(d::Dict, x::IDPhiNode) + new_values = Vector{Any}(undef, length(x.values)) + for n in eachindex(x.values) + if isassigned(x.values, n) + new_values[n] = get(d, x.values[n], x.values[n]) + end + end + return PhiNode(map(e -> Int32(getindex(d, e).id), x.edges), new_values) +end +_to_ssas(d::Dict, x::IDGotoNode) = GotoNode(d[x.label].id) +_to_ssas(d::Dict, x::IDGotoIfNot) = GotoIfNot(get(d, x.cond, x.cond), d[x.dest].id) + +""" + _remove_double_edges(ir::BBCode)::BBCode + +If the `dest` field of an `IDGotoIfNot` node in block `n` of `ir` points towards the `n+1`th +block then we have two edges from block `n` to block `n+1`. This transformation replaces all +such `IDGotoIfNot` nodes with unconditional `IDGotoNode`s pointing towards the `n+1`th block +in `ir`. +""" +function _remove_double_edges(ir::BBCode) + new_blks = map(enumerate(ir.blocks)) do (n, blk) + t = terminator(blk) + if t isa IDGotoIfNot && t.dest == ir.blocks[n + 1].id + new_insts = vcat(blk.insts[1:(end - 1)], NewInstruction(t; stmt=IDGotoNode(t.dest))) + return BBlock(blk.id, blk.inst_ids, new_insts) + else + return blk + end + end + return BBCode(ir, new_blks) +end + +""" + _build_graph_of_cfg(blks::Vector{BBlock})::Tuple{SimpleDiGraph, Dict{ID, Int}} + +Builds a `SimpleDiGraph`, `g`, representing of the CFG associated to `blks`, where `blks` +comprises the collection of basic blocks associated to a `BBCode`. +This is a type from Graphs.jl, so constructing `g` makes it straightforward to analyse the +control flow structure of `ir` using algorithms from Graphs.jl. + +Returns a 2-tuple, whose first element is `g`, and whose second element is a map from +the `ID` associated to each basic block in `ir`, to the `Int` corresponding to its node +index in `g`. +""" +function _build_graph_of_cfg(blks::Vector{BBlock})::Tuple{SimpleDiGraph,Dict{ID,Int}} + node_ints = collect(eachindex(blks)) + id_to_int = Dict(zip(map(blk -> blk.id, blks), node_ints)) + successors = _compute_all_successors(blks) + g = SimpleDiGraph(length(blks)) + for blk in blks, successor in successors[blk.id] + add_edge!(g, id_to_int[blk.id], id_to_int[successor]) + end + return g, id_to_int +end + +""" + _distance_to_entry(blks::Vector{BBlock})::Vector{Int} + +For each basic block in `blks`, compute the distance from it to the entry point (the first +block. The distance is `typemax(Int)` if no path from the entry point to a given node. +""" +function _distance_to_entry(blks::Vector{BBlock})::Vector{Int} + g, id_to_int = _build_graph_of_cfg(blks) + return dijkstra_shortest_paths(g, id_to_int[blks[1].id]).dists +end + +""" + sort_blocks!(ir::BBCode)::BBCode + +Ensure that blocks appear in order of distance-from-entry-point, where distance the +distance from block b to the entry point is defined to be the minimum number of basic +blocks that must be passed through in order to reach b. + +For reasons unknown (to me, Will), the compiler / optimiser needs this for inference to +succeed. Since we do quite a lot of re-ordering on the reverse-pass of AD, this is a problem +there. + +WARNING: use with care. Only use if you are confident that arbitrary re-ordering of basic +blocks in `ir` is valid. Notably, this does not hold if you have any `IDGotoIfNot` nodes in +`ir`. +""" +function sort_blocks!(ir::BBCode)::BBCode + I = sortperm(_distance_to_entry(ir.blocks)) + ir.blocks .= ir.blocks[I] + return ir +end + +""" + characterise_unique_predecessor_blocks(blks::Vector{BBlock}) -> + Tuple{Dict{ID, Bool}, Dict{ID, Bool}} + +We call a block `b` a _unique_ _predecessor_ in the control flow graph associated to `blks` +if it is the only predecessor to all of its successors. Put differently we call `b` a unique +predecessor if, whenever control flow arrives in any of the successors of `b`, we know for +certain that the previous block must have been `b`. + +Returns two `Dict`s. A value in the first `Dict` is `true` if the block associated to its +key is a unique precessor, and is `false` if not. A value in the second `Dict` is `true` if +it has a single predecessor, and that predecessor is a unique predecessor. + +*Context*: + +This information is important for optimising AD because knowing that `b` is a unique +predecessor means that +1. on the forwards-pass, there is no need to push the ID of `b` to the block stack when + passing through it, and +2. on the reverse-pass, there is no need to pop the block stack when passing through one of + the successors to `b`. + +Utilising this reduces the overhead associated to doing AD. It is quite important when +working with cheap loops -- loops where the operations performed at each iteration +are inexpensive -- for which minimising memory pressure is critical to performance. It is +also important for single-block functions, because it can be used to entirely avoid using a +block stack at all. +""" +function characterise_unique_predecessor_blocks( + blks::Vector{BBlock} +)::Tuple{Dict{ID,Bool},Dict{ID,Bool}} + + # Obtain the block IDs in order -- this ensures that we get the entry block first. + blk_ids = ID[b.id for b in blks] + preds = _compute_all_predecessors(blks) + succs = _compute_all_successors(blks) + + # The bulk of blocks can be hanled by this general loop. + is_unique_pred = Dict{ID,Bool}() + for id in blk_ids + ss = succs[id] + is_unique_pred[id] = !isempty(ss) && all(s -> length(preds[s]) == 1, ss) + end + + # If there is a single reachable return node, then that block is treated as a unique + # pred, since control can only pass "out" of the function via this block. Conversely, + # if there are multiple reachable return nodes, then execution can return to the calling + # function via any of them, so they are not unique predecessors. + # Note that the previous block sets is_unique_pred[id] to false for all nodes which + # end with a reachable return node, so the value only needs changing if there is a + # unique reachable return node. + reachable_return_blocks = filter(blks) do blk + is_reachable_return_node(terminator(blk)) + end + if length(reachable_return_blocks) == 1 + is_unique_pred[only(reachable_return_blocks).id] = true + end + + # pred_is_unique_pred is true if the unique predecessor to a block is a unique pred. + pred_is_unique_pred = Dict{ID,Bool}() + for id in blk_ids + pred_is_unique_pred[id] = length(preds[id]) == 1 && is_unique_pred[only(preds[id])] + end + + # If the entry block has no predecessors, then it can only be entered once, when the + # function is first entered. In this case, we treat it as having a unique predecessor. + entry_id = blk_ids[1] + pred_is_unique_pred[entry_id] = isempty(preds[entry_id]) + + return is_unique_pred, pred_is_unique_pred +end + +""" + is_reachable_return_node(x::ReturnNode) + +Determine whether `x` is a `ReturnNode`, and if it is, if it is also reachable. This is +purely a function of whether or not its `val` field is defined or not. +""" +is_reachable_return_node(x::ReturnNode) = isdefined(x, :val) +is_reachable_return_node(x) = false + +""" + characterise_used_ids(stmts::Vector{IDInstPair})::Dict{ID, Bool} + +For each line in `stmts`, determine whether it is referenced anywhere else in the code. +Returns a dictionary containing the results. An element is `false` if the corresponding +`ID` is unused, and `true` if is used. +""" +function characterise_used_ids(stmts::Vector{IDInstPair})::Dict{ID,Bool} + ids = first.(stmts) + insts = last.(stmts) + + # Initialise to false. + is_used = Dict{ID,Bool}(zip(ids, fill(false, length(ids)))) + + # Hunt through the instructions, flipping a value in is_used to true whenever an ID + # is encountered which corresponds to an SSA. + for inst in insts + _find_id_uses!(is_used, inst.stmt) + end + return is_used +end + +""" + _find_id_uses!(d::Dict{ID, Bool}, x) + +Helper function used in [`characterise_used_ids`](@ref). For all uses of `ID`s in `x`, set +the corresponding value of `d` to `true`. + +For example, if `x = ReturnNode(ID(5))`, then this function sets `d[ID(5)] = true`. +""" +function _find_id_uses!(d::Dict{ID,Bool}, x::Expr) + for arg in x.args + in(arg, keys(d)) && setindex!(d, true, arg) + end +end +function _find_id_uses!(d::Dict{ID,Bool}, x::IDGotoIfNot) + return in(x.cond, keys(d)) && setindex!(d, true, x.cond) +end +_find_id_uses!(::Dict{ID,Bool}, ::IDGotoNode) = nothing +function _find_id_uses!(d::Dict{ID,Bool}, x::PiNode) + return in(x.val, keys(d)) && setindex!(d, true, x.val) +end +function _find_id_uses!(d::Dict{ID,Bool}, x::IDPhiNode) + v = x.values + for n in eachindex(v) + isassigned(v, n) && in(v[n], keys(d)) && setindex!(d, true, v[n]) + end +end +function _find_id_uses!(d::Dict{ID,Bool}, x::ReturnNode) + return isdefined(x, :val) && in(x.val, keys(d)) && setindex!(d, true, x.val) +end +_find_id_uses!(d::Dict{ID,Bool}, x::QuoteNode) = nothing +_find_id_uses!(d::Dict{ID,Bool}, x) = nothing + +""" + _is_reachable(blks::Vector{BBlock})::Vector{Bool} + +Computes a `Vector` whose length is `length(blks)`. The `n`th element is `true` iff it is +possible for control flow to reach the `n`th block. +""" +_is_reachable(blks::Vector{BBlock})::Vector{Bool} = _distance_to_entry(blks) .< typemax(Int) + +""" + remove_unreachable_blocks!(ir::BBCode)::BBCode + +If a basic block in `ir` cannot possibly be reached during execution, then it can be safely +removed from `ir` without changing its functionality. +A block is unreachable if either: +1. it has no predecessors _and_ it is not the first block, or +2. all of its predecessors are themselves unreachable. + +For example, consider the following IR: +```jldoctest remove_unreachable_blocks +julia> ir = ircode( + Any[Core.ReturnNode(nothing), Expr(:call, sin, 5), Core.ReturnNode(Core.SSAValue(2))], + Any[Any, Any, Any], + ); +``` +There is no possible way to reach the second basic block (lines 2 and 3). Applying this +function will therefore remove it, yielding the following: +```jldoctest remove_unreachable_blocks +julia> IRCode(remove_unreachable_blocks!(BBCode(ir))) +1 1 ─ return nothing +``` + +In the blocks which have not been removed, there may be references to blocks which have been +removed. For example, the `edge`s in a `PhiNode` may contain a reference to a removed block. +These references are removed in-place from these remaining blocks, so this function will (in +general) modify `ir`. +""" +remove_unreachable_blocks!(ir::BBCode) = BBCode(ir, _remove_unreachable_blocks!(ir.blocks)) + +function _remove_unreachable_blocks!(blks::Vector{BBlock}) + + # Figure out which blocks are reachable. + is_reachable = _is_reachable(blks) + + # Collect all blocks which are reachable. + remaining_blks = blks[is_reachable] + + # For each reachable block, remove any references to removed blocks. These can appear in + # `PhiNode`s with edges that come from remove blocks. + removed_block_ids = map(idx -> blks[idx].id, findall(!, is_reachable)) + for blk in remaining_blks, inst in blk.insts + stmt = inst.stmt + stmt isa IDPhiNode || continue + for n in reverse(1:length(stmt.edges)) + if stmt.edges[n] in removed_block_ids + deleteat!(stmt.edges, n) + deleteat!(stmt.values, n) + end + end + end + + return remaining_blks +end + +end diff --git a/src/copyable_task.jl b/src/copyable_task.jl index 947eafd9..fd74a664 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -45,9 +45,9 @@ function build_callable(sig::Type{<:Tuple}) ir = Base.code_ircode_by_type(sig)[1][1] bb, refs, types = derive_copyable_task_ir(BBCode(ir)) unoptimised_ir = IRCode(bb) - optimised_ir = Mooncake.optimise_ir!(unoptimised_ir) + optimised_ir = optimise_ir!(unoptimised_ir) mc_ret_type = callable_ret_type(sig, types) - mc = Mooncake.misty_closure(mc_ret_type, optimised_ir, refs...; do_compile=true) + mc = misty_closure(mc_ret_type, optimised_ir, refs...; do_compile=true) mc_cache[key] = mc return mc, refs[end] end @@ -191,7 +191,7 @@ function TapedTask(taped_globals::Any, fargs...; kwargs...) end function fresh_copy(mc::T) where {T<:MistyClosure} - new_captures = Mooncake.tuple_map(mc.oc.captures) do r + new_captures = map(mc.oc.captures) do r if eltype(r) <: DynamicCallable return Base.RefValue(DynamicCallable()) elseif eltype(r) <: LazyCallable @@ -202,7 +202,7 @@ function fresh_copy(mc::T) where {T<:MistyClosure} end new_position = new_captures[end] new_position[] = -1 - return Mooncake.replace_captures(mc, new_captures), new_position + return replace_captures(mc, new_captures), new_position end """ @@ -538,7 +538,7 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple,Vector{Any}} new_bblocks = map(zip(ir.blocks, all_splits, all_split_ids)) do (bb, splits, splits_ids) new_blocks = map(enumerate(splits)) do (n, split) # We'll push ID-NewInstruction pairs to this as we proceed through the split. - inst_pairs = Mooncake.IDInstPair[] + inst_pairs = IDInstPair[] # PhiNodes: # @@ -839,7 +839,7 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple,Vector{Any}} # Find any `ID`s and replace them with calls to read whatever is stored # in the `Ref`s that they are associated to. - callable_inst_pairs = Mooncake.IDInstPair[] + callable_inst_pairs = IDInstPair[] for (n, arg) in enumerate(callable_args) arg isa ID || continue @@ -891,7 +891,7 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple,Vector{Any}} push!(resume_block_ids, callable_block_id) set_res = Expr(:call, set_resume_block!, refs_id, callable_block_id.id) return_id = ID() - produced_block_inst_pairs = Mooncake.IDInstPair[ + produced_block_inst_pairs = IDInstPair[ (ID(), new_inst(set_res)), (return_id, new_inst(ReturnNode(result_id))), ] @@ -907,7 +907,7 @@ function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple,Vector{Any}} else set_ref = nothing end - not_produced_block_inst_pairs = Mooncake.IDInstPair[ + not_produced_block_inst_pairs = IDInstPair[ (ID(), new_inst(set_ref)) (ID(), new_inst(IDGotoNode(next_block_id))) ] @@ -1006,7 +1006,7 @@ end DynamicCallable() = DynamicCallable(Dict{Any,Any}()) function (dynamic_callable::DynamicCallable)(args::Vararg{Any,N}) where {N} - sig = Mooncake._typeof(args) + sig = _typeof(args) callable = get(dynamic_callable.cache, sig, nothing) if callable === nothing callable = build_callable(sig) diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 00000000..e4449657 --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,204 @@ + +""" + replace_captures(oc::Toc, new_captures) where {Toc<:OpaqueClosure} + +Given an `OpaqueClosure` `oc`, create a new `OpaqueClosure` of the same type, but with new +captured variables. This is needed for efficiency reasons -- if `build_rrule` is called +repeatedly with the same signature and intepreter, it is important to avoid recompiling +the `OpaqueClosure`s that it produces multiple times, because it can be quite expensive to +do so. +""" +function replace_captures(oc::Toc, new_captures) where {Toc<:OpaqueClosure} + return __replace_captures_internal(oc, new_captures) +end + +@eval function __replace_captures_internal(oc::Toc, new_captures) where {Toc<:OpaqueClosure} + return $(Expr( + :new, :(Toc), :new_captures, :(oc.world), :(oc.source), :(oc.invoke), :(oc.specptr) + )) +end + +""" + replace_captures(mc::Tmc, new_captures) where {Tmc<:MistyClosure} + +Same as `replace_captures` for `Core.OpaqueClosure`s, but returns a new `MistyClosure`. +""" +function replace_captures(mc::Tmc, new_captures) where {Tmc<:MistyClosure} + return Tmc(replace_captures(mc.oc, new_captures), mc.ir) +end + +""" + optimise_ir!(ir::IRCode, show_ir=false) + +Run a fairly standard optimisation pass on `ir`. If `show_ir` is `true`, displays the IR +to `stdout` at various points in the pipeline -- this is sometimes useful for debugging. +""" +function optimise_ir!(ir::IRCode; show_ir=false, do_inline=true) + if show_ir + println("Pre-optimization") + display(ir) + println() + end + CC.verify_ir(ir) + ir = __strip_coverage!(ir) + ir = CC.compact!(ir) + local_interp = CC.NativeInterpreter() + # local_interp = BugPatchInterpreter() # 319 -- see patch_for_319.jl for context + mi = __get_toplevel_mi_from_ir(ir, @__MODULE__) + ir = __infer_ir!(ir, local_interp, mi) + if show_ir + println("Post-inference") + display(ir) + println() + end + inline_state = CC.InliningState(local_interp) + CC.verify_ir(ir) + if do_inline + ir = CC.ssa_inlining_pass!(ir, inline_state, true) #=propagate_inbounds=# + ir = CC.compact!(ir) + end + ir = __strip_coverage!(ir) + ir = CC.sroa_pass!(ir, inline_state) + + @static if VERSION < v"1.11-" + ir = CC.adce_pass!(ir, inline_state) + else + ir, _ = CC.adce_pass!(ir, inline_state) + end + + ir = CC.compact!(ir) + # CC.verify_ir(ir, true, false, CC.optimizer_lattice(local_interp)) + CC.verify_linetable(ir.linetable, true) + if show_ir + println("Post-optimization") + display(ir) + println() + end + return ir +end + +@static if VERSION < v"1.11.0" + get_inference_world(interp::CC.AbstractInterpreter) = CC.get_world_counter(interp) +else + get_inference_world(interp::CC.AbstractInterpreter) = CC.get_inference_world(interp) +end + +# Given some IR, generates a MethodInstance suitable for passing to infer_ir!, if you don't +# already have one with the right argument types. Credit to @oxinabox: +# https://gist.github.com/oxinabox/cdcffc1392f91a2f6d80b2524726d802#file-example-jl-L54 +function __get_toplevel_mi_from_ir(ir, _module::Module) + mi = ccall(:jl_new_method_instance_uninit, Ref{Core.MethodInstance}, ()) + mi.specTypes = Tuple{map(CC.widenconst, ir.argtypes)...} + mi.def = _module + return mi +end + +# Run type inference and constant propagation on the ir. Credit to @oxinabox: +# https://gist.github.com/oxinabox/cdcffc1392f91a2f6d80b2524726d802#file-example-jl-L54 +function __infer_ir!(ir, interp::CC.AbstractInterpreter, mi::CC.MethodInstance) + method_info = CC.MethodInfo(true, nothing) #=propagate_inbounds=# + min_world = world = get_inference_world(interp) + max_world = Base.get_world_counter() + irsv = CC.IRInterpretationState( + interp, method_info, ir, mi, ir.argtypes, world, min_world, max_world + ) + rt = CC._ir_abstract_constant_propagation(interp, irsv) + return ir +end + +# In automatically generated code, it is meaningless to include code coverage effects. +# Moreover, it seems to cause some serious inference probems. Consequently, it makes sense +# to remove such effects before optimising IRCode. +function __strip_coverage!(ir::IRCode) + for n in eachindex(stmt(ir.stmts)) + if Meta.isexpr(stmt(ir.stmts)[n], :code_coverage_effect) + stmt(ir.stmts)[n] = nothing + end + end + return ir +end + +stmt(ir::CC.InstructionStream) = @static VERSION < v"1.11.0-rc4" ? ir.inst : ir.stmt + +""" + opaque_closure( + ret_type::Type, + ir::IRCode, + @nospecialize env...; + isva::Bool=false, + do_compile::Bool=true, + )::Core.OpaqueClosure{<:Tuple, ret_type} + +Construct a `Core.OpaqueClosure`. Almost equivalent to +`Core.OpaqueClosure(ir, env...; isva, do_compile)`, but instead of letting +`Core.compute_oc_rettype` figure out the return type from `ir`, impose `ret_type` as the +return type. + +# Warning + +User beware: if the `Core.OpaqueClosure` produced by this function ever returns anything +which is not an instance of a subtype of `ret_type`, you should expect all kinds of awful +things to happen, such as segfaults. You have been warned! + +# Extended Help + +This is needed because we make extensive use of our ability to know the return +type of a couple of specific `OpaqueClosure`s without actually having constructed them. +Without the capability to specify the return type, we have to guess what type +`compute_ir_rettype` will return for a given `IRCode` before we have constructed +the `IRCode` and run type inference on it. This exposes us to details of type inference, +which are not part of the public interface of the language, and can therefore vary from +Julia version to Julia version (including patch versions). Moreover, even for a fixed Julia +version it can be extremely hard to predict exactly what type inference will infer to be the +return type of a function. + +Failing to correctly guess the return type can happen for a number of reasons, and the kinds +of errors that tend to be generated when this fails tell you very little about the +underlying cause of the problem. + +By specifying the return type ourselves, we remove this dependence. The price we pay for +this is the potential for segfaults etc if we fail to specify `ret_type` correctly. +""" +function opaque_closure( + ret_type::Type, + ir::IRCode, + @nospecialize env...; + isva::Bool=false, + do_compile::Bool=true, +) + # This implementation is copied over directly from `Core.OpaqueClosure`. + ir = CC.copy(ir) + nargs = length(ir.argtypes) - 1 + sig = Base.Experimental.compute_oc_signature(ir, nargs, isva) + src = ccall(:jl_new_code_info_uninit, Ref{CC.CodeInfo}, ()) + src.slotnames = fill(:none, nargs + 1) + src.slotflags = fill(zero(UInt8), length(ir.argtypes)) + src.slottypes = copy(ir.argtypes) + src.rettype = ret_type + src = CC.ir_to_codeinf!(src, ir) + return Base.Experimental.generate_opaque_closure( + sig, Union{}, ret_type, src, nargs, isva, env...; do_compile + )::Core.OpaqueClosure{sig,ret_type} +end + +""" + misty_closure( + ret_type::Type, + ir::IRCode, + @nospecialize env...; + isva::Bool=false, + do_compile::Bool=true, + ) + +Identical to [`opaque_closure`](@ref), but returns a `MistyClosure` closure rather +than a `Core.OpaqueClosure`. +""" +function misty_closure( + ret_type::Type, + ir::IRCode, + @nospecialize env...; + isva::Bool=false, + do_compile::Bool=true, +) + return MistyClosure(opaque_closure(ret_type, ir, env...; isva, do_compile), Ref(ir)) +end From bc480c287dcc103cc48e029b28225fe35fa631b2 Mon Sep 17 00:00:00 2001 From: Will Tebbutt <3628294+willtebbutt@users.noreply.github.com> Date: Tue, 15 Apr 2025 13:45:23 +0100 Subject: [PATCH 63/69] Tidy up further --- Project.toml | 2 - docs/src/internals.md | 5 + src/bbcode.jl | 481 ------------------------------------------ 3 files changed, 5 insertions(+), 483 deletions(-) diff --git a/Project.toml b/Project.toml index a7a110f6..b150bd08 100644 --- a/Project.toml +++ b/Project.toml @@ -6,13 +6,11 @@ repo = "https://github.com/TuringLang/Libtask.jl.git" version = "0.8.8" [deps] -Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" MistyClosures = "dbe65cb8-6be2-42dd-bbc5-4196aaced4f4" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] Aqua = "0.8.11" -Graphs = "1.12.1" JuliaFormatter = "1.0.62" MistyClosures = "2.0.0" Test = "1" diff --git a/docs/src/internals.md b/docs/src/internals.md index e92ad4f8..ee6848e8 100644 --- a/docs/src/internals.md +++ b/docs/src/internals.md @@ -9,4 +9,9 @@ Libtask.LazyCallable Libtask.inc_args Libtask.get_type Libtask._typeof +Libtask.replace_captures +Libtask.BasicBlockCode +Libtask.opaque_closure +Libtask.misty_closure +Libtask.optimise_ir! ``` diff --git a/src/bbcode.jl b/src/bbcode.jl index 646aa0e9..9d4f45d2 100644 --- a/src/bbcode.jl +++ b/src/bbcode.jl @@ -6,8 +6,6 @@ Refer to Mooncake's developer docs for context on this file. """ module BasicBlockCode -using Graphs - using Core.Compiler: ReturnNode, PhiNode, @@ -36,7 +34,6 @@ export ID, remove_unreachable_blocks!, characterise_used_ids, characterise_unique_predecessor_blocks, - sort_blocks!, InstVector, IDInstPair, __line_numbers_to_block_numbers!, @@ -45,38 +42,12 @@ export ID, const _id_count::Dict{Int,Int32} = Dict{Int,Int32}() -""" - new_inst(stmt, type=Any, flag=CC.IR_FLAG_REFINED)::NewInstruction - -Create a `NewInstruction` with fields: -- `stmt` = `stmt` -- `type` = `type` -- `info` = `CC.NoCallInfo()` -- `line` = `Int32(1)` -- `flag` = `flag` -""" function new_inst(@nospecialize(stmt), @nospecialize(type)=Any, flag=CC.IR_FLAG_REFINED) return NewInstruction(stmt, type, CC.NoCallInfo(), Int32(1), flag) end -""" - const InstVector = Vector{NewInstruction} - -Note: the `CC.NewInstruction` type is used to represent instructions because it has the -correct fields. While it is only used to represent new instrucdtions in `Core.Compiler`, it -is used to represent all instructions in `BBCode`. -""" const InstVector = Vector{NewInstruction} -""" - ID() - -An `ID` (read: unique name) is just a wrapper around an `Int32`. Uniqueness is ensured via a -global counter, which is incremented each time that an `ID` is created. - -This counter can be reset using `seed_id!` if you need to ensure deterministic `ID`s are -produced, in the same way that seed for random number generators can be set. -""" struct ID id::Int32 function ID() @@ -89,23 +60,10 @@ end Base.copy(id::ID) = id -""" - seed_id!() - -Set the global counter used to ensure ID uniqueness to 0. This is useful when you want to -ensure determinism between two runs of the same function which makes use of `ID`s. - -This is akin to setting the random seed associated to a random number generator globally. -""" function seed_id!() return global _id_count[Threads.threadid()] = 0 end -""" - IDPhiNode(edges::Vector{ID}, values::Vector{Any}) - -Like a `PhiNode`, but `edges` are `ID`s rather than `Int32`s. -""" struct IDPhiNode edges::Vector{ID} values::Vector{Any} @@ -115,22 +73,12 @@ Base.:(==)(x::IDPhiNode, y::IDPhiNode) = x.edges == y.edges && x.values == y.val Base.copy(node::IDPhiNode) = IDPhiNode(copy(node.edges), copy(node.values)) -""" - IDGotoNode(label::ID) - -Like a `GotoNode`, but `label` is an `ID` rather than an `Int64`. -""" struct IDGotoNode label::ID end Base.copy(node::IDGotoNode) = IDGotoNode(copy(node.label)) -""" - IDGotoIfNot(cond::Any, dest::ID) - -Like a `GotoIfNot`, but `dest` is an `ID` rather than an `Int64`. -""" struct IDGotoIfNot cond::Any dest::ID @@ -138,26 +86,6 @@ end Base.copy(node::IDGotoIfNot) = IDGotoIfNot(copy(node.cond), copy(node.dest)) -""" - Switch(conds::Vector{Any}, dests::Vector{ID}, fallthrough_dest::ID) - -A switch-statement node. These can be inserted in the `BBCode` representation of Julia IR. -`Switch` has the following semantics: -```julia -goto dests[1] if not conds[1] -goto dests[2] if not conds[2] -... -goto dests[N] if not conds[N] -goto fallthrough_dest -``` -where the value associated to each element of `conds` is a `Bool`, and `dests` indicate -which block to jump to. If none of the conditions are met, then we go to whichever block is -specified by `fallthrough_dest`. - -`Switch` statements are lowered into the above sequence of `GotoIfNot`s and `GotoNode`s -when converting `BBCode` back into `IRCode`, because `Switch` statements are not valid -nodes in regular Julia IR. -""" struct Switch conds::Vector{Any} dests::Vector{ID} @@ -168,30 +96,8 @@ struct Switch end end -""" - Terminator = Union{Switch, IDGotoIfNot, IDGotoNode, ReturnNode} - -A Union of the possible types of a terminator node. -""" const Terminator = Union{Switch,IDGotoIfNot,IDGotoNode,ReturnNode} -""" - BBlock(id::ID, stmt_ids::Vector{ID}, stmts::InstVector) - -A basic block data structure (not called `BasicBlock` to avoid accidental confusion with -`CC.BasicBlock`). Forms a single basic block. - -Each `BBlock` has an `ID` (a unique name). This makes it possible to refer to blocks in a -way that does not change when additional `BBlocks` are inserted into a `BBCode`. -This differs from the positional block numbering found in `IRCode`, in which the number -associated to a basic block changes when new blocks are inserted. - -The `n`th line of code in a `BBlock` is associated to `ID` `stmt_ids[n]`, and the `n`th -instruction from `stmts`. - -Note that `PhiNode`s, `GotoIfNot`s, and `GotoNode`s should not appear in a `BBlock` -- -instead an `IDPhiNode`, `IDGotoIfNot`, or `IDGotoNode` should be used. -""" mutable struct BBlock id::ID inst_ids::Vector{ID} @@ -202,17 +108,8 @@ mutable struct BBlock end end -""" - const IDInstPair = Tuple{ID, NewInstruction} -""" const IDInstPair = Tuple{ID,NewInstruction} -""" - BBlock(id::ID, inst_pairs::Vector{IDInstPair}) - -Convenience constructor -- splits `inst_pairs` into a `Vector{ID}` and `InstVector` in order -to build a `BBlock`. -""" function BBlock(id::ID, inst_pairs::Vector{IDInstPair}) return BBlock(id, first.(inst_pairs), last.(inst_pairs)) end @@ -221,12 +118,6 @@ Base.length(bb::BBlock) = length(bb.inst_ids) Base.copy(bb::BBlock) = BBlock(bb.id, copy(bb.inst_ids), copy(bb.insts)) -""" - phi_nodes(bb::BBlock)::Tuple{Vector{ID}, Vector{IDPhiNode}} - -Returns all of the `IDPhiNode`s at the start of `bb`, along with their `ID`s. If there are -no `IDPhiNode`s at the start of `bb`, then both vectors will be empty. -""" function phi_nodes(bb::BBlock) n_phi_nodes = findlast(x -> x.stmt isa IDPhiNode, bb.insts) if n_phi_nodes === nothing @@ -235,77 +126,21 @@ function phi_nodes(bb::BBlock) return bb.inst_ids[1:n_phi_nodes], bb.insts[1:n_phi_nodes] end -""" - Base.insert!(bb::BBlock, n::Int, id::ID, stmt::CC.NewInstruction)::Nothing - -Inserts `stmt` and `id` into `bb` immediately before the `n`th instruction. -""" function Base.insert!(bb::BBlock, n::Int, id::ID, inst::NewInstruction)::Nothing insert!(bb.inst_ids, n, id) insert!(bb.insts, n, inst) return nothing end -""" - terminator(bb::BBlock) - -Returns the terminator associated to `bb`. If the last instruction in `bb` isa -`Terminator` then that is returned, otherwise `nothing` is returned. -""" terminator(bb::BBlock) = isa(bb.insts[end].stmt, Terminator) ? bb.insts[end].stmt : nothing -""" - insert_before_terminator!(bb::BBlock, id::ID, inst::NewInstruction)::Nothing - -If the final instruction in `bb` is a `Terminator`, insert `inst` immediately before it. -Otherwise, insert `inst` at the end of the block. -""" function insert_before_terminator!(bb::BBlock, id::ID, inst::NewInstruction)::Nothing insert!(bb, length(bb.insts) + (terminator(bb) === nothing ? 1 : 0), id, inst) return nothing end -""" - collect_stmts(bb::BBlock)::Vector{IDInstPair} - -Returns a `Vector` containing the `ID`s and instructions associated to each line in `bb`. -These should be assumed to be ordered. -""" collect_stmts(bb::BBlock)::Vector{IDInstPair} = collect(zip(bb.inst_ids, bb.insts)) -""" - BBCode( - blocks::Vector{BBlock} - argtypes::Vector{Any} - sptypes::Vector{CC.VarState} - linetable::Vector{Core.LineInfoNode} - meta::Vector{Expr} - ) - -A `BBCode` is a data structure which is similar to `IRCode`, but adds additional structure. - -In particular, a `BBCode` comprises a sequence of basic blocks (`BBlock`s), each of which -comprise a sequence of statements. Moreover, each `BBlock` has its own unique `ID`, as does -each statment. - -The consequence of this is that new basic blocks can be inserted into a `BBCode`. This is -distinct from `IRCode`, in which to create a new basic block, one must insert additional -statments which you know will create a new basic block -- this is generally quite an -unreliable process, while inserting a new `BBlock` into `BBCode` is entirely predictable. -Furthermore, inserting a new `BBlock` does not change the `ID` associated to the other -blocks, meaning that you can safely assume that references from existing basic block -terminators / phi nodes to other blocks will not be modified by inserting a new basic block. - -Additionally, since each statment in each basic block has its own unique `ID`, new -statments can be inserted without changing references between other blocks. `IRCode` also -has some support for this via its `new_nodes` field, but eventually all statements will be -renamed upon `compact!`ing the `IRCode`, meaning that the name of any given statement will -eventually change. - -Finally, note that the basic blocks in a `BBCode` support the custom `Switch` statement. -This statement is not valid in `IRCode`, and is therefore lowered into a collection of -`GotoIfNot`s and `GotoNode`s when a `BBCode` is converted back into an `IRCode`. -""" struct BBCode blocks::Vector{BBlock} argtypes::Vector{Any} @@ -314,12 +149,6 @@ struct BBCode meta::Vector{Expr} end -""" - BBCode(ir::Union{IRCode, BBCode}, new_blocks::Vector{Block}) - -Make a new `BBCode` whose `blocks` is given by `new_blocks`, and fresh copies are made of -all other fields from `ir`. -""" function BBCode(ir::Union{IRCode,BBCode}, new_blocks::Vector{BBlock}) return BBCode( new_blocks, @@ -333,20 +162,8 @@ end # Makes use of the above outer constructor for `BBCode`. Base.copy(ir::BBCode) = BBCode(ir, copy(ir.blocks)) -""" - compute_all_successors(ir::BBCode)::Dict{ID, Vector{ID}} - -Compute a map from the `ID` of each `BBlock` in `ir` to its possible successors. -""" compute_all_successors(ir::BBCode)::Dict{ID,Vector{ID}} = _compute_all_successors(ir.blocks) -""" - _compute_all_successors(blks::Vector{BBlock})::Dict{ID, Vector{ID}} - -Internal method implementing [`compute_all_successors`](@ref). This method is easier to -construct test cases for because it only requires the collection of `BBlocks`, not all of -the other stuff that goes into a `BBCode`. -""" @noinline function _compute_all_successors(blks::Vector{BBlock})::Dict{ID,Vector{ID}} succs = map(enumerate(blks)) do (n, blk) is_final_block = n == length(blks) @@ -368,22 +185,10 @@ the other stuff that goes into a `BBCode`. return Dict{ID,Vector{ID}}((b.id, succ) for (b, succ) in zip(blks, succs)) end -""" - compute_all_predecessors(ir::BBCode)::Dict{ID, Vector{ID}} - -Compute a map from the `ID` of each `BBlock` in `ir` to its possible predecessors. -""" function compute_all_predecessors(ir::BBCode)::Dict{ID,Vector{ID}} return _compute_all_predecessors(ir.blocks) end -""" - _compute_all_predecessors(blks::Vector{BBlock})::Dict{ID, Vector{ID}} - -Internal method implementing [`compute_all_predecessors`](@ref). This method is easier to -construct test cases for because it only requires the collection of `BBlocks`, not all of -the other stuff that goes into a `BBCode`. -""" function _compute_all_predecessors(blks::Vector{BBlock})::Dict{ID,Vector{ID}} successor_map = _compute_all_successors(blks) @@ -401,22 +206,8 @@ function _compute_all_predecessors(blks::Vector{BBlock})::Dict{ID,Vector{ID}} return predecessor_map end -""" - collect_stmts(ir::BBCode)::Vector{IDInstPair} - -Produce a `Vector` containing all of the statements in `ir`. These are returned in -order, so it is safe to assume that element `n` refers to the `nth` element of the `IRCode` -associated to `ir`. -""" collect_stmts(ir::BBCode)::Vector{IDInstPair} = reduce(vcat, map(collect_stmts, ir.blocks)) -""" - id_to_line_map(ir::BBCode) - -Produces a `Dict` mapping from each `ID` associated with a line in `ir` to its line number. -This is isomorphic to mapping to its `SSAValue` in `IRCode`. Terminators do not have `ID`s -associated to them, so not every line in the original `IRCode` is mapped to. -""" function id_to_line_map(ir::BBCode) lines = collect_stmts(ir) lines_and_line_numbers = collect(zip(lines, eachindex(lines))) @@ -427,19 +218,8 @@ end concatenate_ids(bb_code::BBCode) = reduce(vcat, map(b -> b.inst_ids, bb_code.blocks)) concatenate_stmts(bb_code::BBCode) = reduce(vcat, map(b -> b.insts, bb_code.blocks)) -""" - control_flow_graph(bb_code::BBCode)::Core.Compiler.CFG - -Computes the `Core.Compiler.CFG` object associated to this `bb_code`. -""" control_flow_graph(bb_code::BBCode)::Core.Compiler.CFG = _control_flow_graph(bb_code.blocks) -""" - _control_flow_graph(blks::Vector{BBlock})::Core.Compiler.CFG - -Internal function, used to implement [`control_flow_graph`](@ref). Easier to write test -cases for because there is no need to construct an ensure BBCode object, just the `BBlock`s. -""" function _control_flow_graph(blks::Vector{BBlock})::Core.Compiler.CFG # Get IDs of predecessors and successors. @@ -462,28 +242,11 @@ function _control_flow_graph(blks::Vector{BBlock})::Core.Compiler.CFG return Core.Compiler.CFG(basic_blocks, index[2:(end - 1)]) end -""" - _instructions_to_blocks(insts::InstVector, cfg::CC.CFG)::InstVector - -Pulls out the instructions from `insts`, and calls `__line_numbers_to_block_numbers!`. -""" function _lines_to_blocks(insts::InstVector, cfg::CC.CFG)::InstVector stmts = __line_numbers_to_block_numbers!(Any[x.stmt for x in insts], cfg) return map((inst, stmt) -> NewInstruction(inst; stmt), insts, stmts) end -""" - __line_numbers_to_block_numbers!(insts::Vector{Any}, cfg::CC.CFG) - -Converts any edges in `GotoNode`s, `GotoIfNot`s, `PhiNode`s, and `:enter` expressions which -refer to line numbers into references to block numbers. The `cfg` provides the information -required to perform this conversion. - -For context, `CodeInfo` objects have references to line numbers, while `IRCode` uses -block numbers. - -This code is copied over directly from the body of `Core.Compiler.inflate_ir!`. -""" function __line_numbers_to_block_numbers!(insts::Vector{Any}, cfg::CC.CFG) for i in eachindex(insts) stmt = insts[i] @@ -507,19 +270,6 @@ end # Converting from IRCode to BBCode # -""" - BBCode(ir::IRCode) - -Convert an `ir` into a `BBCode`. Creates a completely independent data structure, so -mutating the `BBCode` returned will not mutate `ir`. - -All `PhiNode`s, `GotoIfNot`s, and `GotoNode`s will be replaced with the `IDPhiNode`s, -`IDGotoIfNot`s, and `IDGotoNode`s respectively. - -See `IRCode` for conversion back to `IRCode`. - -Note that `IRCode(BBCode(ir))` should be equal to the identity function. -""" function BBCode(ir::IRCode) # Produce a new set of statements with `IDs` rather than `SSAValues` and block numbers. @@ -534,11 +284,6 @@ function BBCode(ir::IRCode) return BBCode(ir, blocks) end -""" - new_inst_vec(x::CC.InstructionStream) - -Convert an `Compiler.InstructionStream` into a list of `Compiler.NewInstruction`s. -""" function new_inst_vec(x::CC.InstructionStream) stmt = @static VERSION < v"1.11.0-rc4" ? x.inst : x.stmt return map((v...,) -> NewInstruction(v...), stmt, x.type, x.info, x.line, x.flag) @@ -548,27 +293,12 @@ end const SSAToIdDict = Dict{SSAValue,ID} const BlockNumToIdDict = Dict{Integer,ID} -""" - _ssas_to_ids(insts::InstVector)::Tuple{Vector{ID}, InstVector} - -Assigns an ID to each line in `stmts`, and replaces each instance of an `SSAValue` in each -line with the corresponding `ID`. For example, a call statement of the form -`Expr(:call, :f, %4)` is be replaced with `Expr(:call, :f, id_assigned_to_%4)`. -""" function _ssas_to_ids(insts::InstVector)::Tuple{Vector{ID},InstVector} ids = map(_ -> ID(), insts) val_id_map = SSAToIdDict(zip(SSAValue.(eachindex(insts)), ids)) return ids, map(Base.Fix1(_ssa_to_ids, val_id_map), insts) end -""" - _ssa_to_ids(d::SSAToIdDict, inst::NewInstruction) - -Produce a new instance of `inst` in which all instances of `SSAValue`s are replaced with -the `ID`s prescribed by `d`, all basic block numbers are replaced with the `ID`s -prescribed by `d`, and `GotoIfNot`, `GotoNode`, and `PhiNode` instances are replaced with -the corresponding `ID` versions. -""" function _ssa_to_ids(d::SSAToIdDict, inst::NewInstruction) return NewInstruction(inst; stmt=_ssa_to_ids(d, inst.stmt)) end @@ -591,12 +321,6 @@ end _ssa_to_ids(d::SSAToIdDict, x::GotoNode) = x _ssa_to_ids(d::SSAToIdDict, x::GotoIfNot) = GotoIfNot(get(d, x.cond, x.cond), x.dest) -""" - _block_nums_to_ids(insts::InstVector, cfg::CC.CFG)::Tuple{Vector{ID}, InstVector} - -Assign to each basic block in `cfg` an `ID`. Replace all integers referencing block numbers -in `insts` with the corresponding `ID`. Return the `ID`s and the updated instructions. -""" function _block_nums_to_ids(insts::InstVector, cfg::CC.CFG)::Tuple{Vector{ID},InstVector} ids = map(_ -> ID(), cfg.blocks) block_num_id_map = BlockNumToIdDict(zip(eachindex(cfg.blocks), ids)) @@ -617,18 +341,6 @@ _block_num_to_ids(d::BlockNumToIdDict, x) = x # Converting from BBCode to IRCode # -""" - IRCode(bb_code::BBCode) - -Produce an `IRCode` instance which is equivalent to `bb_code`. The resulting `IRCode` -shares no memory with `bb_code`, so can be safely mutated without modifying `bb_code`. - -All `IDPhiNode`s, `IDGotoIfNot`s, and `IDGotoNode`s are converted into `PhiNode`s, -`GotoIfNot`s, and `GotoNode`s respectively. - -In the resulting `bb_code`, any `Switch` nodes are lowered into a semantically-equivalent -collection of `GotoIfNot` nodes. -""" function CC.IRCode(bb_code::BBCode) bb_code = _lower_switch_statements(bb_code) bb_code = _remove_double_edges(bb_code) @@ -651,12 +363,6 @@ function CC.IRCode(bb_code::BBCode) ) end -""" - _lower_switch_statements(bb_code::BBCode) - -Converts all `Switch`s into a semantically-equivalent collection of `GotoIfNot`s. See the -`Switch` docstring for an explanation of what is going on here. -""" function _lower_switch_statements(bb_code::BBCode) new_blocks = Vector{BBlock}(undef, 0) for block in bb_code.blocks @@ -683,12 +389,6 @@ function _lower_switch_statements(bb_code::BBCode) return BBCode(bb_code, new_blocks) end -""" - _ids_to_line_numbers(bb_code::BBCode)::InstVector - -For each statement in `bb_code`, returns a `NewInstruction` in which every `ID` is replaced -by either an `SSAValue`, or an `Int64` / `Int32` which refers to an `SSAValue`. -""" function _ids_to_line_numbers(bb_code::BBCode)::InstVector # Construct map from `ID`s to `SSAValue`s. @@ -703,12 +403,6 @@ function _ids_to_line_numbers(bb_code::BBCode)::InstVector return [_to_ssas(id_to_ssa_map, stmt) for stmt in concatenate_stmts(bb_code)] end -""" - _to_ssas(d::Dict, inst::NewInstruction) - -Like `_ssas_to_ids`, but in reverse. Converts IDs to SSAValues / (integers corresponding -to ssas). -""" _to_ssas(d::Dict, inst::NewInstruction) = NewInstruction(inst; stmt=_to_ssas(d, inst.stmt)) _to_ssas(d::Dict, x::ReturnNode) = isdefined(x, :val) ? ReturnNode(get(d, x.val, x.val)) : x _to_ssas(d::Dict, x::Expr) = Expr(x.head, map(a -> get(d, a, a), x.args)...) @@ -727,14 +421,6 @@ end _to_ssas(d::Dict, x::IDGotoNode) = GotoNode(d[x.label].id) _to_ssas(d::Dict, x::IDGotoIfNot) = GotoIfNot(get(d, x.cond, x.cond), d[x.dest].id) -""" - _remove_double_edges(ir::BBCode)::BBCode - -If the `dest` field of an `IDGotoIfNot` node in block `n` of `ir` points towards the `n+1`th -block then we have two edges from block `n` to block `n+1`. This transformation replaces all -such `IDGotoIfNot` nodes with unconditional `IDGotoNode`s pointing towards the `n+1`th block -in `ir`. -""" function _remove_double_edges(ir::BBCode) new_blks = map(enumerate(ir.blocks)) do (n, blk) t = terminator(blk) @@ -748,89 +434,6 @@ function _remove_double_edges(ir::BBCode) return BBCode(ir, new_blks) end -""" - _build_graph_of_cfg(blks::Vector{BBlock})::Tuple{SimpleDiGraph, Dict{ID, Int}} - -Builds a `SimpleDiGraph`, `g`, representing of the CFG associated to `blks`, where `blks` -comprises the collection of basic blocks associated to a `BBCode`. -This is a type from Graphs.jl, so constructing `g` makes it straightforward to analyse the -control flow structure of `ir` using algorithms from Graphs.jl. - -Returns a 2-tuple, whose first element is `g`, and whose second element is a map from -the `ID` associated to each basic block in `ir`, to the `Int` corresponding to its node -index in `g`. -""" -function _build_graph_of_cfg(blks::Vector{BBlock})::Tuple{SimpleDiGraph,Dict{ID,Int}} - node_ints = collect(eachindex(blks)) - id_to_int = Dict(zip(map(blk -> blk.id, blks), node_ints)) - successors = _compute_all_successors(blks) - g = SimpleDiGraph(length(blks)) - for blk in blks, successor in successors[blk.id] - add_edge!(g, id_to_int[blk.id], id_to_int[successor]) - end - return g, id_to_int -end - -""" - _distance_to_entry(blks::Vector{BBlock})::Vector{Int} - -For each basic block in `blks`, compute the distance from it to the entry point (the first -block. The distance is `typemax(Int)` if no path from the entry point to a given node. -""" -function _distance_to_entry(blks::Vector{BBlock})::Vector{Int} - g, id_to_int = _build_graph_of_cfg(blks) - return dijkstra_shortest_paths(g, id_to_int[blks[1].id]).dists -end - -""" - sort_blocks!(ir::BBCode)::BBCode - -Ensure that blocks appear in order of distance-from-entry-point, where distance the -distance from block b to the entry point is defined to be the minimum number of basic -blocks that must be passed through in order to reach b. - -For reasons unknown (to me, Will), the compiler / optimiser needs this for inference to -succeed. Since we do quite a lot of re-ordering on the reverse-pass of AD, this is a problem -there. - -WARNING: use with care. Only use if you are confident that arbitrary re-ordering of basic -blocks in `ir` is valid. Notably, this does not hold if you have any `IDGotoIfNot` nodes in -`ir`. -""" -function sort_blocks!(ir::BBCode)::BBCode - I = sortperm(_distance_to_entry(ir.blocks)) - ir.blocks .= ir.blocks[I] - return ir -end - -""" - characterise_unique_predecessor_blocks(blks::Vector{BBlock}) -> - Tuple{Dict{ID, Bool}, Dict{ID, Bool}} - -We call a block `b` a _unique_ _predecessor_ in the control flow graph associated to `blks` -if it is the only predecessor to all of its successors. Put differently we call `b` a unique -predecessor if, whenever control flow arrives in any of the successors of `b`, we know for -certain that the previous block must have been `b`. - -Returns two `Dict`s. A value in the first `Dict` is `true` if the block associated to its -key is a unique precessor, and is `false` if not. A value in the second `Dict` is `true` if -it has a single predecessor, and that predecessor is a unique predecessor. - -*Context*: - -This information is important for optimising AD because knowing that `b` is a unique -predecessor means that -1. on the forwards-pass, there is no need to push the ID of `b` to the block stack when - passing through it, and -2. on the reverse-pass, there is no need to pop the block stack when passing through one of - the successors to `b`. - -Utilising this reduces the overhead associated to doing AD. It is quite important when -working with cheap loops -- loops where the operations performed at each iteration -are inexpensive -- for which minimising memory pressure is critical to performance. It is -also important for single-block functions, because it can be used to entirely avoid using a -block stack at all. -""" function characterise_unique_predecessor_blocks( blks::Vector{BBlock} )::Tuple{Dict{ID,Bool},Dict{ID,Bool}} @@ -875,22 +478,9 @@ function characterise_unique_predecessor_blocks( return is_unique_pred, pred_is_unique_pred end -""" - is_reachable_return_node(x::ReturnNode) - -Determine whether `x` is a `ReturnNode`, and if it is, if it is also reachable. This is -purely a function of whether or not its `val` field is defined or not. -""" is_reachable_return_node(x::ReturnNode) = isdefined(x, :val) is_reachable_return_node(x) = false -""" - characterise_used_ids(stmts::Vector{IDInstPair})::Dict{ID, Bool} - -For each line in `stmts`, determine whether it is referenced anywhere else in the code. -Returns a dictionary containing the results. An element is `false` if the corresponding -`ID` is unused, and `true` if is used. -""" function characterise_used_ids(stmts::Vector{IDInstPair})::Dict{ID,Bool} ids = first.(stmts) insts = last.(stmts) @@ -906,14 +496,6 @@ function characterise_used_ids(stmts::Vector{IDInstPair})::Dict{ID,Bool} return is_used end -""" - _find_id_uses!(d::Dict{ID, Bool}, x) - -Helper function used in [`characterise_used_ids`](@ref). For all uses of `ID`s in `x`, set -the corresponding value of `d` to `true`. - -For example, if `x = ReturnNode(ID(5))`, then this function sets `d[ID(5)] = true`. -""" function _find_id_uses!(d::Dict{ID,Bool}, x::Expr) for arg in x.args in(arg, keys(d)) && setindex!(d, true, arg) @@ -938,67 +520,4 @@ end _find_id_uses!(d::Dict{ID,Bool}, x::QuoteNode) = nothing _find_id_uses!(d::Dict{ID,Bool}, x) = nothing -""" - _is_reachable(blks::Vector{BBlock})::Vector{Bool} - -Computes a `Vector` whose length is `length(blks)`. The `n`th element is `true` iff it is -possible for control flow to reach the `n`th block. -""" -_is_reachable(blks::Vector{BBlock})::Vector{Bool} = _distance_to_entry(blks) .< typemax(Int) - -""" - remove_unreachable_blocks!(ir::BBCode)::BBCode - -If a basic block in `ir` cannot possibly be reached during execution, then it can be safely -removed from `ir` without changing its functionality. -A block is unreachable if either: -1. it has no predecessors _and_ it is not the first block, or -2. all of its predecessors are themselves unreachable. - -For example, consider the following IR: -```jldoctest remove_unreachable_blocks -julia> ir = ircode( - Any[Core.ReturnNode(nothing), Expr(:call, sin, 5), Core.ReturnNode(Core.SSAValue(2))], - Any[Any, Any, Any], - ); -``` -There is no possible way to reach the second basic block (lines 2 and 3). Applying this -function will therefore remove it, yielding the following: -```jldoctest remove_unreachable_blocks -julia> IRCode(remove_unreachable_blocks!(BBCode(ir))) -1 1 ─ return nothing -``` - -In the blocks which have not been removed, there may be references to blocks which have been -removed. For example, the `edge`s in a `PhiNode` may contain a reference to a removed block. -These references are removed in-place from these remaining blocks, so this function will (in -general) modify `ir`. -""" -remove_unreachable_blocks!(ir::BBCode) = BBCode(ir, _remove_unreachable_blocks!(ir.blocks)) - -function _remove_unreachable_blocks!(blks::Vector{BBlock}) - - # Figure out which blocks are reachable. - is_reachable = _is_reachable(blks) - - # Collect all blocks which are reachable. - remaining_blks = blks[is_reachable] - - # For each reachable block, remove any references to removed blocks. These can appear in - # `PhiNode`s with edges that come from remove blocks. - removed_block_ids = map(idx -> blks[idx].id, findall(!, is_reachable)) - for blk in remaining_blks, inst in blk.insts - stmt = inst.stmt - stmt isa IDPhiNode || continue - for n in reverse(1:length(stmt.edges)) - if stmt.edges[n] in removed_block_ids - deleteat!(stmt.edges, n) - deleteat!(stmt.values, n) - end - end - end - - return remaining_blks -end - end From 35e9b7a750990596e7e21844d62062aee8c45692 Mon Sep 17 00:00:00 2001 From: Will Tebbutt <3628294+willtebbutt@users.noreply.github.com> Date: Tue, 15 Apr 2025 13:48:22 +0100 Subject: [PATCH 64/69] Tidy up further --- test/copyable_task.jl | 26 +++++++++++++++++++++++++ test/issues.jl | 44 ------------------------------------------- test/runtests.jl | 1 - 3 files changed, 26 insertions(+), 45 deletions(-) delete mode 100644 test/issues.jl diff --git a/test/copyable_task.jl b/test/copyable_task.jl index 5b3feb27..3bb01b5d 100644 --- a/test/copyable_task.jl +++ b/test/copyable_task.jl @@ -229,4 +229,30 @@ @test consume(a) == 4 end end + @testset "Issue: PR-86 (DynamicPPL.jl/pull/261)" begin + function f() + t = Array{Int}(undef, 1) + t[1] = 0 + for _ in 1:4000 + produce(t[1]) + t[1] + t[1] = 1 + t[1] + end + end + + ttask = TapedTask(nothing, f) + + ex = try + for _ in 1:999 + consume(ttask) + consume(ttask) + a = copy(ttask) + consume(a) + consume(a) + end + catch ex + ex + end + @test ex === nothing + end end diff --git a/test/issues.jl b/test/issues.jl deleted file mode 100644 index d649a151..00000000 --- a/test/issues.jl +++ /dev/null @@ -1,44 +0,0 @@ -@testset "Issues" begin - @testset "Issue: PR-86 (DynamicPPL.jl/pull/261)" begin - function f() - t = Array{Int}(undef, 1) - t[1] = 0 - for _ in 1:4000 - produce(t[1]) - t[1] - t[1] = 1 + t[1] - end - end - - ttask = TapedTask(f) - - ex = try - for _ in 1:999 - consume(ttask) - consume(ttask) - a = copy(ttask) - consume(a) - consume(a) - end - catch ex - ex - end - @test ex === nothing - end - - # TODO: this test will need to change because I'm going to modify the interface _very_ - # slightly. - @testset "Issue-140, copy unstarted task" begin - function f(x) - for i in 1:3 - produce(i + x) - end - end - - ttask = TapedTask(f, 3) - ttask2 = copy(ttask) - @test consume(ttask2) == 4 - ttask3 = copy(ttask; args=(4,)) - @test consume(ttask3) == 5 - end -end diff --git a/test/runtests.jl b/test/runtests.jl index 54876973..e6400098 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,5 +5,4 @@ include("front_matter.jl") @test JuliaFormatter.format(Libtask; verbose=false, overwrite=false) end include("copyable_task.jl") - # include("issues.jl") end From c54dc5c462ae875814dd4f457e7e8b7bb8a8e7db Mon Sep 17 00:00:00 2001 From: Will Tebbutt <3628294+willtebbutt@users.noreply.github.com> Date: Tue, 15 Apr 2025 13:53:25 +0100 Subject: [PATCH 65/69] Remove undefined export --- src/bbcode.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/bbcode.jl b/src/bbcode.jl index 9d4f45d2..843877a8 100644 --- a/src/bbcode.jl +++ b/src/bbcode.jl @@ -31,7 +31,6 @@ export ID, collect_stmts, compute_all_predecessors, BBCode, - remove_unreachable_blocks!, characterise_used_ids, characterise_unique_predecessor_blocks, InstVector, From b2e65b86e02c2fb301b29de31dee15e1834e71ef Mon Sep 17 00:00:00 2001 From: Will Tebbutt <3628294+willtebbutt@users.noreply.github.com> Date: Tue, 15 Apr 2025 13:56:50 +0100 Subject: [PATCH 66/69] Bump patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b150bd08..f7ab46e0 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" license = "MIT" desc = "Tape based task copying in Turing" repo = "https://github.com/TuringLang/Libtask.jl.git" -version = "0.8.8" +version = "0.9.0" [deps] MistyClosures = "dbe65cb8-6be2-42dd-bbc5-4196aaced4f4" From 01a0b34afe67a7cf688414cfc1f88d27ba75014d Mon Sep 17 00:00:00 2001 From: Will Tebbutt <3628294+willtebbutt@users.noreply.github.com> Date: Tue, 15 Apr 2025 17:32:45 +0100 Subject: [PATCH 67/69] Test produce global performance --- src/test_utils.jl | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 59ec70f0..961f635b 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -139,8 +139,11 @@ function test_cases() [Ptr{UInt8}, Ptr{UInt8}], allocs, ), - Testcase("dynamic scope 1", 5, (taped_globals_tester_1,), nothing, [5], allocs), - Testcase("dynamic scope 2", 6, (taped_globals_tester_1,), nothing, [6], none), + Testcase("globals tester 1", 5, (taped_globals_tester_1,), nothing, [5], allocs), + Testcase("globals tester 2", 6, (taped_globals_tester_1,), nothing, [6], none), + Testcase( + "globals tester 3", 6, (while_loop_with_globals,), nothing, fill(6, 9), allocs + ), Testcase( "nested (static)", nothing, (static_nested_outer,), nothing, [true, false], none ), @@ -291,6 +294,15 @@ function taped_globals_tester_1() return nothing end +function while_loop_with_globals() + t = 1 + while t < 10 + produce(get_taped_globals(Int)) + t = 1 + t + end + return nothing +end + @noinline function nested_inner() produce(true) return 1 From ebe8f91a184094f4ed9c58b10005629c60f4d65a Mon Sep 17 00:00:00 2001 From: Will Tebbutt <3628294+willtebbutt@users.noreply.github.com> Date: Tue, 15 Apr 2025 17:32:57 +0100 Subject: [PATCH 68/69] Remove more references to dynamic scope --- README.md | 2 +- src/copyable_task.jl | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index a6cd2f90..8a429e5c 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [![Libtask Testing](https://github.com/TuringLang/Libtask.jl/workflows/Libtask%20Testing/badge.svg)](https://github.com/TuringLang/Libtask.jl/actions?branch=main) -Resumable and copyable functions in Julia, with optional dynamic scope. +Resumable and copyable functions in Julia, with optional function-specific globals. See the docs for example usage. Used in the [Turing](https://github.com/TuringLang/Turing.jl) probabilistic programming language to implement various particle-based inference methods, for example those in [AdvancedPS.jl](https://github.com/TuringLang/AdvancedPS.jl/). diff --git a/src/copyable_task.jl b/src/copyable_task.jl index fd74a664..a73ac29b 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -1,8 +1,8 @@ """ get_taped_globals(T::Type) -Returns the dynamic scope associated to `Libtask`. If called from inside a `TapedTask`, this -will return whatever is contained in its `taped_globals` field. +When called from inside a call to a `TapedTask`, this will return whatever is contained in +its `taped_globals` field. The type `T` is required for optimal performance. If you know that the result of this operation must return a specific type, specific `T`. If you do not know what type it will @@ -10,7 +10,8 @@ return, pass `Any` -- this will typically yield type instabilities, but will run See also [`set_taped_globals!`](@ref). """ -get_taped_globals(::Type{T}) where {T} = typeassert(task_local_storage(:task_variable), T) +@noinline get_taped_globals(::Type{T}) where {T} = + typeassert(task_local_storage(:task_variable), T) __v::Int = 5 From 2907f032da3d5c8d21c8c33c738602b47c3c85f4 Mon Sep 17 00:00:00 2001 From: Will Tebbutt <3628294+willtebbutt@users.noreply.github.com> Date: Tue, 15 Apr 2025 17:38:18 +0100 Subject: [PATCH 69/69] Document type assertion --- src/copyable_task.jl | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/copyable_task.jl b/src/copyable_task.jl index a73ac29b..0dc49940 100644 --- a/src/copyable_task.jl +++ b/src/copyable_task.jl @@ -9,9 +9,21 @@ operation must return a specific type, specific `T`. If you do not know what typ return, pass `Any` -- this will typically yield type instabilities, but will run correctly. See also [`set_taped_globals!`](@ref). + +# Extended Help + + """ -@noinline get_taped_globals(::Type{T}) where {T} = - typeassert(task_local_storage(:task_variable), T) +@noinline function get_taped_globals(::Type{T}) where {T} + # This function is `@noinline`d to ensure that the type-unstable items in here do not + # appear in a calling function, and cause allocations. + # + # The return type of `task_local_storage(:task_variable)` is `Any`. To ensure that this + # type instability does not propagate through the rest of the code, we `typeassert` the + # result to be `T`. By doing this, callers of this function will (hopefully) think + # carefully about how they can figure out what type they have put in global storage. + return typeassert(task_local_storage(:task_variable), T) +end __v::Int = 5