diff --git a/Project.toml b/Project.toml index 8fa384e42..49a913fd2 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,9 @@ version = "0.2.4" [deps] BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" +ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4" +Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" EarlyStopping = "792122b4-ca99-40de-a6bc-6742525f08b6" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" @@ -19,6 +22,7 @@ PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" ValueHistories = "98cad3c8-aec3-5f06-8e41-884608649ab7" @@ -26,6 +30,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] BSON = "0.2, 0.3" +DataFrames = "1.3" DataStructures = "0.18" EarlyStopping = "0.1, 0.2, 0.3" Flux = "0.11, 0.12, 0.13" diff --git a/src/FluxTraining.jl b/src/FluxTraining.jl index 2001cb0d8..c85a13cf2 100644 --- a/src/FluxTraining.jl +++ b/src/FluxTraining.jl @@ -3,6 +3,10 @@ module FluxTraining using Graphs using BSON: @load, @save +using Colors: @colorant_str +using ColorSchemes: ColorScheme, colorschemes +using DataFrames: DataFrame, groupby, select, subset, combine +using DataFrames.PooledArrays: PooledArray using Flux using Flux: Params, onecold using Flux.Optimise: update! @@ -17,7 +21,7 @@ import OnlineStats using OnlineStats: EqualWeight, Mean, OnlineStat using Parameters using ProgressMeter: Progress, next! -using Statistics: mean +using Statistics: mean, median using UUIDs using Zygote using ParameterSchedulers @@ -26,6 +30,7 @@ using Zygote: Grads, gradient using ValueHistories using DataStructures: DefaultDict using PrettyTables +using StructArrays # functional include("./functional/metrics.jl") @@ -37,6 +42,7 @@ include("./callbacks/events.jl") include("./callbacks/callback.jl") include("./callbacks/graph.jl") include("./callbacks/execution.jl") +include("./callbacks/runners/profiler.jl") # logging include("./callbacks/logging/Loggables.jl") @@ -105,5 +111,6 @@ export AbstractCallback, step!, onecycle, loadmodel, - savemodel + savemodel, + ProfileRunner end # module diff --git a/src/callbacks/runners/profiler.jl b/src/callbacks/runners/profiler.jl new file mode 100644 index 000000000..c8d922831 --- /dev/null +++ b/src/callbacks/runners/profiler.jl @@ -0,0 +1,234 @@ + +# Data structures + +struct TimingBetween + phase::Any + eventstart::Any + eventend::Any + timestart::Any + timeend::Any + duration::Any +end + +TimingBetween(phase, es, ee, ts, te) = TimingBetween(phase, es, ee, ts, te, te - ts) + +struct TimingCallback + phase::Any + cb::Any + event::Any + timestart::Any + timeend::Any + duration::Any +end + +TimingCallback(phase, cb, e, ts, te) = TimingCallback(phase, cb, e, ts, te, te - ts) + + +# Runner + +""" + ProfileRunner() <: CallbackRunner + +A profiling callback runner that measures times for callback +handlers and times between events. This allows for granular +benchmarking of any training loop. + +## Examples + +To use, pass as `cbrunner` argument to `Learner`: + +```julia +cbrunner = ProfileRunner() +learner = Learner(model, data, opt, lossfn; cbrunner=cbrunner) +fit!(learner, 10) +``` + +After having trained, you can access the timings on fields: + +- `cbrunner.timesbetween`: Stores timings between events +- `cbrunner.timescallbacks`: Stores timings for callback handlers +""" +mutable struct ProfileRunner <: FluxTraining.CallbackRunner + df_fit::DataFrame + df_cb::DataFrame + _last::Any +end + + +ProfileRunner() = ProfileRunner(_new_df_fit(), _new_df_cb(), nothing) + + +function Base.show(io::IO, runner::ProfileRunner) + print(io, "ProfileRunner(df_fit = ") + summary(io, runner.df_fit) + print(io, ", df_cb = ") + summary(io, runner.df_cb) + print(io, ")") +end + + +_new_df_fit() = DataFrame( + phase = PooledArray(Type{<:Phase}[], UInt8), + eventstart = PooledArray(Type{<:Event}[], UInt8), + eventend = PooledArray(Type{<:Event}[], UInt8), + timestart = Float64[], + timeend = Float64[], +) + +_new_df_cb() = DataFrame( + phase = PooledArray(Type{<:Phase}[], UInt8), + event = PooledArray(Type{<:Event}[], UInt8), + callback = FluxTraining.Callback[], + timestart = Float64[], + timeend = Float64[], +) + + + +function FluxTraining.handle( + runner::ProfileRunner, + event::E, + phase::P, + learner, +) where {E<:Event,P<:Phase} + # add timing for inbetween + last = runner._last + if last !== nothing + timeend = Zygote.ignore(() -> Base.time()) + lastevent, lasttime, lastphase = last + if lastphase == P + Zygote.ignore() do + push!( + runner.df_fit, + (; + phase = P, + eventstart = lastevent, + eventend = E, + timestart = lasttime, + timeend = timeend, + ), + ) + end + end + end + + # execute callback and add timing for it + idxs = Zygote.ignore() do + LightGraphs.topological_sort_by_dfs(learner.callbacks.graph) + end + for i in idxs + cb = learner.callbacks.cbs[i] + timestart = Zygote.ignore(() -> Base.time()) + FluxTraining._on(event, phase, cb, learner) + Zygote.ignore() do + timeend = Base.time() + push!( + runner.df_cb, + (; + phase = P, + event = E, + callback = cb, + timestart = timestart, + timeend = timeend, + ), + ) + end + end + + # update `last` so next between time can be measured + runner._last = (E, Zygote.ignore(() -> Base.time()), P) + return nothing +end + + +# ### Data transformations +# +# Get the data into a usable shape for further analysis. + +""" + getsteptimings(profilerunner[, Phase]) -> GroupedDataFrame + +Group the data of step timings by the events that they occur between. +""" +function getsteptimings(runner::ProfileRunner, P = AbstractTrainingPhase) + return groupby( + subset( + combine( + runner.df_fit, + [:timeend, :timestart] => ((e, s) -> e - s) => :duration, + :phase, + :eventstart, + :eventend, + ), + :phase => (ps -> ps .<: P), + :eventstart => (es -> ((es .!= EpochBegin) .& (es .!= EpochEnd))), + :eventend => (es -> ((es .!= EpochBegin) .& (es .!= EpochEnd))), + ), + [:eventstart, :eventend], + ) +end + +# ### Analysis and visualization +# +# Provide helpful analyses that show most important timings and help with +# benchmarking and identifying bottlenecks. + +""" + showsteptimings(profilerunner) + showsteptimings(io, profilerunner, P = AbstractTrainingPhase; metrics = [...]) + + +""" +function showsteptimings( + io::IO, + runner::ProfileRunner, + P = AbstractTrainingPhase; + metrics = [median, minimum, maximum], +) + gdf = getsteptimings(runner, P) + rownames = ["$(k.eventstart) => $(k.eventend)" for k in keys(gdf)] + rowdata = [metricfn(eventdf.duration .* 1000) for eventdf in gdf, metricfn in metrics] + pretty_table( + io, + rowdata, + header = (string.(metrics), repeat(["ms"], length(metrics))), + row_names = rownames, + row_name_column_title = "Event", + highlighters = _timinghighlighter(), + formatters = ft_printf("%5.3f"), + ) +end +showsteptimings(args...; kwargs...) = showsteptimings(stdout, args...; kwargs...) + + +# #### PrettyTables.jl utilities + +_timinghighlighter() = Highlighter( + (data, i, j) -> true, + function (h, data, i, j) + ext = extrema(data[:, j]) + ext = 0., ext[2] + return Crayon( + background = _cvtcolor( + get( + ColorScheme(range(colorant"black", colorant"darkorange4")), + data[i, j], + ext, + ), + ), + foreground = _cvtcolor( + get( + ColorScheme(range(colorant"gray", colorant"white")), + data[i, j], + ext, + ), + ), + ) + end, +) + +_cvtcolor(c::Color) = ( + round(Int, Colors.red(c) * 255), + round(Int, Colors.green(c) * 255), + round(Int, Colors.blue(c) * 255), +) diff --git a/src/training.jl b/src/training.jl index 4356252d2..d542d3687 100644 --- a/src/training.jl +++ b/src/training.jl @@ -64,8 +64,9 @@ end function step!(learner, phase::ValidationPhase, batch) xs, ys = batch - runstep(learner, phase, (;xs=xs, ys=ys)) do _, state + runstep(learner, phase, (;xs=xs, ys=ys)) do handle, state state.ŷs = learner.model(state.xs) + handle(LossBegin()) state.loss = learner.lossfn(state.ŷs, state.ys) end end diff --git a/test/profilerunner.jl b/test/profilerunner.jl new file mode 100644 index 000000000..1cd23f7f2 --- /dev/null +++ b/test/profilerunner.jl @@ -0,0 +1,2 @@ +@testset "ProfileRunner" begin +end