Skip to content

Commit 15b6ae7

Browse files
authored
perf: more benchmarking results (#1614)
* perf: update resnet benchmarks * perf: add KAN perf numbers
1 parent 420dce3 commit 15b6ae7

32 files changed

+6062
-3179
lines changed

README.md

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
[![CI (pre-release)](<https://img.shields.io/github/actions/workflow/status/LuxDL/Lux.jl/CIPreRelease.yml?branch=main&label=CI%20(pre-release)&logo=github>)](https://github.com/LuxDL/Lux.jl/actions/workflows/CIPreRelease.yml)
1414
[![Build status](https://img.shields.io/buildkite/ba1f9622add5978c2d7b194563fd9327113c9c21e5734be20e/main.svg?label=gpu&branch=main&logo=buildkite)](https://buildkite.com/julialang/lux-dot-jl)
1515
[![codecov](https://codecov.io/gh/LuxDL/Lux.jl/branch/main/graph/badge.svg?token=IMqBM1e3hz)](https://codecov.io/gh/LuxDL/Lux.jl)
16-
[![Benchmarks](https://github.com/LuxDL/Lux.jl/actions/workflows/Benchmark.yml/badge.svg?branch=main)](https://lux.csail.mit.edu/benchmarks/)
16+
<!-- [![Benchmarks](https://github.com/LuxDL/Lux.jl/actions/workflows/Benchmark.yml/badge.svg?branch=main)](https://lux.csail.mit.edu/benchmarks/) -->
1717

1818
[![Downloads](https://img.shields.io/badge/dynamic/json?url=http%3A%2F%2Fjuliapkgstats.com%2Fapi%2Fv1%2Fmonthly_downloads%2FLux&query=total_requests&suffix=%2Fmonth&label=Downloads)](https://juliapkgstats.com/pkg/Lux)
1919
[![Downloads](https://img.shields.io/badge/dynamic/json?url=http%3A%2F%2Fjuliapkgstats.com%2Fapi%2Fv1%2Ftotal_downloads%2FLux&query=total_requests&&label=Total%20Downloads)](https://juliapkgstats.com/pkg/Lux)
@@ -129,6 +129,18 @@ Pkg.add("Lux")
129129
[downloads-luxtestutils-url]: http://juliapkgstats.com/pkg/LuxTestUtils
130130
[downloads-luxcuda-url]: http://juliapkgstats.com/pkg/LuxCUDA
131131

132+
## 🚀 Benchmarks
133+
134+
Currently Benchmarks are scatter across a few places:
135+
136+
1. For comparison with other Julia packages like CUDA.jl take a look
137+
at [Lux.jl/perf](./perf/README.md).
138+
2. <https://enzymead.github.io/Enzyme-JAX/benchmarks/> highlights
139+
performance of EnzymeJAX (backend for Reactant.jl) against JAX.
140+
3. <https://enzymead.github.io/Reactant.jl/benchmarks/> highlights
141+
performance of Reactant.jl against default XLA and base Julia
142+
compilation.
143+
132144
## 🤸 Quickstart
133145

134146
### Reactant & Enzyme

perf/README.md

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
## ResNet
44

5-
Benchmarks were run on a single A100 GPU with 40GB of memory.
5+
Benchmarks were run on a single GeForce RTX 5090 GPU with 32GB of VRAM.
66

77
<p align="center">
88
<img src="results/resnet/resnet_runtimes.svg#gh-light-mode-only"/>
@@ -13,3 +13,17 @@ Benchmarks were run on a single A100 GPU with 40GB of memory.
1313
<img src="results/resnet/resnet_speedups.svg#gh-light-mode-only"/>
1414
<img src="results/resnet/resnet_speedups_dark.svg#gh-dark-mode-only"/>
1515
</p>
16+
17+
## Kolmogorov-Arnold Networks
18+
19+
Benchmarks were run on a single GeForce RTX 5090 GPU with 32GB of VRAM.
20+
21+
<p align="center">
22+
<img src="results/kan/kan_runtimes.svg#gh-light-mode-only"/>
23+
<img src="results/kan/kan_runtimes_dark.svg#gh-dark-mode-only"/>
24+
</p>
25+
26+
<p align="center">
27+
<img src="results/kan/kan_speedups.svg#gh-light-mode-only"/>
28+
<img src="results/kan/kan_speedups_dark.svg#gh-dark-mode-only"/>
29+
</p>

perf/kan/Project.toml

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
[deps]
2+
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
3+
Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542"
4+
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
5+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
6+
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
7+
KolmogorovArnold = "eec8b66d-f71a-4a43-b228-0fe5d6721cd3"
8+
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
9+
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
10+
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
11+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
12+
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
13+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
14+
15+
[sources]
16+
Lux = {path = "../../"}
17+
LuxCUDA = {path = "../../lib/LuxCUDA"}
18+
19+
[compat]
20+
BenchmarkTools = "1.6.0"
21+
Comonicon = "1.0.8"
22+
Enzyme = "0.13.81"
23+
Lux = "1.13.3"
24+
LuxCUDA = "0.3.3"
25+
Random = "1.11"
26+
Reactant = "0.2.190"
27+
julia = "1.11"

perf/kan/main.jl

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
# Taken from https://github.com/vpuri3/KolmogorovArnold.jl/blob/0fc349813be15982365173bce0e9bf3a814a342a/examples/eg3.jl
2+
using KolmogorovArnold
3+
using Comonicon, BenchmarkTools, JSON3
4+
using Random, LinearAlgebra
5+
using Enzyme, Zygote, Lux
6+
using OrderedCollections
7+
8+
# configure BLAS
9+
ncores = min(Sys.CPU_THREADS, length(Sys.cpu_info()))
10+
BLAS.set_num_threads(ncores)
11+
12+
# configure CUDA
13+
using LuxCUDA
14+
CUDA.allowscalar(false)
15+
16+
# configure Reactant
17+
using Reactant
18+
Reactant.set_default_backend("gpu")
19+
20+
rng = Random.default_rng()
21+
Random.seed!(rng, 0)
22+
23+
function toy_loss_function(model, ps, st, x, y)
24+
pred, _ = model(x, ps, st)
25+
return MSELoss()(pred, y)
26+
end
27+
28+
function setup_models(; kan_width::Int=128, grid_size::Int=32)
29+
wK, G = kan_width, grid_size
30+
31+
basis_func = rbf # rbf, rswaf
32+
normalizer = softsign # sigmoid(_fast), tanh(_fast), softsign
33+
34+
kan1 = Chain(
35+
KDense(1, wK, G; use_base_act=true, basis_func, normalizer),
36+
KDense(wK, wK, G; use_base_act=true, basis_func, normalizer),
37+
KDense(wK, 1, G; use_base_act=true, basis_func, normalizer),
38+
)
39+
40+
kan2 = Chain(
41+
KDense(1, wK, G; use_base_act=false, basis_func, normalizer),
42+
KDense(wK, wK, G; use_base_act=false, basis_func, normalizer),
43+
KDense(wK, 1, G; use_base_act=false, basis_func, normalizer),
44+
)
45+
46+
return [("kan_base_act", kan1), ("kan_no_base_act", kan2)]
47+
end
48+
49+
function run_cuda_benchmarks(; batch_size::Int=128, kwargs...)
50+
dev = gpu_device(; force=true)
51+
52+
x = rand32(rng, 1, batch_size)
53+
y = x .^ 2
54+
55+
models = setup_models(; kwargs...)
56+
timings = OrderedDict{String,OrderedDict{String,Float64}}()
57+
58+
for (name, model) in models
59+
println("\nCUDA Benchmarking: $name")
60+
61+
ps, st = Lux.setup(rng, model) |> dev
62+
x_cu = x |> dev
63+
y_cu = y |> dev
64+
65+
println("Param count: $(Lux.parameterlength(ps))")
66+
println("State count: $(Lux.statelength(st))")
67+
68+
# Forward pass timing
69+
fwd_time = @belapsed begin
70+
pred, _ = $(model)($(x_cu), $(ps), $(Lux.testmode(st)))
71+
CUDA.synchronize()
72+
end setup = begin
73+
GC.gc(true)
74+
CUDA.reclaim()
75+
end
76+
77+
# Backward pass timing (using Zygote)
78+
fn = (ps, x) -> toy_loss_function(model, ps, st, x, y_cu)
79+
80+
bwd_time = @belapsed begin
81+
Zygote.gradient($(fn), $(ps), $(x_cu))
82+
CUDA.synchronize()
83+
end setup = begin
84+
GC.gc(true)
85+
CUDA.reclaim()
86+
end
87+
88+
timings[name] = OrderedDict{String,Float64}(
89+
"forward" => fwd_time, "backward" => bwd_time
90+
)
91+
92+
display(timings[name])
93+
end
94+
95+
return timings
96+
end
97+
98+
function run_xla_benchmarks(; kwargs...)
99+
return run_reactant_benchmarks(;
100+
kwargs..., compile_options=Reactant.DefaultXLACompileOptions()
101+
)
102+
end
103+
104+
function run_reactant_benchmarks(;
105+
batch_size::Int=128,
106+
compile_options=Reactant.CompileOptions(; optimization_passes=:all),
107+
kwargs...,
108+
)
109+
dev = reactant_device(; force=true)
110+
111+
x = rand32(rng, 1, batch_size)
112+
y = x .^ 2
113+
114+
models = setup_models(; kwargs...)
115+
timings = OrderedDict{String,OrderedDict{String,Float64}}()
116+
117+
for (name, model) in models
118+
println("\nReactant Benchmarking: $name")
119+
120+
ps, st = Lux.setup(rng, model) |> dev
121+
x_ra = x |> dev
122+
y_ra = y |> dev
123+
124+
println("Param count: $(Lux.parameterlength(ps))")
125+
println("State count: $(Lux.statelength(st))")
126+
127+
# Forward pass timing
128+
fwd_time_result = Reactant.Profiler.profile_with_xprof(
129+
Lux.apply,
130+
model,
131+
x_ra,
132+
ps,
133+
Lux.testmode(st);
134+
nrepeat=10,
135+
warmup=1,
136+
compile_options,
137+
)
138+
fwd_time = fwd_time_result.profiling_result.runtime_ns / 1e9
139+
140+
# Backward pass timing
141+
bwd_time_result = Reactant.Profiler.profile_with_xprof(
142+
Enzyme.gradient,
143+
Reverse,
144+
toy_loss_function,
145+
Const(model),
146+
ps,
147+
Const(st),
148+
Const(x_ra),
149+
Const(y_ra);
150+
nrepeat=10,
151+
warmup=1,
152+
compile_options,
153+
)
154+
bwd_time = bwd_time_result.profiling_result.runtime_ns / 1e9
155+
156+
timings[name] = OrderedDict{String,Float64}(
157+
"forward" => fwd_time, "backward" => bwd_time
158+
)
159+
160+
display(timings[name])
161+
end
162+
163+
return timings
164+
end
165+
166+
Comonicon.@main function main(;
167+
backend::String="all", batch_size::Int=1024, kan_width::Int=128, grid_size::Int=32
168+
)
169+
results_path = joinpath(@__DIR__, "../results/kan/")
170+
mkpath(results_path)
171+
172+
if backend in ("cuda", "all")
173+
println("\n" * "="^50)
174+
println("Running CUDA benchmarks...")
175+
println("="^50)
176+
177+
cuda_timings = run_cuda_benchmarks(; batch_size, kan_width, grid_size)
178+
179+
open(joinpath(results_path, "cudajl.json"), "w") do io
180+
JSON3.write(io, cuda_timings)
181+
end
182+
183+
println("\nCUDA Results:")
184+
display(cuda_timings)
185+
end
186+
187+
if backend in ("reactant", "all")
188+
println("\n" * "="^50)
189+
println("Running Reactant benchmarks...")
190+
println("="^50)
191+
192+
reactant_timings = run_reactant_benchmarks(; batch_size, kan_width, grid_size)
193+
194+
open(joinpath(results_path, "reactant.json"), "w") do io
195+
JSON3.write(io, reactant_timings)
196+
end
197+
198+
println("\nReactant Results:")
199+
display(reactant_timings)
200+
end
201+
202+
if backend in ("xla", "all")
203+
println("\n" * "="^50)
204+
println("Running XLA benchmarks...")
205+
println("="^50)
206+
207+
xla_timings = run_xla_benchmarks(; batch_size, kan_width, grid_size)
208+
209+
open(joinpath(results_path, "xla.json"), "w") do io
210+
JSON3.write(io, xla_timings)
211+
end
212+
213+
println("\nXLA Results:")
214+
display(xla_timings)
215+
end
216+
217+
return nothing
218+
end

perf/resnet/Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
55
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
66
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
77
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
8+
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
89
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
910
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
1011
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
@@ -20,5 +21,5 @@ Enzyme = "0.13.81"
2021
Lux = "1.13.3"
2122
LuxCUDA = "0.3.3"
2223
Random = "1.11"
23-
Reactant = "0.2.170"
24+
Reactant = "0.2.190"
2425
julia = "1.11"
Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Comonicon, BenchmarkTools, JSON3
22
using Lux, LuxCUDA, Random, Zygote
3+
using OrderedCollections
34

45
include("resnet.jl")
56

@@ -13,7 +14,7 @@ Comonicon.@main function main(;
1314
)
1415
dev = gpu_device(; force=true)
1516

16-
timings = Dict{Int,Dict{Int,Dict{String,Float64}}}()
17+
timings = OrderedDict{Int,OrderedDict{Int,OrderedDict{String,Float64}}}
1718

1819
for m in model_size
1920
println("model_size=$m")
@@ -23,7 +24,7 @@ Comonicon.@main function main(;
2324
println("Param count: $(Lux.parameterlength(ps))")
2425
println("State count: $(Lux.statelength(st))")
2526

26-
timings[m] = Dict{Int,Dict{String,Float64}}()
27+
timings[m] = OrderedDict{Int,OrderedDict{String,Float64}}()
2728

2829
for b in batch_size
2930
x = rand(Float32, 224, 224, 3, b) |> dev
@@ -52,12 +53,12 @@ Comonicon.@main function main(;
5253
end
5354
end
5455

55-
timings[m][b] = Dict{String,Float64}(
56+
timings[m][b] = OrderedDict{String,Float64}(
5657
"forward" => fwd_time, "backward" => bwd_time
5758
)
5859
end
5960

60-
println(timings[m])
61+
display(timings[m])
6162
end
6263

6364
results_path = joinpath(@__DIR__, "../results/resnet/")

perf/resnet/main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import json
1616

1717
from functools import partial
18-
from typing import Any, Tuple
18+
from typing import Any
1919
from collections.abc import Callable, Sequence
2020

2121
import flax.linen as nn
@@ -153,7 +153,7 @@ def loss_fn(p, x, y):
153153
if __name__ == "__main__":
154154
parser = argparse.ArgumentParser()
155155
parser.add_argument("--batch-size", type=list, default=[1, 4, 32, 128])
156-
parser.add_argument("--model-size", type=list, default=[18, 34, 50, 101, 152])
156+
parser.add_argument("--model-size", type=list, default=[18, 34, 50, 101])
157157
args = parser.parse_args()
158158

159159
timings = dict()
@@ -223,4 +223,4 @@ def loss_fn(p, x, y):
223223
os.makedirs(results_path, exist_ok=True)
224224

225225
with open(os.path.join(results_path, "jax.json"), "w") as f:
226-
json.dump(timings, f, indent=4)
226+
json.dump(timings, f, indent=4, sort_keys=True)

0 commit comments

Comments
 (0)