diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml new file mode 100644 index 00000000..323237ba --- /dev/null +++ b/.JuliaFormatter.toml @@ -0,0 +1 @@ +style = "blue" diff --git a/.github/workflows/Documentation.yml b/.github/workflows/Documentation.yml new file mode 100644 index 00000000..0ec2baa2 --- /dev/null +++ b/.github/workflows/Documentation.yml @@ -0,0 +1,32 @@ +name: Documentation + +on: + push: + branches: + - main + tags: '*' + pull_request: + +jobs: + build: + permissions: + contents: write + pull-requests: read + statuses: write + actions: write + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: '1' + arch: x64 + include-all-prereleases: false + - name: Install dependencies + run: julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.update(); Pkg.instantiate()' + - name: Build and deploy + env: + GKSwstype: nul # turn off GR's interactive plotting for notebooks + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # If authenticating with GitHub Actions token + DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} # For authentication with SSH deploy key + run: julia --project=docs/ docs/make.jl diff --git a/.github/workflows/Testing.yaml b/.github/workflows/Testing.yaml index 09c2390c..136e7f39 100644 --- a/.github/workflows/Testing.yaml +++ b/.github/workflows/Testing.yaml @@ -11,7 +11,6 @@ jobs: strategy: matrix: version: - - '1.7' - '1.10' - '1' - 'nightly' @@ -26,12 +25,12 @@ jobs: - os: macOS-latest arch: x86 steps: - - uses: actions/checkout@v2 - - uses: julia-actions/setup-julia@v1 + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} arch: ${{ matrix.arch }} - - uses: actions/cache@v1 + - uses: actions/cache@v4 env: cache-name: cache-artifacts with: diff --git a/.gitignore b/.gitignore index d039efaf..bb8f9cf6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,60 +1,4 @@ -# Prerequisites -*.d - -# Object files -*.o -*.ko -*.obj -*.elf - -# Linker output -*.ilk -*.map -*.exp - -# Precompiled Headers -*.gch -*.pch - -# Libraries -*.lib -*.a -*.la -*.lo - -# Shared objects (inc. Windows DLLs) -*.dll -*.so -*.so.* -*.dylib - -# Executables -*.exe -*.out -*.app -*.i*86 -*.x86_64 -*.hex - -# Debug files -*.dSYM/ -*.su -*.idb -*.pdb - -# Kernel Module Compile Results -*.mod* -*.cmd -.tmp_versions/ -modules.order -Module.symvers -Mkfile.old -dkms.conf - # Projects files -Manifest.toml -deps/build.log -deps/deps.jl -deps/usr/ -deps/tmp-build.jl +Manifest* *.cov +docs/build \ No newline at end of file diff --git a/NEWS.md b/NEWS.md new file mode 100644 index 00000000..db7b2b8b --- /dev/null +++ b/NEWS.md @@ -0,0 +1,8 @@ +- From v0.6.0, Libtask is implemented by recording all the computing + to a tape and copying that tape. Before that version, it is based on + a tricky hack on the Julia internals. You can check the commit + history of this repo to see the details. + +- From version 0.9.0, the old `TArray` and `TRef` types are completely removed, where + previously they were only deprecated. Additionally, the internals have been completely + overhauled, and the public interface more precisely defined. See the docs for more info. diff --git a/Project.toml b/Project.toml index 41385419..f7ab46e0 100644 --- a/Project.toml +++ b/Project.toml @@ -3,22 +3,23 @@ uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" license = "MIT" desc = "Tape based task copying in Turing" repo = "https://github.com/TuringLang/Libtask.jl.git" -version = "0.8.8" +version = "0.9.0" [deps] -FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" -LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +MistyClosures = "dbe65cb8-6be2-42dd-bbc5-4196aaced4f4" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] -FunctionWrappers = "1.1" -LRUCache = "1.3" -julia = "1.7" +Aqua = "0.8.11" +JuliaFormatter = "1.0.62" +MistyClosures = "2.0.0" +Test = "1" +julia = "1.10.8" [extras] -BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "BenchmarkTools"] +test = ["Aqua", "JuliaFormatter", "Test"] diff --git a/README.md b/README.md index 57902cac..8a429e5c 100644 --- a/README.md +++ b/README.md @@ -2,98 +2,8 @@ [![Libtask Testing](https://github.com/TuringLang/Libtask.jl/workflows/Libtask%20Testing/badge.svg)](https://github.com/TuringLang/Libtask.jl/actions?branch=main) -Tape based task copying in Turing -## Getting Started +Resumable and copyable functions in Julia, with optional function-specific globals. +See the docs for example usage. -Stack allocated objects are always deep copied: - -```julia -using Libtask - -function f() - t = 0 - for _ in 1:10 - produce(t) - t = 1 + t - end -end - -ttask = TapedTask(f) - -@show consume(ttask) # 0 -@show consume(ttask) # 1 - -a = copy(ttask) -@show consume(a) # 2 -@show consume(a) # 3 - -@show consume(ttask) # 2 -@show consume(ttask) # 3 -``` - -Heap-allocated Array and Ref objects are deep copied by default: - -```julia -using Libtask - -function f() - t = [0 1 2] - for _ in 1:10 - produce(t[1]) - t[1] = 1 + t[1] - end -end - -ttask = TapedTask(f) - -@show consume(ttask) # 0 -@show consume(ttask) # 1 - -a = copy(ttask) -@show consume(a) # 2 -@show consume(a) # 3 - -@show consume(ttask) # 2 -@show consume(ttask) # 3 -``` - -Other heap-allocated objects (e.g., `Dict`) are shallow copied, by default: - -```julia -using Libtask - -function f() - t = Dict(1=>10, 2=>20) - while true - produce(t[1]) - t[1] = 1 + t[1] - end -end - -ttask = TapedTask(f) - -@show consume(ttask) # 10 -@show consume(ttask) # 11 - -a = copy(ttask) -@show consume(a) # 12 -@show consume(a) # 13 - -@show consume(ttask) # 14 -@show consume(ttask) # 15 -``` - -Notes: - -- The [Turing](https://github.com/TuringLang/Turing.jl) probabilistic - programming language uses this task copying feature in an efficient - implementation of the [particle - filtering](https://en.wikipedia.org/wiki/Particle_filter) sampling - algorithm for arbitrary order [Markov - processes](https://en.wikipedia.org/wiki/Markov_model#Hidden_Markov_model). - -- From v0.6.0, Libtask is implemented by recording all the computing - to a tape and copying that tape. Before that version, it is based on - a tricky hack on the Julia internals. You can check the commit - history of this repo to see the details. +Used in the [Turing](https://github.com/TuringLang/Turing.jl) probabilistic programming language to implement various particle-based inference methods, for example those in [AdvancedPS.jl](https://github.com/TuringLang/AdvancedPS.jl/). diff --git a/docs/Project.toml b/docs/Project.toml new file mode 100644 index 00000000..acf52482 --- /dev/null +++ b/docs/Project.toml @@ -0,0 +1,3 @@ +[deps] +Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" diff --git a/docs/make.jl b/docs/make.jl new file mode 100644 index 00000000..318d04ea --- /dev/null +++ b/docs/make.jl @@ -0,0 +1,9 @@ +using Documenter, Libtask + +DocMeta.setdocmeta!(Libtask, :DocTestSetup, :(using Libtask); recursive=true) + +makedocs(; + sitename="Libtask", doctest=true, pages=["index.md", "internals.md"], modules=[Libtask] +) + +deploydocs(; repo="github.com/TuringLang/Libtask.jl.git", push_preview=true) diff --git a/docs/src/index.md b/docs/src/index.md new file mode 100644 index 00000000..9beb0078 --- /dev/null +++ b/docs/src/index.md @@ -0,0 +1,24 @@ +# Libtask + +Libtask is best explained by the docstring for [`TapedTask`](@ref): +```@docs; canonical=true +Libtask.TapedTask +``` + +The functions discussed the above docstring (in addition to [`TapedTask`](@ref) itself) form the +public interface of Libtask.jl. +They divide neatly into two kinds of functions: those which are used to manipulate +[`TapedTask`](@ref)s, and those which are intended to be used _inside_ a +[`TapedTask`](@ref). +First, manipulation of [`TapedTask`](@ref)s: +```@docs; canonical=true +Libtask.consume +Base.copy(::Libtask.TapedTask) +Libtask.set_taped_globals! +``` + +Functions for use inside a [`TapedTask`](@ref)s are: +```@docs; canonical=true +Libtask.produce +Libtask.get_taped_globals +``` diff --git a/docs/src/internals.md b/docs/src/internals.md new file mode 100644 index 00000000..ee6848e8 --- /dev/null +++ b/docs/src/internals.md @@ -0,0 +1,17 @@ +# Internals + +```@docs; canonical=true +Libtask.produce_value +Libtask.is_produce_stmt +Libtask.might_produce +Libtask.stmt_might_produce +Libtask.LazyCallable +Libtask.inc_args +Libtask.get_type +Libtask._typeof +Libtask.replace_captures +Libtask.BasicBlockCode +Libtask.opaque_closure +Libtask.misty_closure +Libtask.optimise_ir! +``` diff --git a/perf/Project.toml b/perf/Project.toml index 9e9ab49b..6522964d 100644 --- a/perf/Project.toml +++ b/perf/Project.toml @@ -8,7 +8,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [compat] -julia = "1.3" +julia = "1.10.8" [targets] test = ["Test", "BenchmarkTools"] diff --git a/perf/benchmark.jl b/perf/benchmark.jl index dcfb3638..a78a59bf 100644 --- a/perf/benchmark.jl +++ b/perf/benchmark.jl @@ -24,15 +24,15 @@ function benchmark_driver!(f, x...; f_displayname=string(f)) GC.gc() print(" Run TapedTask: ") - x = (x[1:end-1]..., produce) + x = (x[1:(end - 1)]..., produce) # show the number of produce calls inside `f` function f_task(f, x; verbose=false) tt = TapedTask(f, x...) c = 0 - while consume(tt)!==nothing - c+=1 + while consume(tt) !== nothing + c += 1 end - verbose && print("#produce=", c, "; ") + return verbose && print("#produce=", c, "; ") end # Note that we need to pass `f` instead of `tf` to avoid # default continuation in `TapedTask` constructor, see, e.g. @@ -40,15 +40,15 @@ function benchmark_driver!(f, x...; f_displayname=string(f)) f_task(f, x; verbose=true) # print #produce calls @btime $f_task($f, $x) GC.gc() + return nothing end #################################################################### - function rosenbrock(x, callback=nothing) i = x[2:end] - j = x[1:end-1] - ret = sum((1 .- j).^2 + 100*(i - j.^2).^2) + j = x[1:(end - 1)] + ret = sum((1 .- j) .^ 2 + 100 * (i - j .^ 2) .^ 2) callback !== nothing && callback(ret) return ret end @@ -59,17 +59,20 @@ benchmark_driver!(rosenbrock, x) #################################################################### function ackley(x::AbstractVector, callback=nothing) - a, b, c = 20.0, -0.2, 2.0*π + a, b, c = 20.0, -0.2, 2.0 * π len_recip = inv(length(x)) sum_sqrs = zero(eltype(x)) sum_cos = sum_sqrs for i in x - sum_cos += cos(c*i) + sum_cos += cos(c * i) sum_sqrs += i^2 callback !== nothing && callback(sum_sqrs) end - return (-a * exp(b * sqrt(len_recip*sum_sqrs)) - - exp(len_recip*sum_cos) + a + MathConstants.e) + return ( + -a * exp(b * sqrt(len_recip * sum_sqrs)) - exp(len_recip * sum_cos) + + a + + MathConstants.e + ) end x = rand(100000) @@ -79,8 +82,8 @@ benchmark_driver!(ackley, x) function generate_matrix_test(n) return (x, callback=nothing) -> begin # @assert length(x) == 2n^2 + n - a = reshape(x[1:n^2], n, n) - b = reshape(x[n^2 + 1:2n^2], n, n) + a = reshape(x[1:(n^2)], n, n) + b = reshape(x[(n^2 + 1):(2n^2)], n, n) ret = log.((a * b) + a - b) callback !== nothing && callback(ret) return ret @@ -94,7 +97,7 @@ benchmark_driver!(matrix_test, x; f_displayname="matrix_test") #################################################################### relu(x) = log.(1.0 .+ exp.(x)) -sigmoid(n) = 1. / (1. + exp(-n)) +sigmoid(n) = 1.0 / (1.0 + exp(-n)) function neural_net(w1, w2, w3, x1, callback=nothing) x2 = relu(w1 * x1) @@ -104,7 +107,7 @@ function neural_net(w1, w2, w3, x1, callback=nothing) return ret end -xs = (randn(10,10), randn(10,10), randn(10), rand(10)) +xs = (randn(10, 10), randn(10, 10), randn(10), rand(10)) benchmark_driver!(neural_net, xs...) #################################################################### @@ -114,11 +117,11 @@ println("======= breakdown benchmark =======") x = rand(100000) tf = Libtask.TapedFunction(ackley, x, nothing) tf(x, nothing); -idx = findlast((x)->isa(x, Libtask.Instruction), tf.tape) +idx = findlast((x) -> isa(x, Libtask.Instruction), tf.tape) ins = tf.tape[idx] b = ins.input[1] -@show ins.input |> length +@show length(ins.input) @btime map(x -> Libtask._lookup(tf, x), ins.input) @btime Libtask._lookup(tf, b) @btime tf.binding_values[b] diff --git a/perf/p0.jl b/perf/p0.jl index 757d8b3a..23a71250 100644 --- a/perf/p0.jl +++ b/perf/p0.jl @@ -5,17 +5,16 @@ using BenchmarkTools @model gdemo(x, y) = begin # Assumptions - σ ~ InverseGamma(2,3) - μ ~ Normal(0,sqrt(σ)) + σ ~ InverseGamma(2, 3) + μ ~ Normal(0, sqrt(σ)) # Observations x ~ Normal(μ, sqrt(σ)) y ~ Normal(μ, sqrt(σ)) end - # Case 1: Sample from the prior. rng = MersenneTwister() -m = Turing.Core.TracedModel(gdemo(1.5, 2.), SampleFromPrior(), VarInfo(), rng) +m = Turing.Core.TracedModel(gdemo(1.5, 2.0), SampleFromPrior(), VarInfo(), rng) f = m.evaluator[1]; args = m.evaluator[2:end]; @@ -28,7 +27,7 @@ println("Run a tape...") @btime t.tf(args...) # Case 2: SMC sampler -m = Turing.Core.TracedModel(gdemo(1.5, 2.), Sampler(SMC(50)), VarInfo(), rng) +m = Turing.Core.TracedModel(gdemo(1.5, 2.0), Sampler(SMC(50)), VarInfo(), rng) f = m.evaluator[1]; args = m.evaluator[2:end]; diff --git a/perf/p1.jl b/perf/p1.jl index 4ecd2ec8..34797f3c 100644 --- a/perf/p1.jl +++ b/perf/p1.jl @@ -2,25 +2,22 @@ using Turing, Test, AbstractMCMC, DynamicPPL, Random import AbstractMCMC.AbstractSampler -function check_numerical(chain, - symbols::Vector, - exact_vals::Vector; - atol=0.2, - rtol=0.0) +function check_numerical(chain, symbols::Vector, exact_vals::Vector; atol=0.2, rtol=0.0) for (sym, val) in zip(symbols, exact_vals) - E = val isa Real ? - mean(chain[sym]) : - vec(mean(chain[sym], dims=1)) + E = val isa Real ? mean(chain[sym]) : vec(mean(chain[sym]; dims=1)) @info (symbol=sym, exact=val, evaluated=E) - @test E ≈ val atol=atol rtol=rtol + @test E ≈ val atol = atol rtol = rtol end end function check_MoGtest_default(chain; atol=0.2, rtol=0.0) - check_numerical(chain, - [:z1, :z2, :z3, :z4, :mu1, :mu2], - [1.0, 1.0, 2.0, 2.0, 1.0, 4.0], - atol=atol, rtol=rtol) + return check_numerical( + chain, + [:z1, :z2, :z3, :z4, :mu1, :mu2], + [1.0, 1.0, 2.0, 2.0, 1.0, 4.0]; + atol=atol, + rtol=rtol, + ) end @model gdemo_d(x, y) = begin @@ -36,4 +33,4 @@ chain = sample(gdemo_d(1.5, 2.0), alg, 5_000) @show chain -check_numerical(chain, [:s, :m], [49/24, 7/6], atol=0.1) +check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1) diff --git a/src/Libtask.jl b/src/Libtask.jl index 8fa79533..ff4692f5 100644 --- a/src/Libtask.jl +++ b/src/Libtask.jl @@ -1,26 +1,21 @@ module Libtask -using FunctionWrappers: FunctionWrapper -using LRUCache +# We'll emit `MistyClosure`s rather than `OpaqueClosure`s. +using MistyClosures -export TapedTask, consume, produce +# Import some names from the compiler. +const CC = Core.Compiler +using Core: OpaqueClosure +using Core.Compiler: Argument, IRCode, ReturnNode -export TArray, tzeros, tfill, TRef # legacy types back compat +# IR-related functionality from Mooncake. +include("utils.jl") +include("bbcode.jl") +using .BasicBlockCode +include("copyable_task.jl") +include("test_utils.jl") -@static if isdefined(Core, :TypedSlot) || isdefined(Core.Compiler, :TypedSlot) - # Julia v1.10 moved Core.TypedSlot to Core.Compiler - # Julia v1.11 removed Core.Compiler.TypedSlot - const TypedSlot = @static if isdefined(Core, :TypedSlot) - Core.TypedSlot - else - Core.Compiler.TypedSlot - end -else - struct TypedSlot end # Dummy -end - -include("tapedfunction.jl") -include("tapedtask.jl") +export TapedTask, consume, produce, get_taped_globals, set_taped_globals! end diff --git a/src/bbcode.jl b/src/bbcode.jl new file mode 100644 index 00000000..843877a8 --- /dev/null +++ b/src/bbcode.jl @@ -0,0 +1,522 @@ +""" + module BasicBlockCode + +Copied over from Mooncake.jl in order to avoid making this package depend on Mooncake. +Refer to Mooncake's developer docs for context on this file. +""" +module BasicBlockCode + +using Core.Compiler: + ReturnNode, + PhiNode, + GotoIfNot, + GotoNode, + NewInstruction, + IRCode, + SSAValue, + PiNode, + Argument +const CC = Core.Compiler + +export ID, + seed_id!, + IDPhiNode, + IDGotoNode, + IDGotoIfNot, + Switch, + BBlock, + phi_nodes, + terminator, + insert_before_terminator!, + collect_stmts, + compute_all_predecessors, + BBCode, + characterise_used_ids, + characterise_unique_predecessor_blocks, + InstVector, + IDInstPair, + __line_numbers_to_block_numbers!, + is_reachable_return_node, + new_inst + +const _id_count::Dict{Int,Int32} = Dict{Int,Int32}() + +function new_inst(@nospecialize(stmt), @nospecialize(type)=Any, flag=CC.IR_FLAG_REFINED) + return NewInstruction(stmt, type, CC.NoCallInfo(), Int32(1), flag) +end + +const InstVector = Vector{NewInstruction} + +struct ID + id::Int32 + function ID() + current_thread_id = Threads.threadid() + id_count = get(_id_count, current_thread_id, Int32(0)) + _id_count[current_thread_id] = id_count + Int32(1) + return new(id_count) + end +end + +Base.copy(id::ID) = id + +function seed_id!() + return global _id_count[Threads.threadid()] = 0 +end + +struct IDPhiNode + edges::Vector{ID} + values::Vector{Any} +end + +Base.:(==)(x::IDPhiNode, y::IDPhiNode) = x.edges == y.edges && x.values == y.values + +Base.copy(node::IDPhiNode) = IDPhiNode(copy(node.edges), copy(node.values)) + +struct IDGotoNode + label::ID +end + +Base.copy(node::IDGotoNode) = IDGotoNode(copy(node.label)) + +struct IDGotoIfNot + cond::Any + dest::ID +end + +Base.copy(node::IDGotoIfNot) = IDGotoIfNot(copy(node.cond), copy(node.dest)) + +struct Switch + conds::Vector{Any} + dests::Vector{ID} + fallthrough_dest::ID + function Switch(conds::Vector{Any}, dests::Vector{ID}, fallthrough_dest::ID) + @assert length(conds) == length(dests) + return new(conds, dests, fallthrough_dest) + end +end + +const Terminator = Union{Switch,IDGotoIfNot,IDGotoNode,ReturnNode} + +mutable struct BBlock + id::ID + inst_ids::Vector{ID} + insts::InstVector + function BBlock(id::ID, inst_ids::Vector{ID}, insts::InstVector) + @assert length(inst_ids) == length(insts) + return new(id, inst_ids, insts) + end +end + +const IDInstPair = Tuple{ID,NewInstruction} + +function BBlock(id::ID, inst_pairs::Vector{IDInstPair}) + return BBlock(id, first.(inst_pairs), last.(inst_pairs)) +end + +Base.length(bb::BBlock) = length(bb.inst_ids) + +Base.copy(bb::BBlock) = BBlock(bb.id, copy(bb.inst_ids), copy(bb.insts)) + +function phi_nodes(bb::BBlock) + n_phi_nodes = findlast(x -> x.stmt isa IDPhiNode, bb.insts) + if n_phi_nodes === nothing + n_phi_nodes = 0 + end + return bb.inst_ids[1:n_phi_nodes], bb.insts[1:n_phi_nodes] +end + +function Base.insert!(bb::BBlock, n::Int, id::ID, inst::NewInstruction)::Nothing + insert!(bb.inst_ids, n, id) + insert!(bb.insts, n, inst) + return nothing +end + +terminator(bb::BBlock) = isa(bb.insts[end].stmt, Terminator) ? bb.insts[end].stmt : nothing + +function insert_before_terminator!(bb::BBlock, id::ID, inst::NewInstruction)::Nothing + insert!(bb, length(bb.insts) + (terminator(bb) === nothing ? 1 : 0), id, inst) + return nothing +end + +collect_stmts(bb::BBlock)::Vector{IDInstPair} = collect(zip(bb.inst_ids, bb.insts)) + +struct BBCode + blocks::Vector{BBlock} + argtypes::Vector{Any} + sptypes::Vector{CC.VarState} + linetable::Vector{Core.LineInfoNode} + meta::Vector{Expr} +end + +function BBCode(ir::Union{IRCode,BBCode}, new_blocks::Vector{BBlock}) + return BBCode( + new_blocks, + CC.copy(ir.argtypes), + CC.copy(ir.sptypes), + CC.copy(ir.linetable), + CC.copy(ir.meta), + ) +end + +# Makes use of the above outer constructor for `BBCode`. +Base.copy(ir::BBCode) = BBCode(ir, copy(ir.blocks)) + +compute_all_successors(ir::BBCode)::Dict{ID,Vector{ID}} = _compute_all_successors(ir.blocks) + +@noinline function _compute_all_successors(blks::Vector{BBlock})::Dict{ID,Vector{ID}} + succs = map(enumerate(blks)) do (n, blk) + is_final_block = n == length(blks) + t = terminator(blk) + if t === nothing + return is_final_block ? ID[] : ID[blks[n + 1].id] + elseif t isa IDGotoNode + return [t.label] + elseif t isa IDGotoIfNot + return is_final_block ? ID[t.dest] : ID[t.dest, blks[n + 1].id] + elseif t isa ReturnNode + return ID[] + elseif t isa Switch + return vcat(t.dests, t.fallthrough_dest) + else + error("Unhandled terminator $t") + end + end + return Dict{ID,Vector{ID}}((b.id, succ) for (b, succ) in zip(blks, succs)) +end + +function compute_all_predecessors(ir::BBCode)::Dict{ID,Vector{ID}} + return _compute_all_predecessors(ir.blocks) +end + +function _compute_all_predecessors(blks::Vector{BBlock})::Dict{ID,Vector{ID}} + successor_map = _compute_all_successors(blks) + + # Initialise predecessor map to be empty. + ks = collect(keys(successor_map)) + predecessor_map = Dict{ID,Vector{ID}}(zip(ks, map(_ -> ID[], ks))) + + # Find all predecessors by iterating through the successor map. + for (k, succs) in successor_map + for succ in succs + push!(predecessor_map[succ], k) + end + end + + return predecessor_map +end + +collect_stmts(ir::BBCode)::Vector{IDInstPair} = reduce(vcat, map(collect_stmts, ir.blocks)) + +function id_to_line_map(ir::BBCode) + lines = collect_stmts(ir) + lines_and_line_numbers = collect(zip(lines, eachindex(lines))) + ids_and_line_numbers = map(x -> (x[1][1], x[2]), lines_and_line_numbers) + return Dict(ids_and_line_numbers) +end + +concatenate_ids(bb_code::BBCode) = reduce(vcat, map(b -> b.inst_ids, bb_code.blocks)) +concatenate_stmts(bb_code::BBCode) = reduce(vcat, map(b -> b.insts, bb_code.blocks)) + +control_flow_graph(bb_code::BBCode)::Core.Compiler.CFG = _control_flow_graph(bb_code.blocks) + +function _control_flow_graph(blks::Vector{BBlock})::Core.Compiler.CFG + + # Get IDs of predecessors and successors. + preds_ids = _compute_all_predecessors(blks) + succs_ids = _compute_all_successors(blks) + + # Construct map from block ID to block number. + block_ids = map(b -> b.id, blks) + id_to_num = Dict{ID,Int}(zip(block_ids, collect(eachindex(block_ids)))) + + # Convert predecessor and successor IDs to numbers. + preds = map(id -> sort(map(p -> id_to_num[p], preds_ids[id])), block_ids) + succs = map(id -> sort(map(s -> id_to_num[s], succs_ids[id])), block_ids) + + index = vcat(0, cumsum(map(length, blks))) .+ 1 + basic_blocks = map(eachindex(blks)) do n + stmt_range = Core.Compiler.StmtRange(index[n], index[n + 1] - 1) + return Core.Compiler.BasicBlock(stmt_range, preds[n], succs[n]) + end + return Core.Compiler.CFG(basic_blocks, index[2:(end - 1)]) +end + +function _lines_to_blocks(insts::InstVector, cfg::CC.CFG)::InstVector + stmts = __line_numbers_to_block_numbers!(Any[x.stmt for x in insts], cfg) + return map((inst, stmt) -> NewInstruction(inst; stmt), insts, stmts) +end + +function __line_numbers_to_block_numbers!(insts::Vector{Any}, cfg::CC.CFG) + for i in eachindex(insts) + stmt = insts[i] + if isa(stmt, GotoNode) + insts[i] = GotoNode(CC.block_for_inst(cfg, stmt.label)) + elseif isa(stmt, GotoIfNot) + insts[i] = GotoIfNot(stmt.cond, CC.block_for_inst(cfg, stmt.dest)) + elseif isa(stmt, PhiNode) + insts[i] = PhiNode( + Int32[CC.block_for_inst(cfg, Int(edge)) for edge in stmt.edges], stmt.values + ) + elseif Meta.isexpr(stmt, :enter) + stmt.args[1] = CC.block_for_inst(cfg, stmt.args[1]::Int) + insts[i] = stmt + end + end + return insts +end + +# +# Converting from IRCode to BBCode +# + +function BBCode(ir::IRCode) + + # Produce a new set of statements with `IDs` rather than `SSAValues` and block numbers. + insts = new_inst_vec(ir.stmts) + ssa_ids, stmts = _ssas_to_ids(insts) + block_ids, stmts = _block_nums_to_ids(stmts, ir.cfg) + + # Chop up the new statements into `BBlocks`, according to the `CFG` in `ir`. + blocks = map(zip(ir.cfg.blocks, block_ids)) do (bb, id) + return BBlock(id, ssa_ids[bb.stmts], stmts[bb.stmts]) + end + return BBCode(ir, blocks) +end + +function new_inst_vec(x::CC.InstructionStream) + stmt = @static VERSION < v"1.11.0-rc4" ? x.inst : x.stmt + return map((v...,) -> NewInstruction(v...), stmt, x.type, x.info, x.line, x.flag) +end + +# Maps from positional names (SSAValues for nodes, Integers for basic blocks) to IDs. +const SSAToIdDict = Dict{SSAValue,ID} +const BlockNumToIdDict = Dict{Integer,ID} + +function _ssas_to_ids(insts::InstVector)::Tuple{Vector{ID},InstVector} + ids = map(_ -> ID(), insts) + val_id_map = SSAToIdDict(zip(SSAValue.(eachindex(insts)), ids)) + return ids, map(Base.Fix1(_ssa_to_ids, val_id_map), insts) +end + +function _ssa_to_ids(d::SSAToIdDict, inst::NewInstruction) + return NewInstruction(inst; stmt=_ssa_to_ids(d, inst.stmt)) +end +function _ssa_to_ids(d::SSAToIdDict, x::ReturnNode) + return isdefined(x, :val) ? ReturnNode(get(d, x.val, x.val)) : x +end +_ssa_to_ids(d::SSAToIdDict, x::Expr) = Expr(x.head, map(a -> get(d, a, a), x.args)...) +_ssa_to_ids(d::SSAToIdDict, x::PiNode) = PiNode(get(d, x.val, x.val), get(d, x.typ, x.typ)) +_ssa_to_ids(d::SSAToIdDict, x::QuoteNode) = x +_ssa_to_ids(d::SSAToIdDict, x) = x +function _ssa_to_ids(d::SSAToIdDict, x::PhiNode) + new_values = Vector{Any}(undef, length(x.values)) + for n in eachindex(x.values) + if isassigned(x.values, n) + new_values[n] = get(d, x.values[n], x.values[n]) + end + end + return PhiNode(x.edges, new_values) +end +_ssa_to_ids(d::SSAToIdDict, x::GotoNode) = x +_ssa_to_ids(d::SSAToIdDict, x::GotoIfNot) = GotoIfNot(get(d, x.cond, x.cond), x.dest) + +function _block_nums_to_ids(insts::InstVector, cfg::CC.CFG)::Tuple{Vector{ID},InstVector} + ids = map(_ -> ID(), cfg.blocks) + block_num_id_map = BlockNumToIdDict(zip(eachindex(cfg.blocks), ids)) + return ids, map(Base.Fix1(_block_num_to_ids, block_num_id_map), insts) +end + +function _block_num_to_ids(d::BlockNumToIdDict, x::NewInstruction) + return NewInstruction(x; stmt=_block_num_to_ids(d, x.stmt)) +end +function _block_num_to_ids(d::BlockNumToIdDict, x::PhiNode) + return IDPhiNode(ID[d[e] for e in x.edges], x.values) +end +_block_num_to_ids(d::BlockNumToIdDict, x::GotoNode) = IDGotoNode(d[x.label]) +_block_num_to_ids(d::BlockNumToIdDict, x::GotoIfNot) = IDGotoIfNot(x.cond, d[x.dest]) +_block_num_to_ids(d::BlockNumToIdDict, x) = x + +# +# Converting from BBCode to IRCode +# + +function CC.IRCode(bb_code::BBCode) + bb_code = _lower_switch_statements(bb_code) + bb_code = _remove_double_edges(bb_code) + insts = _ids_to_line_numbers(bb_code) + cfg = control_flow_graph(bb_code) + insts = _lines_to_blocks(insts, cfg) + return IRCode( + CC.InstructionStream( + map(x -> x.stmt, insts), + map(x -> x.type, insts), + map(x -> x.info, insts), + map(x -> x.line, insts), + map(x -> x.flag, insts), + ), + cfg, + CC.copy(bb_code.linetable), + CC.copy(bb_code.argtypes), + CC.copy(bb_code.meta), + CC.copy(bb_code.sptypes), + ) +end + +function _lower_switch_statements(bb_code::BBCode) + new_blocks = Vector{BBlock}(undef, 0) + for block in bb_code.blocks + t = terminator(block) + if t isa Switch + + # Create new block without the `Switch`. + bb = BBlock(block.id, block.inst_ids[1:(end - 1)], block.insts[1:(end - 1)]) + push!(new_blocks, bb) + + # Create new blocks for each `GotoIfNot` from the `Switch`. + foreach(t.conds, t.dests) do cond, dest + blk = BBlock(ID(), [ID()], [new_inst(IDGotoIfNot(cond, dest), Any)]) + push!(new_blocks, blk) + end + + # Create a new block for the fallthrough dest. + fallthrough_inst = new_inst(IDGotoNode(t.fallthrough_dest), Any) + push!(new_blocks, BBlock(ID(), [ID()], [fallthrough_inst])) + else + push!(new_blocks, block) + end + end + return BBCode(bb_code, new_blocks) +end + +function _ids_to_line_numbers(bb_code::BBCode)::InstVector + + # Construct map from `ID`s to `SSAValue`s. + block_ids = [b.id for b in bb_code.blocks] + block_lengths = map(length, bb_code.blocks) + block_start_ssas = SSAValue.(vcat(1, cumsum(block_lengths)[1:(end - 1)] .+ 1)) + line_ids = concatenate_ids(bb_code) + line_ssas = SSAValue.(eachindex(line_ids)) + id_to_ssa_map = Dict(zip(vcat(block_ids, line_ids), vcat(block_start_ssas, line_ssas))) + + # Apply map. + return [_to_ssas(id_to_ssa_map, stmt) for stmt in concatenate_stmts(bb_code)] +end + +_to_ssas(d::Dict, inst::NewInstruction) = NewInstruction(inst; stmt=_to_ssas(d, inst.stmt)) +_to_ssas(d::Dict, x::ReturnNode) = isdefined(x, :val) ? ReturnNode(get(d, x.val, x.val)) : x +_to_ssas(d::Dict, x::Expr) = Expr(x.head, map(a -> get(d, a, a), x.args)...) +_to_ssas(d::Dict, x::PiNode) = PiNode(get(d, x.val, x.val), get(d, x.typ, x.typ)) +_to_ssas(d::Dict, x::QuoteNode) = x +_to_ssas(d::Dict, x) = x +function _to_ssas(d::Dict, x::IDPhiNode) + new_values = Vector{Any}(undef, length(x.values)) + for n in eachindex(x.values) + if isassigned(x.values, n) + new_values[n] = get(d, x.values[n], x.values[n]) + end + end + return PhiNode(map(e -> Int32(getindex(d, e).id), x.edges), new_values) +end +_to_ssas(d::Dict, x::IDGotoNode) = GotoNode(d[x.label].id) +_to_ssas(d::Dict, x::IDGotoIfNot) = GotoIfNot(get(d, x.cond, x.cond), d[x.dest].id) + +function _remove_double_edges(ir::BBCode) + new_blks = map(enumerate(ir.blocks)) do (n, blk) + t = terminator(blk) + if t isa IDGotoIfNot && t.dest == ir.blocks[n + 1].id + new_insts = vcat(blk.insts[1:(end - 1)], NewInstruction(t; stmt=IDGotoNode(t.dest))) + return BBlock(blk.id, blk.inst_ids, new_insts) + else + return blk + end + end + return BBCode(ir, new_blks) +end + +function characterise_unique_predecessor_blocks( + blks::Vector{BBlock} +)::Tuple{Dict{ID,Bool},Dict{ID,Bool}} + + # Obtain the block IDs in order -- this ensures that we get the entry block first. + blk_ids = ID[b.id for b in blks] + preds = _compute_all_predecessors(blks) + succs = _compute_all_successors(blks) + + # The bulk of blocks can be hanled by this general loop. + is_unique_pred = Dict{ID,Bool}() + for id in blk_ids + ss = succs[id] + is_unique_pred[id] = !isempty(ss) && all(s -> length(preds[s]) == 1, ss) + end + + # If there is a single reachable return node, then that block is treated as a unique + # pred, since control can only pass "out" of the function via this block. Conversely, + # if there are multiple reachable return nodes, then execution can return to the calling + # function via any of them, so they are not unique predecessors. + # Note that the previous block sets is_unique_pred[id] to false for all nodes which + # end with a reachable return node, so the value only needs changing if there is a + # unique reachable return node. + reachable_return_blocks = filter(blks) do blk + is_reachable_return_node(terminator(blk)) + end + if length(reachable_return_blocks) == 1 + is_unique_pred[only(reachable_return_blocks).id] = true + end + + # pred_is_unique_pred is true if the unique predecessor to a block is a unique pred. + pred_is_unique_pred = Dict{ID,Bool}() + for id in blk_ids + pred_is_unique_pred[id] = length(preds[id]) == 1 && is_unique_pred[only(preds[id])] + end + + # If the entry block has no predecessors, then it can only be entered once, when the + # function is first entered. In this case, we treat it as having a unique predecessor. + entry_id = blk_ids[1] + pred_is_unique_pred[entry_id] = isempty(preds[entry_id]) + + return is_unique_pred, pred_is_unique_pred +end + +is_reachable_return_node(x::ReturnNode) = isdefined(x, :val) +is_reachable_return_node(x) = false + +function characterise_used_ids(stmts::Vector{IDInstPair})::Dict{ID,Bool} + ids = first.(stmts) + insts = last.(stmts) + + # Initialise to false. + is_used = Dict{ID,Bool}(zip(ids, fill(false, length(ids)))) + + # Hunt through the instructions, flipping a value in is_used to true whenever an ID + # is encountered which corresponds to an SSA. + for inst in insts + _find_id_uses!(is_used, inst.stmt) + end + return is_used +end + +function _find_id_uses!(d::Dict{ID,Bool}, x::Expr) + for arg in x.args + in(arg, keys(d)) && setindex!(d, true, arg) + end +end +function _find_id_uses!(d::Dict{ID,Bool}, x::IDGotoIfNot) + return in(x.cond, keys(d)) && setindex!(d, true, x.cond) +end +_find_id_uses!(::Dict{ID,Bool}, ::IDGotoNode) = nothing +function _find_id_uses!(d::Dict{ID,Bool}, x::PiNode) + return in(x.val, keys(d)) && setindex!(d, true, x.val) +end +function _find_id_uses!(d::Dict{ID,Bool}, x::IDPhiNode) + v = x.values + for n in eachindex(v) + isassigned(v, n) && in(v[n], keys(d)) && setindex!(d, true, v[n]) + end +end +function _find_id_uses!(d::Dict{ID,Bool}, x::ReturnNode) + return isdefined(x, :val) && in(x.val, keys(d)) && setindex!(d, true, x.val) +end +_find_id_uses!(d::Dict{ID,Bool}, x::QuoteNode) = nothing +_find_id_uses!(d::Dict{ID,Bool}, x) = nothing + +end diff --git a/src/copyable_task.jl b/src/copyable_task.jl new file mode 100644 index 00000000..0dc49940 --- /dev/null +++ b/src/copyable_task.jl @@ -0,0 +1,1029 @@ +""" + get_taped_globals(T::Type) + +When called from inside a call to a `TapedTask`, this will return whatever is contained in +its `taped_globals` field. + +The type `T` is required for optimal performance. If you know that the result of this +operation must return a specific type, specific `T`. If you do not know what type it will +return, pass `Any` -- this will typically yield type instabilities, but will run correctly. + +See also [`set_taped_globals!`](@ref). + +# Extended Help + + +""" +@noinline function get_taped_globals(::Type{T}) where {T} + # This function is `@noinline`d to ensure that the type-unstable items in here do not + # appear in a calling function, and cause allocations. + # + # The return type of `task_local_storage(:task_variable)` is `Any`. To ensure that this + # type instability does not propagate through the rest of the code, we `typeassert` the + # result to be `T`. By doing this, callers of this function will (hopefully) think + # carefully about how they can figure out what type they have put in global storage. + return typeassert(task_local_storage(:task_variable), T) +end + +__v::Int = 5 + +""" + produce(x) + +When run inside a [`TapedTask`](@ref), will immediately yield to the caller, producing value +`x`. + +See also: [`Libtask.consume`](@ref) +""" +@noinline function produce(x) + global __v = 4 # silly side-effect to prevent this call getting constant-folded away. Should really use the effects system. + return x +end + +function callable_ret_type(sig, types) + produce_type = Union{} + for t in types + p = isconcretetype(t) ? ProducedValue{t} : ProducedValue{T} where {T<:t} + produce_type = CC.tmerge(p, produce_type) + end + return Union{Base.code_ircode_by_type(sig)[1][2],produce_type} +end + +function build_callable(sig::Type{<:Tuple}) + key = CacheKey(Base.get_world_counter(), sig) + if haskey(mc_cache, key) + v = fresh_copy(mc_cache[key]) + return v + else + ir = Base.code_ircode_by_type(sig)[1][1] + bb, refs, types = derive_copyable_task_ir(BBCode(ir)) + unoptimised_ir = IRCode(bb) + optimised_ir = optimise_ir!(unoptimised_ir) + mc_ret_type = callable_ret_type(sig, types) + mc = misty_closure(mc_ret_type, optimised_ir, refs...; do_compile=true) + mc_cache[key] = mc + return mc, refs[end] + end +end + +mutable struct TapedTask{Ttaped_globals,Tfargs,Tmc<:MistyClosure} + taped_globals::Ttaped_globals + const fargs::Tfargs + const mc::Tmc + const position::Base.RefValue{Int32} +end + +struct CacheKey + world_age::UInt + key::Any +end + +const mc_cache = Dict{CacheKey,MistyClosure}() + +""" + TapedTask(taped_globals::Any, f, args...; kwargs...) + +Construct a `TapedTask` with the specified `taped_globals`, for function `f`, positional +arguments `args`, and keyword argument `kwargs`. + +# Extended Help + +There are three central features of a `TapedTask`, which we demonstrate via three examples. + +## Resumption + +The function [`Libtask.produce`](@ref) has a special meaning in Libtask. You can insert it +into regular Julia functions anywhere that you like. For example +```jldoctest tt +julia> function f() + for t in 1:2 + produce(t) + t += 1 + end + return nothing + end +f (generic function with 1 method) +``` + +If you construct a `TapedTask` from `f`, and call [`Libtask.consume`](@ref) on it, you'll +see +```jldoctest tt +julia> t = TapedTask(nothing, f); + +julia> consume(t) +1 +``` +The semantics of this are that [`Libtask.consume`](@ref) runs the function `f` until it +reaches the call to [`Libtask.produce`](@ref), at which point it will return the argument +to [`Libtask.produce`](@ref). + +Subsequent calls to [`Libtask.produce`](@ref) will _resume_ execution of `f` immediately +after the last [`Libtask.produce`](@ref) statement that was hit. +```jldoctest tt +julia> consume(t) +2 +``` + +When there are no more [`Libtask.produce`](@ref) statements to hit, calling +[`Libtask.consume`](@ref) will return `nothing`: +```jldoctest tt +julia> consume(t) + +``` + +## Copying + +[`TapedTask`](@ref)s can be copied. Doing so creates a completely independent object. +For example: +```jldoctest tt +julia> t2 = TapedTask(nothing, f); + +julia> consume(t2) +1 +``` + +If we make a copy and advance its state, it produces the same value that the original would +have produced: +```jldoctest tt +julia> t3 = copy(t2); + +julia> consume(t3) +2 +``` + +Moreover, advancing the state of the copy has not advanced the state of the original, +because they are completely independent copies: +```jldoctest tt +julia> consume(t2) +2 +``` + +## TapedTask-Specific Globals + +It is often desirable to permit a copy of a task and the original to differ in very specific +ways. For example, in the context of Sequential Monte Carlo, you might want the only +difference between two copies to be their random number generator. + +A generic mechanism is available to achieve this. [`Libtask.get_taped_globals`](@ref) and +[`Libtask.set_taped_globals!`](@ref) let you set and retrieve a variable which is specific +to a given [`Libtask.TapedTask`](@ref). The former can be called inside a function: +```jldoctest sv +julia> function f() + produce(get_taped_globals(Int)) + produce(get_taped_globals(Int)) + return nothing + end +f (generic function with 1 method) +``` + +The first argument to [`Libtask.TapedTask`](@ref) is the value that +[`Libtask.get_taped_globals`](@ref) will return: +```jldoctest sv +julia> t = TapedTask(1, f); + +julia> consume(t) +1 +``` + +The value that it returns can be changed between [`Libtask.consume`](@ref) calls: +```jldoctest sv +julia> set_taped_globals!(t, 2) + +julia> consume(t) +2 +``` + +`Int`s have been used here, but it is permissible to set the value returned by +[`Libtask.get_taped_globals`](@ref) to anything you like. +""" +function TapedTask(taped_globals::Any, fargs...; kwargs...) + all_args = isempty(kwargs) ? fargs : (Core.kwcall, getfield(kwargs, :data), fargs...) + seed_id!() + mc, count_ref = build_callable(typeof(all_args)) + return TapedTask(taped_globals, all_args, mc, count_ref) +end + +function fresh_copy(mc::T) where {T<:MistyClosure} + new_captures = map(mc.oc.captures) do r + if eltype(r) <: DynamicCallable + return Base.RefValue(DynamicCallable()) + elseif eltype(r) <: LazyCallable + return _typeof(r)(eltype(r)()) + else + return _typeof(r)() + end + end + new_position = new_captures[end] + new_position[] = -1 + return replace_captures(mc, new_captures), new_position +end + +""" + set_taped_globals!(t::TapedTask, new_taped_globals)::Nothing + +Set the `taped_globals` of `t` to `new_taped_globals`. Any calls to +[`get_taped_globals`](@ref) in future calls to `consume(t)` (either directly, or implicitly +via iteration) will see this new value. +""" +function set_taped_globals!(t::TapedTask{T}, new_taped_globals::T)::Nothing where {T} + t.taped_globals = new_taped_globals + return nothing +end + +""" + Base.copy(t::TapedTask) + +Makes a completely independent copy of `t`. `consume` can be applied to either the copy of +`t` or the original without advancing the state of the other. +""" +Base.copy(t::T) where {T<:TapedTask} = deepcopy(t) + +""" + consume(t::TapedTask) + +Run `t` until it makes a call to `produce`. If this is the first time that `t` has been +called, it start execution from the entry point. If `consume` has previously been called on +`t`, it will resume from the last `produce` call. If there are no more `produce` calls, +`nothing` will be returned. +""" +@inline function consume(t::TapedTask) + task_local_storage(:task_variable, t.taped_globals) + v = t.mc.oc(t.fargs...) + return v isa ProducedValue ? v[] : nothing +end + +""" + might_produce(sig::Type{<:Tuple})::Bool + +`true` if a call to method with signature `sig` is permitted to contain +`Libtask.produce` statements. + +This is an opt-in mechanism. the fallback method of this function returns `false` indicating +that, by default, we assume that calls do not contain `Libtask.produce` statements. +""" +might_produce(::Type{<:Tuple}) = false + +# Helper struct used in `derive_copyable_task_ir`. +struct TupleRef + n::Int +end + +# Unclear whether this is needed. +get_value(x::GlobalRef) = getglobal(x.mod, x.name) +get_value(x::QuoteNode) = x.value +get_value(x) = x + +""" + is_produce_stmt(x)::Bool + +`true` if `x` is an expression of the form `Expr(:call, produce, %x)` or a similar `:invoke` +expression, otherwise `false`. +""" +function is_produce_stmt(x)::Bool + if Meta.isexpr(x, :invoke) && length(x.args) == 3 + return get_value(x.args[2]) === produce + elseif Meta.isexpr(x, :call) && length(x.args) == 2 + return get_value(x.args[1]) === produce + else + return false + end +end + +""" + stmt_might_produce(x, ret_type::Type)::Bool + +`true` if `x` might contain a call to `produce`, and `false` otherwise. +""" +function stmt_might_produce(x, ret_type::Type)::Bool + + # Statement will terminate in an unusual fashion, so don't bother recursing. + # This isn't _strictly_ correct (there could be a `produce` statement before the + # `throw` call is hit), but this seems unlikely to happen in practice. If it does, the + # user should get a sensible error message anyway. + ret_type == Union{} && return false + + # Statement will terminate in the usual fashion, so _do_ bother recusing. + is_produce_stmt(x) && return true + Meta.isexpr(x, :invoke) && return might_produce(x.args[1].specTypes) + if Meta.isexpr(x, :call) + # This is a hack -- it's perfectly possible for `DataType` calls to produce in general. + f = get_function(x.args[1]) + _might_produce = !isa(f, Union{Core.IntrinsicFunction,Core.Builtin,DataType}) + return _might_produce + end + return false +end + +get_function(x) = x +get_function(x::Expr) = eval(x) +get_function(x::GlobalRef) = isconst(x) ? getglobal(x.mod, x.name) : x.binding + +""" + produce_value(x::Expr) + +Returns the value that a `produce` statement returns. For example, for the statment +`produce(%x)`, this function will return `%x`. +""" +function produce_value(x::Expr) + is_produce_stmt(x) || throw(error("Not a produce statement. Please report this error.")) + Meta.isexpr(x, :invoke) && return x.args[3] + return x.args[2] # must be a `:call` Expr. +end + +struct ProducedValue{T} + x::T +end +ProducedValue(::Type{T}) where {T} = ProducedValue{Type{T}}(T) + +@inline Base.getindex(x::ProducedValue) = x.x + +""" + inc_args(stmt::T)::T where {T} + +Returns a new `T` which is equal to `stmt`, except any `Argument`s present in `stmt` are +incremented by `1`. For example +```jldoctest +julia> Libtask.inc_args(Core.ReturnNode(Core.Argument(1))) +:(return _2) +``` +""" +inc_args(x::Expr) = Expr(x.head, map(__inc, x.args)...) +inc_args(x::ReturnNode) = isdefined(x, :val) ? ReturnNode(__inc(x.val)) : x +inc_args(x::IDGotoIfNot) = IDGotoIfNot(__inc(x.cond), x.dest) +inc_args(x::IDGotoNode) = x +function inc_args(x::IDPhiNode) + new_values = Vector{Any}(undef, length(x.values)) + for n in eachindex(x.values) + if isassigned(x.values, n) + new_values[n] = __inc(x.values[n]) + end + end + return IDPhiNode(x.edges, new_values) +end +inc_args(::Nothing) = nothing +inc_args(x::GlobalRef) = x +inc_args(x::Core.PiNode) = Core.PiNode(__inc(x.val), __inc(x.typ)) + +__inc(x::Argument) = Argument(x.n + 1) +__inc(x) = x + +const TypeInfo = Tuple{Vector{Any},Dict{ID,Type}} + +""" + _typeof(x) + +Central definition of typeof, which is specific to the use-required in this package. +""" +_typeof(x) = Base._stable_typeof(x) +_typeof(x::Tuple) = Tuple{map(_typeof, x)...} +_typeof(x::NamedTuple{names}) where {names} = NamedTuple{names,_typeof(Tuple(x))} + +""" + get_type(info::ADInfo, x) + +Returns the static / inferred type associated to `x`. +""" +get_type(info::TypeInfo, x::Argument) = info[1][x.n - 1] +get_type(info::TypeInfo, x::ID) = CC.widenconst(info[2][x]) +get_type(::TypeInfo, x::QuoteNode) = _typeof(x.value) +get_type(::TypeInfo, x) = _typeof(x) +function get_type(::TypeInfo, x::GlobalRef) + return isconst(x) ? _typeof(getglobal(x.mod, x.name)) : x.binding.ty +end +function get_type(::TypeInfo, x::Expr) + x.head === :boundscheck && return Bool + return error("Unrecognised expression $x found in argument slot.") +end + +function derive_copyable_task_ir(ir::BBCode)::Tuple{BBCode,Tuple,Vector{Any}} + + # The location from which all state can be retrieved. Since we're using `OpaqueClosure`s + # to implement `TapedTask`s, this appears via the first argument. + refs_id = Argument(1) + + # Increment all arguments by 1. + for bb in ir.blocks, (n, inst) in enumerate(bb.insts) + bb.insts[n] = CC.NewInstruction( + inc_args(inst.stmt), inst.type, inst.info, inst.line, inst.flag + ) + end + + # Construct map between SSA IDs and their index in the state data structure and back. + # Also construct a map from each ref index to its type. We only construct `Ref`s + # for statements which return a value e.g. `IDGotoIfNot`s do not have a meaningful + # return value, so there's no need to allocate a `Ref` for them. + ssa_id_to_ref_index_map = Dict{ID,Int}() + ref_index_to_ssa_id_map = Dict{Int,ID}() + ref_index_to_type_map = Dict{Int,Type}() + id_to_type_map = Dict{ID,Type}() + is_used_dict = characterise_used_ids(collect_stmts(ir)) + n = 0 + for bb in ir.blocks + for (id, stmt) in zip(bb.inst_ids, bb.insts) + id_to_type_map[id] = CC.widenconst(stmt.type) + stmt.stmt isa IDGotoNode && continue + stmt.stmt isa IDGotoIfNot && continue + stmt.stmt === nothing && continue + stmt.stmt isa ReturnNode && continue + is_used_dict[id] || continue + n += 1 + ssa_id_to_ref_index_map[id] = n + ref_index_to_ssa_id_map[n] = id + ref_index_to_type_map[n] = CC.widenconst(stmt.type) + end + end + + # Specify data structure containing `Ref`s for all of the SSAs. + _refs = Any[Ref{ref_index_to_type_map[p]}() for p in 1:length(ref_index_to_ssa_id_map)] + + # Ensure that each basic block ends with a non-producing statement. This is achieved by + # replacing any fall-through terminators with `IDGotoNode`s. This is not strictly + # necessary, but simplifies later stages of the pipeline, as discussed variously below. + for (n, block) in enumerate(ir.blocks) + if terminator(block) === nothing + # Fall-through terminator, so next block in `ir.blocks` is the unique successor + # block of `block`. Final block cannot have a fall-through terminator, so asking + # for element `n + 1` is always going to be valid. + successor_id = ir.blocks[n + 1].id + push!(block.insts, new_inst(IDGotoNode(successor_id))) + push!(block.inst_ids, ID()) + end + end + + # For each existing basic block, create a sequence of `NamedTuple`s which + # define the manner in which it must be split. + # A block will in general be split as follows: + # 1 - %1 = φ(...) + # 1 - %2 = φ(...) + # 1 - %3 = call_which_must_not_produce(...) + # 1 - %4 = produce(%3) + # 2 - %5 = call_which_must_not_produce(...) + # 2 - %6 = call_which_might_produce(...) + # 3 - %7 = call_which_must_not_produce(...) + # 3 - terminator (GotoIfNot, GotoNode, etc) + # + # The numbers on the left indicate which split each statement falls. The first + # split comprises all statements up until the first produce / call-which-might-produce. + # Consequently, the first split will always contain any `PhiNode`s present in the block. + # The next set of statements up until the next produce / call-which-might-produce form + # the second split, and so on. + # We enforced above the condition that the final statement in a basic block must not + # produce. This ensures that the final split does not produce. While not strictly + # necessary, this simplifies the implementation (see below). + # + # As a result of the above, a basic block will be associated to exactly one split if it + # does not contain any statements which may produce. + # + # Each `NamedTuple` contains a `start` index and `last` index, indicating the position + # in the block at which the corresponding split starts and finishes. + all_splits = map(ir.blocks) do block + split_ends = vcat( + findall( + inst -> stmt_might_produce(inst.stmt, CC.widenconst(inst.type)), + block.insts, + ), + length(block), + ) + return map(enumerate(split_ends)) do (n, split_end) + return (start=(n == 1 ? 0 : split_ends[n - 1]) + 1, last=split_end) + end + end + + # Owing to splitting blocks up, we will need to re-label some `GotoNode`s and + # `GotoIfNot`s. To understand this, consider the following block, whose original `ID` + # we assume to be `ID(old_id)`. + # ID(new_id) - %1 = φ(ID(3) => ...) + # ID(new_id) - %3 = call_which_must_not_produce(...) + # ID(new_id) - %4 = produce(%3) + # ID(old_id) - GotoNode(ID(5)) + # + # In the above, the entire block was original associated to a single ID, `ID(old_id)`, + # but is now split into two sections. We keep the original ID for the final split, and + # assign a new one to the first split. As a result, any `PhiNode`s in other blocks + # which have edges incoming from `ID(old_id)` will remain valid. + # However, if we adopt this strategy for all blocks, `ID(5)` in the `GotoNode` at the + # end of the block will refer to the wrong block if the block original associated to + # `ID(5)` was itself split, since the "top" of that block will have a new `ID`. + # + # To resolve this, we: + # 1. Associate an ID to each split in each block, ensuring that the ID for the final + # split of each block is the same ID as that of the original block. + all_split_ids = map(zip(ir.blocks, all_splits)) do (block, splits) + return vcat([ID() for _ in splits[1:(end - 1)]], block.id) + end + + # 2. Construct a map between the ID of each block and the ID associated to its split. + top_split_id_map = Dict{ID,ID}(b.id => x[1] for (b, x) in zip(ir.blocks, all_split_ids)) + + # 3. Update all `GotoNode`s and `GotoIfNot`s to refer to these new names. + for block in ir.blocks + t = terminator(block) + if t isa IDGotoNode + block.insts[end] = new_inst(IDGotoNode(top_split_id_map[t.label])) + elseif t isa IDGotoIfNot + block.insts[end] = new_inst(IDGotoIfNot(t.cond, top_split_id_map[t.dest])) + end + end + + # A set of blocks from which we might wish to resume computation. + resume_block_ids = Vector{ID}() + + # A list onto which we'll push the type of any statement which might produce. + possible_produce_types = Any[] + + # This where most of the action happens. + # + # For each split of each block, we must + # 1. translate all statements which accept any SSAs as arguments, or return a value, + # into statements which read in data from the `Ref`s containing the value associated + # to each SSA, and write the result to `Ref`s associated to the SSA of the line in + # question. + # 2. add additional code at the end of the split to handle the possibility that the + # last statement produces (per the definition of the splits above). This applies to + # all splits except the last, which cannot produce by construction. Exactly what + # happens here depends on whether the last statement is a `produce` call, or a + # call-which-might-produce -- see below for specifics. + # + # This code transforms each block (and its splits) into a new collection of blocks. + # Note that the total number of new blocks may be greater than the total number of + # splits, because each split ending in a call-which-might-produce requires more than a + # single block to implement the required resumption functionality. + new_bblocks = map(zip(ir.blocks, all_splits, all_split_ids)) do (bb, splits, splits_ids) + new_blocks = map(enumerate(splits)) do (n, split) + # We'll push ID-NewInstruction pairs to this as we proceed through the split. + inst_pairs = IDInstPair[] + + # PhiNodes: + # + # A single `PhiNode` + # + # ID(%1) = φ(ID(#1) => 1, ID(#2) => ID(%n)) + # + # sets `ID(%1)` to either `1` or whatever value is currently associated to + # `ID(%n)`, depending upon whether the predecessor block was `ID(#1)` or + # `ID(#2)`. Consequently, a single `PhiNode` can be transformed into something + # along the lines of: + # + # ID(%1) = φ(ID(#1) => 1, ID(#2) => TupleRef(ref_ind_for_ID(%n))) + # ID(%2) = deref_phi(refs, ID(%1)) + # set_ref_at!(refs, ref_ind_for_ID(%1), ID(%2)) + # + # where `deref_phi` retrives the value in position `ref_ind_for_ID(%n)` if + # ID(%1) is a `TupleRef`, and `1` otherwise, and `set_ref_at!` sets the `Ref` + # at position `ref_ind_for_ID(%1)` to the value of `ID(%2)`. See the actual + # implementations below. + # + # If we have multiple `PhiNode`s at the start of a block, we must run all of + # them, then dereference all of their values, and finally write all of the + # de-referenced values to the appropriate locations. This is because + # a. we require all `PhiNode`s appear together at the top of a given basic + # block, and + # b. the semantics of `PhiNode`s is that they are all "run" simultaneously. This + # only matters if one `PhiNode` in the block can refer to the value stored in + # the SSA associated to another. For example, something along the lines of: + # + # ID(%1) = φ(ID(#1) => 1, ID(#2) => ID(%2)) + # ID(%2) = φ(ID(#1) => 1, ID(#2) => 2) + # + # (we leave it as an exercise for the reader to figure out why this particular + # semantic feature of `PhiNode`s is relevant in this specific case). + # + # So, in general, the code produced by this block will look roughly like + # + # ID(%1) = φ(...) + # ID(%2) = φ(...) + # ID(%3) = φ(...) + # ID(%4) = deref_phi(refs, ID(%1)) + # ID(%5) = deref_phi(refs, ID(%2)) + # ID(%6) = deref_phi(refs, ID(%3)) + # set_ref_at!(refs, ref_ind_for_ID(%1), ID(%4)) + # set_ref_at!(refs, ref_ind_for_ID(%2), ID(%5)) + # set_ref_at!(refs, ref_ind_for_ID(%3), ID(%6)) + if n == 1 + # Find all PhiNodes in the block -- will definitely be in this split. + phi_inds = findall(x -> x.stmt isa IDPhiNode, bb.insts) + + # Replace SSA IDs with `TupleRef`s, and record these instructions. + phi_ids = map(phi_inds) do n + phi = bb.insts[n].stmt + for i in eachindex(phi.values) + isassigned(phi.values, i) || continue + v = phi.values[i] + v isa ID || continue + phi.values[i] = TupleRef(ssa_id_to_ref_index_map[v]) + end + phi_id = ID() + push!(inst_pairs, (phi_id, new_inst(phi, Any))) + return phi_id + end + + # De-reference values associated to `IDPhiNode`s. + deref_ids = map(phi_inds) do n + id = bb.inst_ids[n] + phi_id = phi_ids[n] + push!( + inst_pairs, + (id, new_inst(Expr(:call, deref_phi, refs_id, phi_id))), + ) + return id + end + + # Update values stored in `Ref`s associated to `PhiNode`s. + for n in phi_inds + ref_ind = ssa_id_to_ref_index_map[bb.inst_ids[n]] + expr = Expr(:call, set_ref_at!, refs_id, ref_ind, deref_ids[n]) + push!(inst_pairs, (ID(), new_inst(expr))) + end + end + + # Statements which do not produce: + # + # Iterate every statement in the split other than the final one, replacing uses + # of SSAs with de-referenced `Ref`s, and writing the results of statements to + # the corresponding `Ref`s. + _ids = view(bb.inst_ids, (split.start):(split.last - 1)) + _insts = view(bb.insts, (split.start):(split.last - 1)) + for (id, inst) in zip(_ids, _insts) + stmt = inst.stmt + if Meta.isexpr(stmt, :invoke) || + Meta.isexpr(stmt, :call) || + Meta.isexpr(stmt, :new) || + Meta.isexpr(stmt, :foreigncall) + + # Find any `ID`s and replace them with calls to read whatever is stored + # in the `Ref`s that they are associated to. + for (n, arg) in enumerate(stmt.args) + arg isa ID || continue + + new_id = ID() + ref_ind = ssa_id_to_ref_index_map[arg] + expr = Expr(:call, get_ref_at, refs_id, ref_ind) + push!(inst_pairs, (new_id, new_inst(expr))) + stmt.args[n] = new_id + end + + # Push the target instruction to the list. + push!(inst_pairs, (id, inst)) + + # If we know it is not possible for this statement to contain any calls + # to produce, then simply write out the result to its `Ref`. If it is + # never used, then there is no need to store it. + if is_used_dict[id] + out_ind = ssa_id_to_ref_index_map[id] + set_ref = Expr(:call, set_ref_at!, refs_id, out_ind, id) + push!(inst_pairs, (ID(), new_inst(set_ref))) + end + elseif Meta.isexpr(stmt, :boundscheck) + push!(inst_pairs, (id, inst)) + elseif Meta.isexpr(stmt, :code_coverage_effect) + push!(inst_pairs, (id, inst)) + elseif Meta.isexpr(stmt, :gc_preserve_begin) + push!(inst_pairs, (id, inst)) + elseif Meta.isexpr(stmt, :gc_preserve_end) + push!(inst_pairs, (id, inst)) + elseif Meta.isexpr(stmt, :throw_undef_if_not) + push!(inst_pairs, (id, inst)) + elseif stmt isa Nothing + push!(inst_pairs, (id, inst)) + elseif stmt isa GlobalRef + ref_ind = ssa_id_to_ref_index_map[id] + expr = Expr(:call, set_ref_at!, refs_id, ref_ind, stmt) + push!(inst_pairs, (id, new_inst(expr))) + elseif stmt isa Core.PiNode + if stmt.val isa ID + ref_ind = ssa_id_to_ref_index_map[stmt.val] + val_id = ID() + expr = Expr(:call, get_ref_at, refs_id, ref_ind) + push!(inst_pairs, (val_id, new_inst(expr))) + push!(inst_pairs, (id, new_inst(Core.PiNode(val_id, stmt.typ)))) + else + push!(inst_pairs, (id, inst)) + end + set_ind = ssa_id_to_ref_index_map[id] + set_expr = Expr(:call, set_ref_at!, refs_id, set_ind, id) + push!(inst_pairs, (ID(), new_inst(set_expr))) + elseif stmt isa IDPhiNode + # do nothing -- we've already handled any `PhiNode`s. + else + throw(error("Unhandled stmt $stmt of type $(typeof(stmt))")) + end + end + + # TODO: explain this better. + new_blocks = BBlock[] + + # Produce and Terminators: + # + # Handle the last statement in the split. + id = bb.inst_ids[split.last] + inst = bb.insts[split.last] + stmt = inst.stmt + if n == length(splits) + # This is the last split in the block, so it must end with a non-producing + # terminator. We handle this in a similar way to the statements above. + + if stmt isa ReturnNode + # If returning an SSA, it might be one whose value was restored from + # before. Therefore, grab it out of storage, rather than assuming that + # it is def-ed. + if isdefined(stmt, :val) && stmt.val isa ID + ref_ind = ssa_id_to_ref_index_map[stmt.val] + val_id = ID() + expr = Expr(:call, get_ref_at, refs_id, ref_ind) + push!(inst_pairs, (val_id, new_inst(expr))) + push!(inst_pairs, (ID(), new_inst(ReturnNode(val_id)))) + else + push!(inst_pairs, (id, inst)) + end + elseif stmt isa IDGotoIfNot + # If the condition is an SSA, it might be one whose value was restored + # from before. Therefore, grab it out of storage, rather than assuming + # that it is defined. + if stmt.cond isa ID + ref_ind = ssa_id_to_ref_index_map[stmt.cond] + cond_id = ID() + expr = Expr(:call, get_ref_at, refs_id, ref_ind) + push!(inst_pairs, (cond_id, new_inst(expr))) + push!(inst_pairs, (ID(), new_inst(IDGotoIfNot(cond_id, stmt.dest)))) + else + push!(inst_pairs, (id, inst)) + end + elseif stmt isa IDGotoNode + push!(inst_pairs, (id, inst)) + else + error("Unexpected terminator $stmt") + end + push!(new_blocks, BBlock(splits_ids[n], inst_pairs)) + elseif is_produce_stmt(stmt) + # This is a statement of the form + # %n = produce(arg) + # + # We transform this into + # Libtask.set_resume_block!(refs_id, id_of_next_block) + # return ProducedValue(arg) + # + # The point is to ensure that, next time that this `TapedTask` is called, + # computation is resumed from the statement _after_ this produce statement, + # and to return whatever this produce statement returns. + + # Log the result type of this statement. + arg = stmt.args[Meta.isexpr(stmt, :invoke) ? 3 : 2] + push!(possible_produce_types, get_type((ir.argtypes, id_to_type_map), arg)) + + # When this TapedTask is next called, we should resume from the first + # statement of the next split. + resume_id = splits_ids[n + 1] + push!(resume_block_ids, resume_id) + + # Insert statement to enforce correct resumption behaviour. + resume_stmt = Expr(:call, set_resume_block!, refs_id, resume_id.id) + push!(inst_pairs, (ID(), new_inst(resume_stmt))) + + # Insert statement to construct a `ProducedValue` from the value. + # Could be that the produce references an SSA, in which case we need to + # de-reference, rather than just return the thing. + prod_val = produce_value(stmt) + if prod_val isa ID + deref_id = ID() + ref_ind = ssa_id_to_ref_index_map[prod_val] + expr = Expr(:call, get_ref_at, refs_id, ref_ind) + push!(inst_pairs, (deref_id, new_inst(expr))) + prod_val = deref_id + end + + # Construct a `ProducedValue`. + val_id = ID() + push!(inst_pairs, (val_id, new_inst(Expr(:call, ProducedValue, prod_val)))) + + # Insert statement to return the `ProducedValue`. + push!(inst_pairs, (ID(), new_inst(ReturnNode(val_id)))) + + # Construct a single new basic block from all of the inst-pairs. + push!(new_blocks, BBlock(splits_ids[n], inst_pairs)) + else + # The final statement is one which might produce, but is not itself a + # `produce` statement. For example + # y = f(x) + # + # becomes (morally speaking) + # y = f(x) + # if y isa ProducedValue + # set_resume_block!(refs_id, id_of_current_block) + # return y + # end + # + # The point is to ensure that, if `f` "produces" (as indicated by `y` being + # a `ProducedValue`) then the next time that this TapedTask is called, we + # must resume from the call to `f`, as subsequent runs might also produce. + # On the other hand, if anything other than a `ProducedValue` is returned, + # we know that `f` has nothing else to produce, and execution can safely + # continue to the next split. + # In addition to the above, we must do the usual thing and ensure that any + # ssas are read from storage, and write the result of this computation to + # storage before continuing to the next instruction. + # + # You should look at the IR generated by a simple example in the test suite + # which involves calls that might produce, in order to get a sense of what + # the resulting code looks like prior to digging into the code below. + + # At present, we're not able to properly infer the values which might + # potentially be produced by a call-which-might-produce. Consequently, we + # have to assume they can produce anything. + push!(possible_produce_types, Any) + + # Create a new basic block from the existing statements, since all new + # statement need to live in their own basic blocks. + callable_block_id = ID() + push!(inst_pairs, (ID(), new_inst(IDGotoNode(callable_block_id)))) + push!(new_blocks, BBlock(splits_ids[n], inst_pairs)) + + # Derive TapedTask for this statement. + (callable, callable_args) = if Meta.isexpr(stmt, :invoke) + sig = stmt.args[1].specTypes + v = Any[Any] + (LazyCallable{sig,callable_ret_type(sig, v)}(), stmt.args[2:end]) + elseif Meta.isexpr(stmt, :call) + (DynamicCallable(), stmt.args) + else + display(stmt) + println() + error("unhandled statement which might produce $stmt") + end + + # Find any `ID`s and replace them with calls to read whatever is stored + # in the `Ref`s that they are associated to. + callable_inst_pairs = IDInstPair[] + for (n, arg) in enumerate(callable_args) + arg isa ID || continue + + new_id = ID() + ref_ind = ssa_id_to_ref_index_map[arg] + expr = Expr(:call, get_ref_at, refs_id, ref_ind) + push!(callable_inst_pairs, (new_id, new_inst(expr))) + callable_args[n] = new_id + end + + # Allocate a slot in the _refs vector for this callable. + push!(_refs, Ref(callable)) + callable_ind = length(_refs) + + # Retrieve the callable from the refs. + callable_id = ID() + callable_stmt = Expr(:call, get_ref_at, refs_id, callable_ind) + push!(callable_inst_pairs, (callable_id, new_inst(callable_stmt))) + + # Call the callable. + result_id = ID() + result_stmt = Expr(:call, callable_id, callable_args...) + push!(callable_inst_pairs, (result_id, new_inst(result_stmt))) + + # Determine whether this TapedTask has produced a not-a-`ProducedValue`. + not_produced_id = ID() + not_produced_stmt = Expr(:call, not_a_produced, result_id) + push!(callable_inst_pairs, (not_produced_id, new_inst(not_produced_stmt))) + + # Go to a block which just returns the `ProducedValue`, if a + # `ProducedValue` is returned, otherwise continue to the next split. + is_produced_block_id = ID() + is_not_produced_block_id = ID() + switch = Switch( + Any[not_produced_id], + [is_produced_block_id], + is_not_produced_block_id, + ) + push!(callable_inst_pairs, (ID(), new_inst(switch))) + + # Push the above statements onto a new block. + push!(new_blocks, BBlock(callable_block_id, callable_inst_pairs)) + + # Construct block which handles the case that we got a `ProducedValue`. If + # this happens, it means that `callable` has more things to produce still. + # This means that we need to call it again next time we enter this function. + # To achieve this, we set the resume block to the `callable_block_id`, + # and return the `ProducedValue` currently located in `result_id`. + push!(resume_block_ids, callable_block_id) + set_res = Expr(:call, set_resume_block!, refs_id, callable_block_id.id) + return_id = ID() + produced_block_inst_pairs = IDInstPair[ + (ID(), new_inst(set_res)), + (return_id, new_inst(ReturnNode(result_id))), + ] + push!(new_blocks, BBlock(is_produced_block_id, produced_block_inst_pairs)) + + # Construct block which handles the case that we did not get a + # `ProducedValue`. In this case, we must first push the result to the `Ref` + # associated to the call, and goto the next split. + next_block_id = splits_ids[n + 1] # safe since the last split ends with a terminator + if is_used_dict[id] + result_ref_ind = ssa_id_to_ref_index_map[id] + set_ref = Expr(:call, set_ref_at!, refs_id, result_ref_ind, result_id) + else + set_ref = nothing + end + not_produced_block_inst_pairs = IDInstPair[ + (ID(), new_inst(set_ref)) + (ID(), new_inst(IDGotoNode(next_block_id))) + ] + push!( + new_blocks, + BBlock(is_not_produced_block_id, not_produced_block_inst_pairs), + ) + end + return new_blocks + end + return reduce(vcat, new_blocks) + end + new_bblocks = reduce(vcat, new_bblocks) + + # Insert statements at the top. + cases = map(resume_block_ids) do id + return ID(), id, Expr(:call, resume_block_is, refs_id, id.id) + end + cond_ids = ID[x[1] for x in cases] + cond_dests = ID[x[2] for x in cases] + cond_stmts = Any[x[3] for x in cases] + switch_stmt = Switch(Any[x for x in cond_ids], cond_dests, first(new_bblocks).id) + entry_stmts = vcat(cond_stmts, nothing, switch_stmt) + entry_block = BBlock(ID(), vcat(cond_ids, ID(), ID()), map(new_inst, entry_stmts)) + new_bblocks = vcat(entry_block, new_bblocks) + + # New argtypes are the same as the old ones, except we have `Ref`s in the first argument + # rather than nothing at all. + new_argtypes = copy(ir.argtypes) + refs = (_refs..., Ref{Int32}(-1)) + new_argtypes = vcat(typeof(refs), copy(ir.argtypes)) + + # Return BBCode and the `Ref`s. + new_ir = BBCode(new_bblocks, new_argtypes, ir.sptypes, ir.linetable, ir.meta) + return new_ir, refs, possible_produce_types +end + +# Helper used in `derive_copyable_task_ir`. +@inline get_ref_at(refs::R, n::Int) where {R<:Tuple} = refs[n][] + +# Helper used in `derive_copyable_task_ir`. +@inline function set_ref_at!(refs::R, n::Int, val) where {R<:Tuple} + refs[n][] = val + return nothing +end + +# Helper used in `derive_copyable_task_ir`. +@inline function set_resume_block!(refs::R, id::Int32) where {R<:Tuple} + refs[end][] = id + return nothing +end + +# Helper used in `derive_copyable_task_ir`. +@inline resume_block_is(refs::R, id::Int32) where {R<:Tuple} = !(refs[end][] === id) + +# Helper used in `derive_copyable_task_ir`. +@inline deref_phi(refs::R, n::TupleRef) where {R<:Tuple} = refs[n.n][] +@inline deref_phi(::R, x) where {R<:Tuple} = x + +# Helper used in `derived_copyable_task_ir`. +@inline not_a_produced(x) = !(isa(x, ProducedValue)) + +# Implement iterator interface. +function Base.iterate(t::TapedTask, state::Nothing=nothing) + v = consume(t) + return v === nothing ? nothing : (v, nothing) +end +Base.IteratorSize(::Type{<:TapedTask}) = Base.SizeUnknown() +Base.IteratorEltype(::Type{<:TapedTask}) = Base.EltypeUnknown() + +""" + +""" +mutable struct LazyCallable{sig<:Tuple,Tret} + mc::MistyClosure + position::Base.RefValue{Int32} + LazyCallable{sig,Tret}() where {sig,Tret} = new{sig,Tret}() +end + +function (l::LazyCallable)(args::Vararg{Any,N}) where {N} + isdefined(l, :mc) || construct_callable!(l) + return l.mc(args...) +end + +function construct_callable!(l::LazyCallable{sig}) where {sig} + mc, pos = build_callable(sig) + l.mc = mc + l.position = pos + return nothing +end + +mutable struct DynamicCallable{V} + cache::V +end + +DynamicCallable() = DynamicCallable(Dict{Any,Any}()) + +function (dynamic_callable::DynamicCallable)(args::Vararg{Any,N}) where {N} + sig = _typeof(args) + callable = get(dynamic_callable.cache, sig, nothing) + if callable === nothing + callable = build_callable(sig) + dynamic_callable.cache[sig] = callable + end + return callable[1](args...) +end diff --git a/src/tapedfunction.jl b/src/tapedfunction.jl deleted file mode 100644 index 5e85b87b..00000000 --- a/src/tapedfunction.jl +++ /dev/null @@ -1,514 +0,0 @@ -#= -`TapedFunction` converts a Julia function to a friendly tape for user-specified interpreters. -With this tape-like abstraction for functions, we gain some control over how a function is -executed, like capturing continuations, caching variables, injecting additional control flows -(i.e., produce/consume) between instructions on the tape, etc. - -Under the hood, we first used Julia's compiler API to get the IR code of the original function. -We use the unoptimized typed code in a non-strict SSA form. Then we convert each IR instruction -to a Julia data structure (an object of a subtype of AbstractInstruction). All the operands -(i.e., the variables) these instructions use are stored in a data structure called `Bindings`. -This conversion/binding process is performed at compile-time / tape-recording time and is only -done once for each function. - -In a nutshell, there are two types of instructions (or primitives) on a tape: - - Ordinary function call - - Control-flow instruction: GotoInstruction and CondGotoInstruction, ReturnInstruction - -Once the tape is recorded, we can run the tape just like calling the original function. -We first plugin the arguments, run each instruction on the tape and stop after encountering -a ReturnInstruction. We also provide a mechanism to add a callback after each instruction. -This API allowed us to implement the `produce/consume` mechanism in TapedTask. And exploiting -these features, we implemented a fork mechanism for TapedTask. - -Some potentially sharp edges of this implementation: - - 1. GlobalRef is evaluated at the tape-recording time (compile-time). Most times, - the value/object associated with a GlobalRef does not change at run time. - So this works well. But, if you do something like `module A v=1 end; make tapedfunction; A.eval(:(v=2)); run tf;`, - The assignment won't work. - 2. QuoteNode is also evaluated at the tape-recording time (compile-time). Primarily - the result of evaluating a QuoteNode is a Symbol, which usually works well. - 3. Each Instruction execution contains one unnecessary allocation at the moment. - So writing a function with vectorized computation will be more performant, - for example, using broadcasting instead of a loop. -=# - -const LOGGING = Ref(false) - -## Instruction and TapedFunction -abstract type AbstractInstruction end -const RawTape = Vector{AbstractInstruction} - -function _infer(f, args_type) - # `code_typed` returns a vector: [Pair{Core.CodeInfo, DataType}] - ir0 = code_typed(f, Tuple{args_type...}, optimize=false)[1][1] - return ir0 -end - -const Bindings = Vector{Any} - -mutable struct TapedFunction{F, TapeType} - func::F # maybe a function, a constructor, or a callable object - arity::Int - ir::Core.CodeInfo - tape::TapeType - counter::Int - binding_values::Bindings - arg_binding_slots::Vector{Int} # arg indices in binding_values - retval_binding_slot::Int # 0 indicates the function has not returned - deepcopy_types::Type # use a Union type for multiple types - - function TapedFunction{F, T}(f::F, args...; cache=false, deepcopy_types=Union{}) where {F, T} - args_type = _accurate_typeof.(args) - cache_key = (f, deepcopy_types, args_type...) - - if cache && haskey(TRCache, cache_key) # use cache - cached_tf = TRCache[cache_key]::TapedFunction{F, T} - tf = copy(cached_tf) - tf.counter = 1 - return tf - end - ir = _infer(f, args_type) - binding_values, slots, tape = translate!(RawTape(), ir) - - tf = new{F, T}(f, length(args), ir, tape, 1, binding_values, slots, 0, deepcopy_types) - TRCache[cache_key] = tf # set cache - return tf - end - - TapedFunction(f, args...; cache=false, deepcopy_types=Union{}) = - TapedFunction{typeof(f), RawTape}(f, args...; cache=cache, deepcopy_types=deepcopy_types) - - function TapedFunction{F, T0}(tf::TapedFunction{F, T1}) where {F, T0, T1} - new{F, T0}(tf.func, tf.arity, tf.ir, tf.tape, - tf.counter, tf.binding_values, tf.arg_binding_slots, 0, tf.deepcopy_types) - end - - TapedFunction(tf::TapedFunction{F, T}) where {F, T} = TapedFunction{F, T}(tf) -end - -const TRCache = LRU{Tuple, TapedFunction}(maxsize=10) -const CompiledTape = Vector{FunctionWrapper{Nothing, Tuple{TapedFunction}}} - -function Base.convert(::Type{CompiledTape}, tape::RawTape) - ctape = CompiledTape(undef, length(tape)) - for idx in 1:length(tape) - ctape[idx] = FunctionWrapper{Nothing, Tuple{TapedFunction}}(tape[idx]) - end - return ctape -end - -compile(tf::TapedFunction{F, RawTape}) where {F} = TapedFunction{F, CompiledTape}(tf) - -@inline _lookup(tf::TapedFunction, v::Int) = @inbounds tf.binding_values[v] -@inline _update_var!(tf::TapedFunction, v::Int, c) = @inbounds tf.binding_values[v] = c - -""" - Instruction - -An `Instruction` stands for a function call -""" -struct Instruction{F, N} <: AbstractInstruction - func::F - input::NTuple{N, Int} - output::Int -end - -struct GotoInstruction <: AbstractInstruction - # we enusre a 1-to-1 mapping between ir.code and instruction - # so here we can use the index directly. - dest::Int -end - -struct CondGotoInstruction <: AbstractInstruction - condition::Int - dest::Int -end - -struct ReturnInstruction <: AbstractInstruction - arg::Int -end - -struct NOOPInstruction <: AbstractInstruction end - -@inline result(t::TapedFunction) = t.binding_values[t.retval_binding_slot] -@inline function _arg(tf::TapedFunction, i::Int; default=nothing) - length(tf.arg_binding_slots) < i && return default - tf.arg_binding_slots[i] > 0 && return tf.binding_values[tf.arg_binding_slots[i]] - return default -end -@inline function _arg!(tf::TapedFunction, i::Int, v) - length(tf.arg_binding_slots) >= i && - tf.arg_binding_slots[i] > 0 && _update_var!(tf, tf.arg_binding_slots[i], v) -end - -function (tf::TapedFunction)(args...; callback=nothing, continuation=false) - if !continuation # reset counter and retval_binding_slot to run from the start - tf.counter = 1 - tf.retval_binding_slot = 0 - end - - # set args - if tf.counter <= 1 - # The first slot in `binding_values` is assumed to be `tf.func`. - _arg!(tf, 1, tf.func) - for i in 1:length(args) # the subsequent arg_binding_slots are arguments - slot = i + 1 - _arg!(tf, slot, args[i]) - end - end - - # run the raw tape - while true - ins = tf.tape[tf.counter] - ins(tf) - callback !== nothing && callback() - tf.retval_binding_slot != 0 && break - end - return result(tf) -end - -function Base.show(io::IO, tf::TapedFunction) - # we use an extra IOBuffer to collect all the data and then - # output it once to avoid output interrupt during task context - # switching - buf = IOBuffer() - println(buf, "TapedFunction:") - println(buf, "* .func => $(tf.func)") - println(buf, "* .ir =>") - println(buf, "------------------") - println(buf, tf.ir) - println(buf, "------------------") - print(io, String(take!(buf))) -end - -function Base.show(io::IO, rtape::RawTape) - buf = IOBuffer() - print(buf, length(rtape), "-element RawTape") - isempty(rtape) || println(buf, ":") - i = 1 - for instr in rtape - print(buf, "\t", i, " => ") - show(buf, instr) - i += 1 - end - print(io, String(take!(buf))) -end - -## methods for Instruction -Base.show(io::IO, instr::AbstractInstruction) = println(io, "A ", typeof(instr)) - -function Base.show(io::IO, instr::Instruction) - println(io, "Instruction(", instr.output, "=", instr.func, instr.input) -end - -function Base.show(io::IO, instr::GotoInstruction) - println(io, "GotoInstruction(dest=", instr.dest, ")") -end - -function Base.show(io::IO, instr::CondGotoInstruction) - println(io, "CondGotoInstruction(", instr.condition, ", dest=", instr.dest, ")") -end - -function (instr::Instruction{F})(tf::TapedFunction) where F - # catch run-time exceptions / errors. - try - func = F === Int ? _lookup(tf, instr.func) : instr.func - inputs = map(x -> _lookup(tf, x), instr.input) - output = func(inputs...) - _update_var!(tf, instr.output, output) - tf.counter += 1 - catch e - println("counter=", tf.counter) - println("tf=", tf) - println(e, catch_backtrace()); - rethrow(e); - end -end - -function (instr::GotoInstruction)(tf::TapedFunction) - tf.counter = instr.dest -end - -function (instr::CondGotoInstruction)(tf::TapedFunction) - cond = _lookup(tf, instr.condition) - if cond - tf.counter += 1 - else # goto dest unless cond - tf.counter = instr.dest - end -end - -function (instr::ReturnInstruction)(tf::TapedFunction) - tf.retval_binding_slot = instr.arg -end - -function (instr::NOOPInstruction)(tf::TapedFunction) - tf.counter += 1 -end - -## internal functions -_accurate_typeof(v) = typeof(v) -_accurate_typeof(::Type{V}) where V = Type{V} -_loose_type(t) = t -_loose_type(::Type{Type{T}}) where T = isa(T, DataType) ? Type{T} : typeof(T) - -""" - __new__(T, args...) - -Return a new instance of `T` with `args` even when there is no inner constructor for these args. -Source: https://discourse.julialang.org/t/create-a-struct-with-uninitialized-fields/6967/5 -""" -@generated function __new__(T, args...) - return Expr(:splatnew, :T, :args) -end - - -## Translation: CodeInfo -> Tape - -const IRVar = Union{Core.SSAValue, Core.SlotNumber} - -function bind_var!(var, bindings::Bindings, ir::Core.CodeInfo) - # for literal constants, and static parameters - var = Meta.isexpr(var, :static_parameter) ? ir.parent.sparam_vals[var.args[1]] : var - push!(bindings, var) - idx = length(bindings) - return idx -end -function bind_var!(var::GlobalRef, bindings::Bindings, ir::Core.CodeInfo) - in(var.mod, (Base, Core)) || - LOGGING[] && @info "evaluating GlobalRef $var at compile time" - bind_var!(getproperty(var.mod, var.name), bindings, ir) -end -function bind_var!(var::QuoteNode, bindings::Bindings, ir::Core.CodeInfo) - LOGGING[] && @info "evaluating QuoteNode $var at compile time" - bind_var!(eval(var), bindings, ir) -end -function bind_var!(var::TypedSlot, bindings::Bindings, ir::Core.CodeInfo) - # turn TypedSlot to SlotNumber - bind_var!(Core.SlotNumber(var.id), bindings, ir) -end -function bind_var!(var::Core.SlotNumber, bindings::Bindings, ir::Core.CodeInfo) - get!(bindings[1], var, allocate_binding!(var, bindings, ir.slottypes[var.id])) -end -function bind_var!(var::Core.SSAValue, bindings::Bindings, ir::Core.CodeInfo) - get!(bindings[1], var, allocate_binding!(var, bindings, ir.ssavaluetypes[var.id])) -end -allocate_binding!(var, bindings::Bindings, c::Core.Const) = - allocate_binding!(var, bindings, _loose_type(Type{_accurate_typeof(c.val)})) - -allocate_binding!(var, bindings::Bindings, c::Core.PartialStruct) = - allocate_binding!(var, bindings, _loose_type(c.typ)) -function allocate_binding!(var, bindings::Bindings, ::Type{T}) where T - # we may use the type info (T) here - push!(bindings, nothing) - idx = length(bindings) - return idx -end - -function translate!(tape::RawTape, ir::Core.CodeInfo) - binding_values = Bindings() - sizehint!(binding_values, 128) - bcache = Dict{IRVar, Int}() - # the first slot of binding_values is used to store a cache at compile time - push!(binding_values, bcache) - slots = Dict{Int, Int}() - - for (idx, line) in enumerate(ir.code) - isa(line, Core.Const) && (line = line.val) # unbox Core.Const - isconst = isa(ir.ssavaluetypes[idx], Core.Const) - ins = translate!!(Core.SSAValue(idx), line, binding_values, isconst, ir) - push!(tape, ins) - end - for (k, v) in bcache - isa(k, Core.SlotNumber) && (slots[k.id] = v) - end - arg_binding_slots = fill(0, maximum(keys(slots); init=0)) - for (k, v) in slots - arg_binding_slots[k] = v - end - binding_values[1] = 0 # drop bcache - return (binding_values, arg_binding_slots, tape) -end - -function _const_instruction(var::IRVar, v, bindings::Bindings, ir) - if isa(var, Core.SSAValue) - box = bind_var!(var, bindings, ir) - bindings[box] = v - return NOOPInstruction() - end - return Instruction(identity, (bind_var!(v, bindings, ir),), bind_var!(var, bindings, ir)) -end - -function translate!!(var::IRVar, line::Core.NewvarNode, - bindings::Bindings, isconst::Bool, @nospecialize(ir)) - # use a no-op to ensure the 1-to-1 mapping from ir.code to instructions on tape. - return NOOPInstruction() -end - -function translate!!(var::IRVar, line::GlobalRef, - bindings::Bindings, isconst::Bool, ir) - if isconst - v = ir.ssavaluetypes[var.id].val - return _const_instruction(var, v, bindings, ir) - end - func() = getproperty(line.mod, line.name) - return Instruction(func, (), bind_var!(var, bindings, ir)) -end - -function translate!!(var::IRVar, line::Core.SlotNumber, - bindings::Bindings, isconst::Bool, ir) - if isconst - v = ir.ssavaluetypes[var.id].val - return _const_instruction(var, v, bindings, ir) - end - func = identity - input = (bind_var!(line, bindings, ir),) - output = bind_var!(var, bindings, ir) - return Instruction(func, input, output) -end - -function translate!!(var::IRVar, line::Number, # literal vars - bindings::Bindings, isconst::Bool, ir) - func = identity - input = (bind_var!(line, bindings, ir),) - output = bind_var!(var, bindings, ir) - return Instruction(func, input, output) -end - -function translate!!(var::IRVar, line::NTuple{N, Symbol}, - bindings::Bindings, isconst::Bool, ir) where {N} - # for syntax (; x, y, z), see Turing.jl#1873 - func = identity - input = (bind_var!(line, bindings, ir),) - output = bind_var!(var, bindings, ir) - return Instruction(func, input, output) -end - -function translate!!(var::IRVar, line::TypedSlot, - bindings::Bindings, isconst::Bool, ir) - input_box = bind_var!(Core.SlotNumber(line.id), bindings, ir) - return Instruction(identity, (input_box,), bind_var!(var, bindings, ir)) -end - -function translate!!(var::IRVar, line::Core.GotoIfNot, - bindings::Bindings, isconst::Bool, ir) - cond = bind_var!(line.cond, bindings, ir) - return CondGotoInstruction(cond, line.dest) -end - -function translate!!(var::IRVar, line::Core.GotoNode, - bindings::Bindings, isconst::Bool, @nospecialize(ir)) - return GotoInstruction(line.label) -end - -function translate!!(var::IRVar, line::Core.ReturnNode, - bindings::Bindings, isconst::Bool, ir) - return ReturnInstruction(bind_var!(line.val, bindings, ir)) -end - -_canbeoptimized(v) = isa(v, DataType) || isprimitivetype(typeof(v)) -function translate!!(var::IRVar, line::Expr, - bindings::Bindings, isconst::Bool, ir::Core.CodeInfo) - head = line.head - _bind_fn = (x) -> bind_var!(x, bindings, ir) - if head === :new - args = map(_bind_fn, line.args) - return Instruction(__new__, args |> Tuple, _bind_fn(var)) - elseif head === :call - # Only some of the function calls can be optimized even though many of their results are - # inferred as constants: we only optimize primitive and datatype constants for now. For - # optimised function calls, we will evaluate the function at compile-time and cache results. - if isconst - v = ir.ssavaluetypes[var.id].val - _canbeoptimized(v) && return _const_instruction(var, v, bindings, ir) - end - args = map(_bind_fn, line.args) - # args[1] is the function - func = line.args[1] - if Meta.isexpr(func, :static_parameter) # func is a type parameter - func = ir.parent.sparam_vals[func.args[1]] - elseif isa(func, GlobalRef) - func = getproperty(func.mod, func.name) # Staging out global reference variable (constants). - else # a var? - func = args[1] # a var(box) - end - return Instruction(func, args[2:end] |> Tuple, _bind_fn(var)) - elseif head === :(=) - # line.args[1] (the left hand side) is a SlotNumber, and it should be the output - lhs = line.args[1] - rhs = line.args[2] # the right hand side, maybe a Expr, or a var, or ... - if Meta.isexpr(rhs, (:new, :call)) - return translate!!(lhs, rhs, bindings, false, ir) - else # rhs is a single value - if isconst - v = ir.ssavaluetypes[var.id].val - return Instruction(identity, (_bind_fn(v),), _bind_fn(lhs)) - end - return Instruction(identity, (_bind_fn(rhs),), _bind_fn(lhs)) - end - elseif head === :static_parameter - v = ir.parent.sparam_vals[line.args[1]] - return Instruction(identity, (_bind_fn(v),), _bind_fn(var)) - else - @error "Unknown Expression: " typeof(var) var typeof(line) line - throw(ErrorException("Unknown Expression")) - end -end - -function translate!!(var, line, bindings, isconst, ir) - @error "Unknown IR code: " typeof(var) var typeof(line) line - throw(ErrorException("Unknown IR code")) -end - -## copy Bindings, TapedFunction - -""" - tape_shallowcopy(x) - tape_deepcopy(x) - -Function `tape_shallowcopy` and `tape_deepcopy` are used to copy data -while copying a TapedFunction. A value in the bindings of a -TapedFunction is either `tape_shallowcopy`ed or `tape_deepcopy`ed. For -TapedFunction, all types are shallow copied by default, and you can -specify some types to be deep copied by giving the `deepcopy_types` -kwyword argument while constructing a TapedFunction. - -The default behaviour of `tape_shallowcopy` is, we return its argument -untouched, like `identity` does, i.e., `tape_copy(x) = x`. The default -behaviour of `tape_deepcopy` is, we call `deepcopy` on its argument -and return the result, `tape_deepcopy(x) = deepcopy(x)`. If one wants -some kinds of data to be copied (shallowly or deeply) in a different -way, one can overload these functions. - -""" -function tape_shallowcopy end, function tape_deepcopy end - -tape_shallowcopy(x) = x -tape_deepcopy(x) = deepcopy(x) - -# Core.Box is used as closure captured variable container, so we should tape_copy its contents -_tape_copy(box::Core.Box, deepcopy_types) = Core.Box(_tape_copy(box.contents, deepcopy_types)) - -function _tape_copy(v, deepcopy_types) - if isa(v, deepcopy_types) - tape_deepcopy(v) - else - tape_shallowcopy(v) - end -end - -function copy_bindings(old::Bindings, deepcopy_types) - newb = copy(old) - for k in 1:length(old) - newb[k] = _tape_copy(old[k], deepcopy_types) - end - return newb -end - -function Base.copy(tf::TapedFunction) - new_tf = TapedFunction(tf) - new_tf.binding_values = copy_bindings(tf.binding_values, tf.deepcopy_types) - return new_tf -end diff --git a/src/tapedtask.jl b/src/tapedtask.jl deleted file mode 100644 index c96f4720..00000000 --- a/src/tapedtask.jl +++ /dev/null @@ -1,219 +0,0 @@ -struct TapedTaskException - exc::Exception - backtrace::Vector{Any} -end - -struct TapedTask{F, AT<:Tuple} - task::Task - tf::TapedFunction{F} - args::AT - produce_ch::Channel{Any} - consume_ch::Channel{Int} - produced_val::Vector{Any} - - function TapedTask( - t::Task, - tf::TapedFunction{F}, - args::AT, - produce_ch::Channel{Any}, - consume_ch::Channel{Int} - ) where {F, AT<:Tuple} - new{F, AT}(t, tf, args, produce_ch, consume_ch, Any[]) - end -end - -function producer() - ttask = current_task().storage[:tapedtask]::TapedTask - if length(ttask.produced_val) > 0 - val = pop!(ttask.produced_val) - put!(ttask.produce_ch, val) - take!(ttask.consume_ch) # wait for next consumer - end - return nothing -end - -function wrap_task(tf, produce_ch, consume_ch, args...) - try - tf(args...; callback=producer, continuation=true) - catch e - bt = catch_backtrace() - put!(produce_ch, TapedTaskException(e, bt)) - # @error "TapedTask Error: " exception=(e, bt) - rethrow() - finally - @static if VERSION >= v"1.4" - # we don't do this under Julia 1.3, because `isempty` always hangs on - # an empty channel. - while !isempty(produce_ch) - yield() - end - end - close(produce_ch) - close(consume_ch) - end -end - -function TapedTask(tf::TapedFunction, args...) - produce_ch = Channel() - consume_ch = Channel{Int}() - task = @task wrap_task(tf, produce_ch, consume_ch, args...) - t = TapedTask(task, tf, args, produce_ch, consume_ch) - task.storage === nothing && (task.storage = IdDict()) - task.storage[:tapedtask] = t - return t -end - -BASE_COPY_TYPES = Union{Array, Ref} - -# NOTE: evaluating model without a trace, see -# https://github.com/TuringLang/Turing.jl/pull/1757#diff-8d16dd13c316055e55f300cd24294bb2f73f46cbcb5a481f8936ff56939da7ceR329 -function TapedTask(f, args...; deepcopy_types=nothing) # deepcoy Array and Ref by default. - if isnothing(deepcopy_types) - deepcopy = BASE_COPY_TYPES - else - deepcopy = Union{BASE_COPY_TYPES, deepcopy_types} - end - tf = TapedFunction(f, args...; cache=true, deepcopy_types=deepcopy) - TapedTask(tf, args...) -end - -TapedTask(finfo::Tuple{Any, Type}, args...) = TapedTask(finfo[1], args...; deepcopy_types=finfo[2]) -TapedTask(t::TapedTask, args...) = TapedTask(func(t), args...; deepcopy_types=t.tf.deepcopy_types) -func(t::TapedTask) = t.tf.func - -#= -# ** Approach (A) to implement `produce`: -# Make`produce` a standalone instturction. This approach does NOT -# support `produce` in a nested call -function internal_produce(instr::Instruction, val) - tf = instr.tape - ttask = tf.owner - put!(ttask.produce_ch, val) - take!(ttask.consume_ch) # wait for next consumer -end - -function produce(val) - error("Libtask.produce can only be directly called in a task!") -end - -function (instr::Instruction{typeof(produce)})() - args = val(instr.input[1]) - internal_produce(instr, args) -end -=# - - -# ** Approach (B) to implement `produce`: -# This way has its caveat: -# `produce` may deeply hide in an instruction, but not be an instruction -# itself, and when we copy a task, the newly copied task will resume from -# the instruction after the one which contains this `produce` call. If the -# call to `produce` is not the last expression in the instuction, that -# instruction will not be whole executed in the copied task. -@inline function is_in_tapedtask() - ct = current_task() - ct.storage === nothing && return false - haskey(ct.storage, :tapedtask) || return false - # check if we are recording a tape - ttask = ct.storage[:tapedtask]::TapedTask - return !isempty(ttask.tf.tape) -end - -function produce(val) - is_in_tapedtask() || return nothing - ttask = current_task().storage[:tapedtask]::TapedTask - length(ttask.produced_val) > 1 && - error("There is a produced value which is not consumed.") - push!(ttask.produced_val, val) - return nothing -end - -function consume(ttask::TapedTask) - if istaskstarted(ttask.task) - # tell producer that a consumer is coming - put!(ttask.consume_ch, 0) - else - schedule(ttask.task) - end - - val = try - take!(ttask.produce_ch) - catch e - isa(e, InvalidStateException) || rethrow() - istaskfailed(ttask.task) && throw(ttask.task.exception) - # TODO: we return nothing to indicate the end of a task, - # remove this when AdvancedPS is udpated. - istaskdone(ttask.task) && return nothing - end - - # yield to let the task resume, this is necessary when there's - # an exception is thrown in the task, it gives the task the chance - # to rethow the exception and set its proper status: - yield() - isa(val, TapedTaskException) && throw(val.exc) - return val -end - -# Iteration interface. -function Base.iterate(t::TapedTask, state=nothing) - try - consume(t), nothing - catch ex - !isa(ex, InvalidStateException) && rethrow - nothing - end -end -Base.IteratorSize(::Type{<:TapedTask}) = Base.SizeUnknown() -Base.IteratorEltype(::Type{<:TapedTask}) = Base.EltypeUnknown() - - -# copy the task - -function Base.copy(t::TapedTask; args=()) - length(args) > 0 && t.tf.counter >1 && - error("can't copy started task with new arguments") - tf = copy(t.tf) - task_args = if length(args) > 0 - # this cond implies t.tf.counter == 0, i.e., the task is not started yet - typeof(args) == typeof(t.args) || error("bad arguments") - args - else - if t.tf.counter > 1 - # the task is running, we find the real args from the copied binding_values - map(1:length(t.args)) do i - s = i + 1 - _arg(tf, s; default=t.args[i]) - end - else - # the task is not started yet, but no args is given - map(a -> _tape_copy(a, t.tf.deepcopy_types), t.args) - end - end - new_t = TapedTask(tf, task_args...) - storage = t.task.storage::IdDict{Any,Any} - new_t.task.storage = copy(storage) - new_t.task.storage[:tapedtask] = new_t - return new_t -end - -# TArray and TRef back-compat -function TArray(args...) - Base.depwarn("`TArray` is deprecated, please use `Array` instead.", :TArray) - Array(args...) -end -function TArray(T::Type, dim) - Base.depwarn("`TArray` is deprecated, please use `Array` instead.", :TArray) - Array{T}(undef, dim) -end -function tzeros(args...) - Base.depwarn("`tzeros` is deprecated, please use `zeros` instead.", :tzeros) - zeros(args...) -end -function tfill(args...) - Base.depwarn("`tfill` is deprecated, please use `fill` instead.", :tzeros) - fill(args...) -end -function TRef(x) - Base.depwarn("`TRef` is deprecated, please use `Ref` instead.", :TRef) - Ref(x) -end diff --git a/src/test_utils.jl b/src/test_utils.jl new file mode 100644 index 00000000..961f635b --- /dev/null +++ b/src/test_utils.jl @@ -0,0 +1,357 @@ +module TestUtils + +using ..Libtask +using Test +using ..Libtask: TapedTask + +# Function barrier to ensure inference in value types. +function count_allocs(f::F, x::Vararg{Any,N}) where {F,N} + @allocations f(x...) +end + +@enum PerfFlag none allocs + +struct Testcase + name::String + taped_globals::Any + fargs::Tuple + kwargs::Union{NamedTuple,Nothing} + expected_iteration_results::Vector + perf::PerfFlag +end + +function (case::Testcase)() + testset = @testset "$(case.name)" begin + + # Display some information. + @info "$(case.name)" + + # Construct the task. + if case.kwargs === nothing + t = TapedTask(case.taped_globals, case.fargs...) + else + t = TapedTask(case.taped_globals, case.fargs...; case.kwargs...) + end + + # Iterate through t. Record the results, and take a copy after each iteration. + iteration_results = [] + t_copies = [copy(t)] + for val in t + push!(iteration_results, val) + push!(t_copies, copy(t)) + end + + # Check that iterating the original task gives the expected results. + @test iteration_results == case.expected_iteration_results + + # Check that iterating the copies yields the correct results. + for (n, t_copy) in enumerate(t_copies) + @test iteration_results[n:end] == collect(t_copy) + end + + # Check no allocations if requested. + if case.perf == allocs + + # Construct the task. + if case.kwargs === nothing + t = TapedTask(case.taped_globals, case.fargs...) + else + t = TapedTask(case.taped_globals, case.fargs...; case.kwargs...) + end + + for _ in iteration_results + @test count_allocs(consume, t) == 0 + end + end + end + return testset +end + +function test_cases() + return Testcase[ + Testcase( + "single block", + nothing, + (single_block, 5.0), + nothing, + [sin(5.0), sin(sin(5.0)), sin(sin(sin(5.0))), sin(sin(sin(sin(5.0))))], + allocs, + ), + Testcase( + "produce old", + nothing, + (produce_old_value, 5.0), + nothing, + [sin(5.0), sin(5.0)], + allocs, + ), + Testcase( + "branch on old value l", + nothing, + (branch_on_old_value, 2.0), + nothing, + [true, 2.0], + allocs, + ), + Testcase( + "branch on old value r", + nothing, + (branch_on_old_value, -1.0), + nothing, + [false, -2.0], + allocs, + ), + Testcase("no produce", nothing, (no_produce_test, 5.0, 4.0), nothing, [], allocs), + Testcase( + "new object", + nothing, + (new_object_test, 5, 4), + nothing, + [C(5, 4), C(5, 4)], + none, + ), + Testcase( + "branching test l", + nothing, + (branching_test, 5.0, 4.0), + nothing, + [complex(sin(5.0))], + allocs, + ), + Testcase( + "branching test r", + nothing, + (branching_test, 4.0, 5.0), + nothing, + [sin(4.0) * cos(5.0)], + allocs, + ), + Testcase( + "unused argument test", nothing, (unused_argument_test, 3), nothing, [1], allocs + ), + Testcase("test with const", nothing, (test_with_const,), nothing, [1], allocs), + Testcase("while loop", nothing, (while_loop,), nothing, collect(1:9), allocs), + Testcase( + "foreigncall tester", + nothing, + (foreigncall_tester, "hi"), + nothing, + [Ptr{UInt8}, Ptr{UInt8}], + allocs, + ), + Testcase("globals tester 1", 5, (taped_globals_tester_1,), nothing, [5], allocs), + Testcase("globals tester 2", 6, (taped_globals_tester_1,), nothing, [6], none), + Testcase( + "globals tester 3", 6, (while_loop_with_globals,), nothing, fill(6, 9), allocs + ), + Testcase( + "nested (static)", nothing, (static_nested_outer,), nothing, [true, false], none + ), + Testcase( + "nested (static + used)", + nothing, + (static_nested_outer_use_produced,), + nothing, + [true, 1], + none, + ), + Testcase( + "nested (dynamic)", + nothing, + (dynamic_nested_outer, Ref{Any}(nested_inner)), + nothing, + [true, false], + none, + ), + Testcase( + "nested (dynamic + used)", + nothing, + (dynamic_nested_outer_use_produced, Ref{Any}(nested_inner)), + nothing, + [true, 1], + none, + ), + Testcase( + "callable struct", nothing, (CallableStruct(5), 4), nothing, [5, 4, 9], allocs + ), + Testcase( + "kwarg tester 1", + nothing, + (Core.kwcall, (; y=5.0), kwarg_tester, 4.0), + nothing, + [], + allocs, + ), + Testcase("kwargs tester 2", nothing, (kwarg_tester, 4.0), (; y=5.0), [], allocs), + Testcase( + "default kwarg tester", + nothing, + (default_kwarg_tester, 4.0), + nothing, + [], + allocs, + ), + Testcase( + "default kwarg tester", nothing, (default_kwarg_tester, 4.0), (;), [], allocs + ), + Testcase( + "final statment produce", + nothing, + (final_statement_produce,), + nothing, + [1, 2], + allocs, + ), + ] +end + +function single_block(x::Float64) + x1 = sin(x) + produce(x1) + x2 = sin(x1) + produce(x2) + x3 = sin(x2) + produce(x3) + x4 = sin(x3) + produce(x4) + return cos(x4) +end + +function produce_old_value(x::Float64) + v = sin(x) + produce(v) + produce(v) + return nothing +end + +function branch_on_old_value(x::Float64) + b = x > 0 + produce(b) + produce(b ? x : 2x) + return nothing +end + +function no_produce_test(x, y) + c = x + y + return (; c, x, y) +end + +# Old test case without any produce statements used to test TapedFunction. Since this +# doesn't exist as a distinct entity anymore, not clear that this test case is useful. +mutable struct C + i::Int + C(x, y) = new(x + y) +end + +Base.:(==)(c::C, d::C) = c.i == d.i + +function new_object_test(x, y) + c = C(x, y) + produce(c) + produce(c) + return nothing +end + +function branching_test(x, y) + if x > y + produce(complex(sin(x))) + else + produce(sin(x) * cos(y)) + end + return nothing +end + +function unused_argument_test(x) + produce(1) + return nothing +end + +function test_with_const() + # this line generates: %1 = 1::Core.Const(1) + r = (a = 1) + produce(r) + return nothing +end + +function while_loop() + t = 1 + while t < 10 + produce(t) + t = 1 + t + end + return nothing +end + +function foreigncall_tester(s::String) + ptr = ccall(:jl_string_ptr, Ptr{UInt8}, (Any,), s) + produce(typeof(ptr)) + produce(typeof(ptr)) + return nothing +end + +function taped_globals_tester_1() + produce(Libtask.get_taped_globals(Int)) + return nothing +end + +function while_loop_with_globals() + t = 1 + while t < 10 + produce(get_taped_globals(Int)) + t = 1 + t + end + return nothing +end + +@noinline function nested_inner() + produce(true) + return 1 +end + +Libtask.might_produce(::Type{Tuple{typeof(nested_inner)}}) = true + +function static_nested_outer() + nested_inner() + produce(false) + return nothing +end + +function static_nested_outer_use_produced() + y = nested_inner() + produce(y) + return nothing +end + +function dynamic_nested_outer(f::Ref{Any}) + f[]() + produce(false) + return nothing +end + +function dynamic_nested_outer_use_produced(f::Ref{Any}) + y = f[]() + produce(y) + return nothing +end + +struct CallableStruct{T} + x::T +end + +function (c::CallableStruct)(y) + produce(c.x) + produce(y) + produce(c.x + y) + return nothing +end + +kwarg_tester(x; y) = x + y + +default_kwarg_tester(x; y=5.0) = x * y + +function final_statement_produce() + produce(1) + return produce(2) +end + +end diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 00000000..e4449657 --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,204 @@ + +""" + replace_captures(oc::Toc, new_captures) where {Toc<:OpaqueClosure} + +Given an `OpaqueClosure` `oc`, create a new `OpaqueClosure` of the same type, but with new +captured variables. This is needed for efficiency reasons -- if `build_rrule` is called +repeatedly with the same signature and intepreter, it is important to avoid recompiling +the `OpaqueClosure`s that it produces multiple times, because it can be quite expensive to +do so. +""" +function replace_captures(oc::Toc, new_captures) where {Toc<:OpaqueClosure} + return __replace_captures_internal(oc, new_captures) +end + +@eval function __replace_captures_internal(oc::Toc, new_captures) where {Toc<:OpaqueClosure} + return $(Expr( + :new, :(Toc), :new_captures, :(oc.world), :(oc.source), :(oc.invoke), :(oc.specptr) + )) +end + +""" + replace_captures(mc::Tmc, new_captures) where {Tmc<:MistyClosure} + +Same as `replace_captures` for `Core.OpaqueClosure`s, but returns a new `MistyClosure`. +""" +function replace_captures(mc::Tmc, new_captures) where {Tmc<:MistyClosure} + return Tmc(replace_captures(mc.oc, new_captures), mc.ir) +end + +""" + optimise_ir!(ir::IRCode, show_ir=false) + +Run a fairly standard optimisation pass on `ir`. If `show_ir` is `true`, displays the IR +to `stdout` at various points in the pipeline -- this is sometimes useful for debugging. +""" +function optimise_ir!(ir::IRCode; show_ir=false, do_inline=true) + if show_ir + println("Pre-optimization") + display(ir) + println() + end + CC.verify_ir(ir) + ir = __strip_coverage!(ir) + ir = CC.compact!(ir) + local_interp = CC.NativeInterpreter() + # local_interp = BugPatchInterpreter() # 319 -- see patch_for_319.jl for context + mi = __get_toplevel_mi_from_ir(ir, @__MODULE__) + ir = __infer_ir!(ir, local_interp, mi) + if show_ir + println("Post-inference") + display(ir) + println() + end + inline_state = CC.InliningState(local_interp) + CC.verify_ir(ir) + if do_inline + ir = CC.ssa_inlining_pass!(ir, inline_state, true) #=propagate_inbounds=# + ir = CC.compact!(ir) + end + ir = __strip_coverage!(ir) + ir = CC.sroa_pass!(ir, inline_state) + + @static if VERSION < v"1.11-" + ir = CC.adce_pass!(ir, inline_state) + else + ir, _ = CC.adce_pass!(ir, inline_state) + end + + ir = CC.compact!(ir) + # CC.verify_ir(ir, true, false, CC.optimizer_lattice(local_interp)) + CC.verify_linetable(ir.linetable, true) + if show_ir + println("Post-optimization") + display(ir) + println() + end + return ir +end + +@static if VERSION < v"1.11.0" + get_inference_world(interp::CC.AbstractInterpreter) = CC.get_world_counter(interp) +else + get_inference_world(interp::CC.AbstractInterpreter) = CC.get_inference_world(interp) +end + +# Given some IR, generates a MethodInstance suitable for passing to infer_ir!, if you don't +# already have one with the right argument types. Credit to @oxinabox: +# https://gist.github.com/oxinabox/cdcffc1392f91a2f6d80b2524726d802#file-example-jl-L54 +function __get_toplevel_mi_from_ir(ir, _module::Module) + mi = ccall(:jl_new_method_instance_uninit, Ref{Core.MethodInstance}, ()) + mi.specTypes = Tuple{map(CC.widenconst, ir.argtypes)...} + mi.def = _module + return mi +end + +# Run type inference and constant propagation on the ir. Credit to @oxinabox: +# https://gist.github.com/oxinabox/cdcffc1392f91a2f6d80b2524726d802#file-example-jl-L54 +function __infer_ir!(ir, interp::CC.AbstractInterpreter, mi::CC.MethodInstance) + method_info = CC.MethodInfo(true, nothing) #=propagate_inbounds=# + min_world = world = get_inference_world(interp) + max_world = Base.get_world_counter() + irsv = CC.IRInterpretationState( + interp, method_info, ir, mi, ir.argtypes, world, min_world, max_world + ) + rt = CC._ir_abstract_constant_propagation(interp, irsv) + return ir +end + +# In automatically generated code, it is meaningless to include code coverage effects. +# Moreover, it seems to cause some serious inference probems. Consequently, it makes sense +# to remove such effects before optimising IRCode. +function __strip_coverage!(ir::IRCode) + for n in eachindex(stmt(ir.stmts)) + if Meta.isexpr(stmt(ir.stmts)[n], :code_coverage_effect) + stmt(ir.stmts)[n] = nothing + end + end + return ir +end + +stmt(ir::CC.InstructionStream) = @static VERSION < v"1.11.0-rc4" ? ir.inst : ir.stmt + +""" + opaque_closure( + ret_type::Type, + ir::IRCode, + @nospecialize env...; + isva::Bool=false, + do_compile::Bool=true, + )::Core.OpaqueClosure{<:Tuple, ret_type} + +Construct a `Core.OpaqueClosure`. Almost equivalent to +`Core.OpaqueClosure(ir, env...; isva, do_compile)`, but instead of letting +`Core.compute_oc_rettype` figure out the return type from `ir`, impose `ret_type` as the +return type. + +# Warning + +User beware: if the `Core.OpaqueClosure` produced by this function ever returns anything +which is not an instance of a subtype of `ret_type`, you should expect all kinds of awful +things to happen, such as segfaults. You have been warned! + +# Extended Help + +This is needed because we make extensive use of our ability to know the return +type of a couple of specific `OpaqueClosure`s without actually having constructed them. +Without the capability to specify the return type, we have to guess what type +`compute_ir_rettype` will return for a given `IRCode` before we have constructed +the `IRCode` and run type inference on it. This exposes us to details of type inference, +which are not part of the public interface of the language, and can therefore vary from +Julia version to Julia version (including patch versions). Moreover, even for a fixed Julia +version it can be extremely hard to predict exactly what type inference will infer to be the +return type of a function. + +Failing to correctly guess the return type can happen for a number of reasons, and the kinds +of errors that tend to be generated when this fails tell you very little about the +underlying cause of the problem. + +By specifying the return type ourselves, we remove this dependence. The price we pay for +this is the potential for segfaults etc if we fail to specify `ret_type` correctly. +""" +function opaque_closure( + ret_type::Type, + ir::IRCode, + @nospecialize env...; + isva::Bool=false, + do_compile::Bool=true, +) + # This implementation is copied over directly from `Core.OpaqueClosure`. + ir = CC.copy(ir) + nargs = length(ir.argtypes) - 1 + sig = Base.Experimental.compute_oc_signature(ir, nargs, isva) + src = ccall(:jl_new_code_info_uninit, Ref{CC.CodeInfo}, ()) + src.slotnames = fill(:none, nargs + 1) + src.slotflags = fill(zero(UInt8), length(ir.argtypes)) + src.slottypes = copy(ir.argtypes) + src.rettype = ret_type + src = CC.ir_to_codeinf!(src, ir) + return Base.Experimental.generate_opaque_closure( + sig, Union{}, ret_type, src, nargs, isva, env...; do_compile + )::Core.OpaqueClosure{sig,ret_type} +end + +""" + misty_closure( + ret_type::Type, + ir::IRCode, + @nospecialize env...; + isva::Bool=false, + do_compile::Bool=true, + ) + +Identical to [`opaque_closure`](@ref), but returns a `MistyClosure` closure rather +than a `Core.OpaqueClosure`. +""" +function misty_closure( + ret_type::Type, + ir::IRCode, + @nospecialize env...; + isva::Bool=false, + do_compile::Bool=true, +) + return MistyClosure(opaque_closure(ret_type, ir, env...; isva, do_compile), Ref(ir)) +end diff --git a/test/copyable_task.jl b/test/copyable_task.jl new file mode 100644 index 00000000..3bb01b5d --- /dev/null +++ b/test/copyable_task.jl @@ -0,0 +1,258 @@ +@testset "copyable_task" begin + for case in Libtask.TestUtils.test_cases() + case() + end + @testset "set_taped_globals!" begin + function f() + produce(Libtask.get_taped_globals(Int)) + produce(Libtask.get_taped_globals(Int)) + return nothing + end + t = TapedTask(5, f) + @test consume(t) == 5 + Libtask.set_taped_globals!(t, 6) + @test consume(t) == 6 + @test consume(t) === nothing + end + @testset "iteration" begin + function f() + t = 1 + while true + produce(t) + t = 1 + t + end + end + + ttask = TapedTask(nothing, f) + + next = iterate(ttask) + @test next === (1, nothing) + + val, state = next + next = iterate(ttask, state) + @test next === (2, nothing) + + val, state = next + next = iterate(ttask, state) + @test next === (3, nothing) + + a = collect(Iterators.take(ttask, 7)) + @test eltype(a) === Int + @test a == 4:10 + end + + # Test of `Exception`. + @testset "Exception" begin + @testset "method error" begin + function f() + t = 0 + while true + t[3] = 1 + produce(t) + t = t + 1 + end + end + + ttask = TapedTask(nothing, f) + try + consume(ttask) + catch ex + @test ex isa MethodError + end + end + + @testset "error test" begin + function f() + x = 1 + while true + error("error test") + produce(x) + x += 1 + end + end + + ttask = TapedTask(nothing, f) + try + consume(ttask) + catch ex + @test ex isa ErrorException + end + end + + @testset "OutOfBounds Test Before" begin + function f() + x = zeros(2) + while true + x[1] = 1 + x[2] = 2 + x[3] = 3 + produce(x[1]) + end + end + + ttask = TapedTask(nothing, f) + try + consume(ttask) + catch ex + @test ex isa BoundsError + end + end + + @testset "OutOfBounds Test After `produce`" begin + function f() + x = zeros(2) + while true + x[1] = 1 + x[2] = 2 + produce(x[2]) + x[3] = 3 + end + end + + ttask = TapedTask(nothing, f) + @test consume(ttask) == 2 + try + consume(ttask) + catch ex + @test ex isa BoundsError + end + end + + @testset "OutOfBounds Test After `copy`" begin + function f() + x = zeros(2) + while true + x[1] = 1 + x[2] = 2 + produce(x[2]) + x[3] = 3 + end + end + + ttask = TapedTask(nothing, f) + @test consume(ttask) == 2 + ttask2 = copy(ttask) + try + consume(ttask2) + catch ex + @test ex isa BoundsError + end + end + end + + @testset "copying" begin + # Test case 1: stack allocated objects are deep copied. + @testset "stack allocated objects shallow copy" begin + function f() + t = 0 + while true + produce(t) + t = 1 + t + end + end + + ttask = TapedTask(nothing, f) + @test consume(ttask) == 0 + @test consume(ttask) == 1 + a = copy(ttask) + @test consume(a) == 2 + @test consume(a) == 3 + @test consume(ttask) == 2 + @test consume(ttask) == 3 + end + + # Test case 2: Array objects are deeply copied. + @testset "Array objects deep copy" begin + function f() + t = [0 1 2] + while true + produce(t[1]) + t[1] = 1 + t[1] + end + end + + ttask = TapedTask(nothing, f) + @test consume(ttask) == 0 + @test consume(ttask) == 1 + a = copy(ttask) + @test consume(a) == 2 + @test consume(a) == 3 + @test consume(ttask) == 2 + @test consume(ttask) == 3 + @test consume(ttask) == 4 + @test consume(ttask) == 5 + end + + @testset "Array deep copy 2" begin + function f() + t = Array{Int}(undef, 1) + t[1] = 0 + while true + produce(t[1]) + t[1] + t[1] = 1 + t[1] + end + end + + ttask = TapedTask(nothing, f) + + consume(ttask) + consume(ttask) + a = copy(ttask) + consume(a) + consume(a) + + @test consume(ttask) == 2 + @test consume(a) == 4 + end + + @testset "ref of dictionary deep copy" begin + function f() + t = Ref(Dict("A" => 1, 5 => "B")) + t[]["A"] = 0 + for _ in 1:6 + produce(t[]["A"]) + t[]["A"] += 1 + end + end + + ctask = TapedTask(nothing, f) + + consume(ctask) + consume(ctask) + + a = copy(ctask) + consume(a) + consume(a) + + @test consume(ctask) == 2 + @test consume(a) == 4 + end + end + @testset "Issue: PR-86 (DynamicPPL.jl/pull/261)" begin + function f() + t = Array{Int}(undef, 1) + t[1] = 0 + for _ in 1:4000 + produce(t[1]) + t[1] + t[1] = 1 + t[1] + end + end + + ttask = TapedTask(nothing, f) + + ex = try + for _ in 1:999 + consume(ttask) + consume(ttask) + a = copy(ttask) + consume(a) + consume(a) + end + catch ex + ex + end + @test ex === nothing + end +end diff --git a/test/front_matter.jl b/test/front_matter.jl new file mode 100644 index 00000000..92a26f35 --- /dev/null +++ b/test/front_matter.jl @@ -0,0 +1 @@ +using Aqua, JuliaFormatter, Libtask, Test diff --git a/test/issues.jl b/test/issues.jl deleted file mode 100644 index 370c0235..00000000 --- a/test/issues.jl +++ /dev/null @@ -1,83 +0,0 @@ -@testset "Issues" begin - @testset "Issue: PR-86 (DynamicPPL.jl/pull/261)" begin - function f() - t = Array{Int}(undef, 1) - t[1] = 0 - for _ in 1:4000 - produce(t[1]) - t[1] - t[1] = 1 + t[1] - end - end - - ttask = TapedTask(f) - - ex = try - for _ in 1:999 - consume(ttask) - consume(ttask) - a = copy(ttask) - consume(a) - consume(a) - end - catch ex - ex - end - @test ex === nothing - end - - @testset "Issue-140, copy unstarted task" begin - function f(x) - for i in 1:3 - produce(i + x) - end - end - - ttask = TapedTask(f, 3) - ttask2 = copy(ttask) - @test consume(ttask2) == 4 - ttask3 = copy(ttask; args=(4,)) - @test consume(ttask3) == 5 - end - - @testset "Issue-148, unused argument" begin - function f(x) - produce(1) - end - - ttask = TapedTask(f, 2) - @test consume(ttask) == 1 - end - - @testset "Issue-Turing-1873, NamedTuple syntax" begin - function g(x, y) - c = x + y - return (; c, x, y) - end - - tf = Libtask.TapedFunction(g, 1, 2) - r = tf(1, 2) - @test r == (c=3, x=1, y=2) - end - - @testset "Issue-Libtask-174, SSAValue=Int and static parameter" begin - # SSAValue = Int - function f() - # this line generates: %1 = 1::Core.Const(1) - r = (a = 1) - return nothing - end - tf = Libtask.TapedFunction(f) - r = tf() - @test r == nothing - - # static parameter - function g(::Type{T}) where {T} - a = zeros(T, 10) - end - tf = Libtask.TapedFunction(g, Float64) - r = tf(Float64) - @test r == zeros(Float64, 10) - end - -end diff --git a/test/runtests.jl b/test/runtests.jl index a045454f..e6400098 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,11 +1,8 @@ -using Libtask -using Test - -include("tf.jl") -include("tapedtask.jl") -include("tape_copy.jl") -include("issues.jl") - -if haskey(ENV, "BENCHMARK") - include("benchmarks.jl") +include("front_matter.jl") +@testset "Libtask" begin + @testset "quality" begin + Aqua.test_all(Libtask) + @test JuliaFormatter.format(Libtask; verbose=false, overwrite=false) + end + include("copyable_task.jl") end diff --git a/test/tape_copy.jl b/test/tape_copy.jl deleted file mode 100644 index 6edc3f7c..00000000 --- a/test/tape_copy.jl +++ /dev/null @@ -1,195 +0,0 @@ -@testset "tape copy" begin - # Test case 1: stack allocated objects are deep copied. - @testset "stack allocated objects shallow copy" begin - function f() - t = 0 - while true - produce(t) - t = 1 + t - end - end - - ttask = TapedTask(f) - @test consume(ttask) == 0 - @test consume(ttask) == 1 - a = copy(ttask) - @test consume(a) == 2 - @test consume(a) == 3 - @test consume(ttask) == 2 - @test consume(ttask) == 3 - - @inferred Libtask.TapedFunction(f) - end - - # Test case 2: Array objects are deeply copied. - @testset "Array objects deep copy" begin - function f() - t = [0 1 2] - while true - produce(t[1]) - t[1] = 1 + t[1] - end - end - - ttask = TapedTask(f) - @test consume(ttask) == 0 - @test consume(ttask) == 1 - a = copy(ttask) - @test consume(a) == 2 - @test consume(a) == 3 - @test consume(ttask) == 2 - @test consume(ttask) == 3 - @test consume(ttask) == 4 - @test consume(ttask) == 5 - end - - # Test case 3: Dict objects are shallowly copied. - @testset "Dict objects shallow copy" begin - function f() - t = Dict(1=>10, 2=>20) - while true - produce(t[1]) - t[1] = 1 + t[1] - end - end - - ttask = TapedTask(f) - - @test consume(ttask) == 10 - @test consume(ttask) == 11 - - a = copy(ttask) - @test consume(a) == 12 - @test consume(a) == 13 - - @test consume(ttask) == 14 - @test consume(ttask) == 15 - end - - @testset "Array deep copy 2" begin - function f() - t = Array{Int}(undef, 1) - t[1] = 0 - while true - produce(t[1]) - t[1] - t[1] = 1 + t[1] - end - end - - ttask = TapedTask(f) - - consume(ttask) - consume(ttask) - a = copy(ttask) - consume(a) - consume(a) - - @test consume(ttask) == 2 - @test consume(a) == 4 - - DATA = Dict{Task, Array}() - function g() - ta = zeros(UInt64, 4) - for i in 1:4 - ta[i] = hash(current_task()) - DATA[current_task()] = ta - produce(ta[i]) - end - end - - ttask = TapedTask(g) - @test consume(ttask) == hash(ttask.task) # index = 1 - @test consume(ttask) == hash(ttask.task) # index = 2 - - a = copy(ttask) - @test consume(a) == hash(a.task) # index = 3 - @test consume(a) == hash(a.task) # index = 4 - - @test consume(ttask) == hash(ttask.task) # index = 3 - - @test DATA[ttask.task] == [hash(ttask.task), hash(ttask.task), hash(ttask.task), 0] - @test DATA[a.task] == [hash(ttask.task), hash(ttask.task), hash(a.task), hash(a.task)] - end - - # Test atomic values. - @testset "ref atomic" begin - function f() - t = Ref(1) - t[] = 0 - for _ in 1:6 - produce(t[]) - t[] - t[] += 1 - end - end - - ctask = TapedTask(f) - - consume(ctask) - consume(ctask) - - a = copy(ctask) - consume(a) - consume(a) - - @test consume(ctask) == 2 - @test consume(a) == 4 - end - - @testset "ref of dictionary deep copy" begin - function f() - t = Ref(Dict("A" => 1, 5 => "B")) - t[]["A"] = 0 - for _ in 1:6 - produce(t[]["A"]) - t[]["A"] += 1 - end - end - - ctask = TapedTask(f) - - consume(ctask) - consume(ctask) - - a = copy(ctask) - consume(a) - consume(a) - - @test consume(ctask) == 2 - @test consume(a) == 4 - end - - @testset "ref of array deep copy" begin - # Create a TRef storing a matrix. - x = TRef([1 2 3; 4 5 6]) - x[][1, 3] = 900 - @test x[][1,3] == 900 - - # TRef holding an array. - y = TRef([1,2,3]) - y[][2] = 19 - @test y[][2] == 19 - end - - @testset "override deepcopy_types #57" begin - struct DummyType end - - function f(start::Int) - t = [start] - while true - produce(t[1]) - t[1] = 1 + t[1] - end - end - - ttask = TapedTask(f, 0; deepcopy_types=DummyType) - consume(ttask) - - ttask2 = copy(ttask) - consume(ttask2) - - @test consume(ttask) == 1 - @test consume(ttask2) == 2 - end -end diff --git a/test/tapedtask.jl b/test/tapedtask.jl deleted file mode 100644 index f55f83e1..00000000 --- a/test/tapedtask.jl +++ /dev/null @@ -1,159 +0,0 @@ -@testset "tapedtask" begin - @testset "construction" begin - function f() - t = 1 - while true - produce(t) - t = 1 + t - end - end - - ttask = TapedTask(f) - @test consume(ttask) == 1 - - ttask = TapedTask((f, Union{})) - @test consume(ttask) == 1 - end - - @testset "iteration" begin - function f() - t = 1 - while true - produce(t) - t = 1 + t - end - end - - ttask = TapedTask(f) - - next = iterate(ttask) - @test next === (1, nothing) - - val, state = next - next = iterate(ttask, state) - @test next === (2, nothing) - - val, state = next - next = iterate(ttask, state) - @test next === (3, nothing) - - a = collect(Iterators.take(ttask, 7)) - @test eltype(a) === Int - @test a == 4:10 - end - - # Test of `Exception`. - @testset "Exception" begin - @testset "method error" begin - function f() - t = 0 - while true - t[3] = 1 - produce(t) - t = t + 1 - end - end - - ttask = TapedTask(f) - try - consume(ttask) - catch ex - @test ex isa MethodError - end - if VERSION >= v"1.5" - @test ttask.task.exception isa MethodError - end - end - - @testset "error test" begin - function f() - x = 1 - while true - error("error test") - produce(x) - x += 1 - end - end - - ttask = TapedTask(f) - try - consume(ttask) - catch ex - @test ex isa ErrorException - end - if VERSION >= v"1.5" - @test ttask.task.exception isa ErrorException - end - end - - @testset "OutOfBounds Test Before" begin - function f() - x = zeros(2) - while true - x[1] = 1 - x[2] = 2 - x[3] = 3 - produce(x[1]) - end - end - - ttask = TapedTask(f) - try - consume(ttask) - catch ex - @test ex isa BoundsError - end - if VERSION >= v"1.5" - @test ttask.task.exception isa BoundsError - end - end - - @testset "OutOfBounds Test After `produce`" begin - function f() - x = zeros(2) - while true - x[1] = 1 - x[2] = 2 - produce(x[2]) - x[3] = 3 - end - end - - ttask = TapedTask(f) - @test consume(ttask) == 2 - try - consume(ttask) - catch ex - @test ex isa BoundsError - end - if VERSION >= v"1.5" - @test ttask.task.exception isa BoundsError - end - end - - @testset "OutOfBounds Test After `copy`" begin - function f() - x = zeros(2) - while true - x[1] = 1 - x[2] = 2 - produce(x[2]) - x[3] = 3 - end - end - - ttask = TapedTask(f) - @test consume(ttask) == 2 - ttask2 = copy(ttask) - try - consume(ttask2) - catch ex - @test ex isa BoundsError - end - @test ttask.task.exception === nothing - if VERSION >= v"1.5" - @test ttask2.task.exception isa BoundsError - end - end - end -end diff --git a/test/tf.jl b/test/tf.jl deleted file mode 100644 index f1fd87c5..00000000 --- a/test/tf.jl +++ /dev/null @@ -1,34 +0,0 @@ -using Libtask - -@testset "tapedfunction" begin - # Test case 1: stack allocated objects are deep copied. - @testset "Instruction{typeof(__new__)}" begin - mutable struct S - i::Int - S(x, y) = new(x + y) - end - - tf = Libtask.TapedFunction(S, 1, 2) - s1 = tf(1, 2) - @test s1.i == 3 - newins = findall(x -> isa(x, Libtask.Instruction{typeof(Libtask.__new__)}), tf.tape) - @test length(newins) == 1 - end - - @testset "Compiled Tape" begin - function g(x, y) - if x>y - r= string(sin(x)) - else - r= sin(x) * cos(y) - end - return r - end - - tf = Libtask.TapedFunction(g, 1., 2.) - ctf = Libtask.compile(tf) - r = ctf(1., 2.) - - @test typeof(r) === Float64 - end -end