Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.1.0"

[deps]
Catwalk = "860e6890-8a08-4313-9643-fcac6eb69798"
Folds = "41a02a25-b8f0-4f67-bc48-60067656b558"
Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"

[compat]
Expand Down
35 changes: 30 additions & 5 deletions examples/sum.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
using BenchmarkTools
using Catwalk
using Folds
using FoldsCatwalk
using Transducers

type_instability(x) = Val(round(Int, 2x))

# This version emphasizes the possible gain better:
#vals=[Val(i) for i= 1:100]
#type_instability(x) = vals[abs(round(Int, 4x)) % length(vals) + 1]

asint(::Val{x}) where {x} = x::Int

function sum_baseline(xs)
Expand All @@ -16,9 +22,28 @@ function sum_catwalk(xs)
Folds.sum(itr, CatwalkEx())
end

function demo_sum()
xs = randn(1000_000)
@assert sum_baseline(xs) == sum_catwalk(xs)
@btime sum_baseline($xs)
@btime sum_catwalk($xs)
function sum_catwalk_tuned(xs)
itr = xs |> Map(type_instability) |> OptimizeInner() do
boost = Catwalk.CallBoost(:next;
optimizer = Catwalk.TopNOptimizer(15),
profilestrategy = Catwalk.SparseProfile(0))
jit = Catwalk.JIT(boost; explorertype = Catwalk.NoExplorer)
end |> Map(asint)
Folds.sum(itr, CatwalkEx())
end

function demo_sum(n=20_000_000)
xs = randn(n)

println("Baseline:")
baseline_result = @btime sum_baseline($xs)

println("Catwalk defaults:")
catwalk_result = @btime sum_catwalk($xs)

println("Catwalk tuned:")
catwalk_tuned_result = @btime sum_catwalk_tuned($xs)

@assert baseline_result == catwalk_result
@assert baseline_result == catwalk_tuned_result
end
18 changes: 9 additions & 9 deletions src/FoldsCatwalk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ for how to specify what to JIT.
struct CatwalkEx <: Executor
batchsize::Int
end
CatwalkEx() = CatwalkEx(1000)
CatwalkEx() = CatwalkEx(1_000_000)

Transducers.maybe_set_simd(exc::CatwalkEx, _) = exc

Expand Down Expand Up @@ -81,11 +81,12 @@ function Transducers.transduce(xf, rf, init, coll, exc::CatwalkEx)
end

function onebatch(rf::RF, acc, itr, state, counter) where {RF}
jitctx = xform(inner(rf)).jitctx
while counter != 0
y = iterate(itr, state)
y === nothing && return acc, state, true
state = last(y)
val = next(rf, acc, first(y))
val = next(rf, acc, first(y), jitctx)
val isa Reduced && return val, state, true
acc = val
counter -= 1
Expand All @@ -97,14 +98,13 @@ struct OptimizeXF{C} <: Transducer
jitctx::C
end

Transducers.next(rf::R_{OptimizeXF}, acc, input) =
invoke_next(xform(rf).jitctx, inner(rf), acc, input)

@jit jitfun jitarg function invoke_next(jitctx, rf::RF, acc, input) where {RF}
jitarg = (acc, input)
return jitfun(rf, jitarg)
@inline @jit next mapped function Transducers.next(rf::R_{Map}, acc, input, jitctx)
mapped = xform(rf).f(input)
next(inner(rf), acc, mapped)
end

@inline jitfun(rf::RF, jitarg) where {RF} = next(rf, first(jitarg), last(jitarg))
function Transducers.next(rf::R_{OptimizeXF}, acc, input)
next(inner(rf), acc, input)
end

end # module FoldsCatwalk