Skip to content

Commit f6ba7a4

Browse files
committed
perf: setup to use profiler
1 parent 84ccdc7 commit f6ba7a4

File tree

8 files changed

+121
-158
lines changed

8 files changed

+121
-158
lines changed

benchmark/misc/common.jl

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,7 @@ using Reactant: Reactant, @compile
22
using Chairmarks: @b
33
using Printf: @sprintf
44

5-
function get_backend()
6-
# To run benchmarks on a specific backend
7-
BENCHMARK_GROUP = get(ENV, "BENCHMARK_GROUP", nothing)
8-
9-
if BENCHMARK_GROUP == "CUDA"
10-
Reactant.set_default_backend("gpu")
11-
@info "Running CUDA benchmarks" maxlog = 1
12-
elseif BENCHMARK_GROUP == "TPU"
13-
Reactant.set_default_backend("tpu")
14-
elseif BENCHMARK_GROUP == "CPU"
15-
Reactant.set_default_backend("cpu")
16-
@info "Running CPU benchmarks" maxlog = 1
17-
else
18-
BENCHMARK_GROUP = String(split(string(first(Reactant.devices())), ":")[1])
19-
@info "Running $(BENCHMARK_GROUP) benchmarks" maxlog = 1
20-
end
21-
return BENCHMARK_GROUP
22-
end
5+
include("../utils.jl")
236

247
struct BenchmarkConfiguration
258
name::String

benchmark/misc/runbenchmarks.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Miscellaneous Benchmarks Runner
22
# This script runs all misc benchmarks and stores results to a JSON file
33

4-
include("../utils.jl")
4+
include("common.jl")
55

66
@info sprint(io -> versioninfo(io; verbose=true))
77

@@ -14,7 +14,6 @@ using Random: Random
1414
using Printf: @sprintf
1515

1616
# Include benchmark modules
17-
include("common.jl")
1817
include("newton_schulz.jl")
1918
include("bloch_rf_optimization.jl")
2019

benchmark/nn/common.jl

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -239,21 +239,4 @@ function run_benchmark!(
239239
return nothing
240240
end
241241

242-
function get_backend()
243-
# To run benchmarks on a specific backend
244-
BENCHMARK_GROUP = get(ENV, "BENCHMARK_GROUP", nothing)
245-
246-
if BENCHMARK_GROUP == "CUDA"
247-
Reactant.set_default_backend("gpu")
248-
@info "Running CUDA benchmarks" maxlog = 1
249-
elseif BENCHMARK_GROUP == "TPU"
250-
Reactant.set_default_backend("tpu")
251-
elseif BENCHMARK_GROUP == "CPU"
252-
Reactant.set_default_backend("cpu")
253-
@info "Running CPU benchmarks" maxlog = 1
254-
else
255-
BENCHMARK_GROUP = String(split(string(first(Reactant.devices())), ":")[1])
256-
@info "Running $(BENCHMARK_GROUP) benchmarks" maxlog = 1
257-
end
258-
return BENCHMARK_GROUP
259-
end
242+
include("../utils.jl")

benchmark/nn/runbenchmarks.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
# Neural Network Benchmarks Runner
22
# This script runs all NN benchmarks and stores results to a JSON file
33

4-
include("../utils.jl")
4+
include("common.jl")
55

66
@info sprint(io -> versioninfo(io; verbose=true))
77

88
backend = get_backend()
99

1010
# Include benchmark modules
11-
include("common.jl")
1211
include("vision.jl")
1312
include("neural_operators.jl")
1413
include("dgcnn.jl")

benchmark/oceananigans/abernathey_channel.jl

Lines changed: 81 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,15 @@
1-
using Oceananigans
2-
ENV["GKSwstype"] = "100"
1+
using Oceananigans, Printf, Statistics, SeawaterPolynomials, CUDA, Reactant, Enzyme
32

4-
using Printf
5-
using Statistics
6-
7-
using Oceananigans
83
using Oceananigans.Units
94
using Oceananigans.OutputReaders: FieldTimeSeries
105
using Oceananigans.Grids: xnode, ynode, znode
116
using Oceananigans.TurbulenceClosures: CATKEVerticalDiffusivity, HorizontalFormulation
12-
13-
using SeawaterPolynomials
14-
15-
using CUDA
16-
17-
using Reactant
187
using Oceananigans.Architectures: ReactantState
198

20-
using Enzyme
21-
229
Oceananigans.defaults.FloatType = Float64
2310

11+
include("../utils.jl")
12+
2413
graph_directory = "run_abernathy_model_ad_spinup100_100steps/"
2514

2615
# number of grid points
@@ -368,91 +357,85 @@ end
368357
##### Actually creating our model and using these functions to run it:
369358
#####
370359

371-
# Architecture
372-
architecture = ReactantState()
373-
374-
# Timestep size:
375-
Δt₀ = 2.5minutes
376-
377-
# Make the grid:
378-
grid = make_grid(architecture, Nx, Ny, Nz, z_faces)
379-
model = build_model(grid, Δt₀, parameters)
380-
T_flux = T_flux_init(model.grid, parameters)
381-
u_wind_stress = u_wind_stress_init(model.grid, parameters)
382-
v_wind_stress = v_wind_stress_init(model.grid, parameters)
383-
Tᵢ, Sᵢ = temperature_salinity_init(model.grid, parameters)
384-
mld = Field{Center,Center,Nothing}(model.grid) # Not used for now
385-
Δz = Reactant.to_rarray(Δz)
386-
387-
dmodel = Enzyme.make_zero(model)
388-
dTᵢ = Field{Center,Center,Center}(model.grid)
389-
dSᵢ = Field{Center,Center,Center}(model.grid)
390-
du_wind_stress = Field{Face,Center,Nothing}(model.grid)
391-
dv_wind_stress = Field{Center,Face,Nothing}(model.grid)
392-
dT_flux = Field{Center,Center,Nothing}(model.grid)
393-
dmld = Field{Center,Center,Nothing}(model.grid)
394-
dΔz = Enzyme.make_zero(Δz)
395-
396-
# Trying zonal transport:
397-
398-
tic = time()
399-
rspinup_reentrant_channel_model! = @compile raise_first = true raise = true sync = true spinup_reentrant_channel_model!(
400-
model, Tᵢ, Sᵢ, u_wind_stress, v_wind_stress, T_flux
401-
)
402-
#restimate_tracer_error = @compile raise_first=true raise=true sync=true estimate_tracer_error(model, Tᵢ, Sᵢ, u_wind_stress, v_wind_stress, T_flux, Δz, mld)
403-
rdifferentiate_tracer_error = @compile raise_first = true raise = true sync = true differentiate_tracer_error(
404-
model,
405-
Tᵢ,
406-
Sᵢ,
407-
u_wind_stress,
408-
v_wind_stress,
409-
T_flux,
410-
Δz,
411-
mld,
412-
dmodel,
413-
dTᵢ,
414-
dSᵢ,
415-
du_wind_stress,
416-
dv_wind_stress,
417-
dT_flux,
418-
dΔz,
419-
dmld,
420-
)
421-
compile_toc = time() - tic
422-
423-
@show compile_toc
360+
function run_abernathey_channel_benchmark!(results::Dict{String,Float64}, backend::String)
361+
architecture = ReactantState()
362+
363+
Δt₀ = 2.5minutes
364+
365+
# Make the grid:
366+
grid = make_grid(architecture, Nx, Ny, Nz, z_faces)
367+
model = build_model(grid, Δt₀, parameters)
368+
T_flux = T_flux_init(model.grid, parameters)
369+
u_wind_stress = u_wind_stress_init(model.grid, parameters)
370+
v_wind_stress = v_wind_stress_init(model.grid, parameters)
371+
Tᵢ, Sᵢ = temperature_salinity_init(model.grid, parameters)
372+
mld = Field{Center,Center,Nothing}(model.grid) # Not used for now
373+
Δz = Reactant.to_rarray(Δz)
374+
375+
dmodel = Enzyme.make_zero(model)
376+
dTᵢ = Field{Center,Center,Center}(model.grid)
377+
dSᵢ = Field{Center,Center,Center}(model.grid)
378+
du_wind_stress = Field{Face,Center,Nothing}(model.grid)
379+
dv_wind_stress = Field{Center,Face,Nothing}(model.grid)
380+
dT_flux = Field{Center,Center,Nothing}(model.grid)
381+
dmld = Field{Center,Center,Nothing}(model.grid)
382+
dΔz = Enzyme.make_zero(Δz)
383+
384+
# Profile and time the spinup_reentrant_channel_model!
385+
time_spinup_reentrant_channel_model! = Reactant.Profiler.profile_with_xprof(
386+
spinup_reentrant_channel_model!,
387+
model,
388+
Tᵢ,
389+
Sᵢ,
390+
u_wind_stress,
391+
v_wind_stress,
392+
T_flux;
393+
nrepeat=10,
394+
warmup=1,
395+
compile_options=CompileOptions(; raise=true, raise_first=true),
396+
)
397+
results["Oceananigans/SpinUpReentrantChannelModel/$(backend)/Primal"] =
398+
time_spinup_reentrant_channel_model!.profiling_result.runtime_ns / 1e9
424399

425-
# Spinup the model for a sufficient amount of time, save the T and S from this state:
426-
tic = time()
427-
rspinup_reentrant_channel_model!(model, Tᵢ, Sᵢ, u_wind_stress, v_wind_stress, T_flux)
428-
@allowscalar set!(Tᵢ, model.tracers.T)
429-
@allowscalar set!(Sᵢ, model.tracers.S)
430-
spinup_toc = time() - tic
431-
@show spinup_toc
400+
# Spinup the model for a sufficient amount of time, save the T and S from this state:
401+
rspinup_reentrant_channel_model! = @compile raise_first = true raise = true sync = true spinup_reentrant_channel_model!(
402+
model, Tᵢ, Sᵢ, u_wind_stress, v_wind_stress, T_flux
403+
)
404+
rspinup_reentrant_channel_model!(model, Tᵢ, Sᵢ, u_wind_stress, v_wind_stress, T_flux)
405+
@allowscalar set!(Tᵢ, model.tracers.T)
406+
@allowscalar set!(Sᵢ, model.tracers.S)
432407

433-
tic = time()
434-
#output = restimate_tracer_error(model, Tᵢ, Sᵢ, u_wind_stress, v_wind_stress, T_flux, Δz, mld)
435-
dedν = rdifferentiate_tracer_error(
436-
model,
437-
Tᵢ,
438-
Sᵢ,
439-
u_wind_stress,
440-
v_wind_stress,
441-
T_flux,
442-
Δz,
443-
mld,
444-
dmodel,
445-
dTᵢ,
446-
dSᵢ,
447-
du_wind_stress,
448-
dv_wind_stress,
449-
dT_flux,
450-
dΔz,
451-
dmld,
452-
)
453-
run_toc = time() - tic
408+
# Profile and time the differentiate_tracer_error
409+
time_differentiate_tracer_error = Reactant.Profiler.profile_with_xprof(
410+
differentiate_tracer_error,
411+
model,
412+
Tᵢ,
413+
Sᵢ,
414+
u_wind_stress,
415+
v_wind_stress,
416+
T_flux,
417+
Δz,
418+
mld,
419+
dmodel,
420+
dTᵢ,
421+
dSᵢ,
422+
du_wind_stress,
423+
dv_wind_stress,
424+
dT_flux,
425+
dΔz,
426+
dmld;
427+
nrepeat=10,
428+
warmup=1,
429+
compile_options=CompileOptions(; raise=true, raise_first=true),
430+
)
431+
results["Oceananigans/DifferentiateTracerError/$(backend)/Reverse"] =
432+
time_differentiate_tracer_error!.profiling_result.runtime_ns / 1e9
454433

455-
@show run_toc
456-
#@show output
434+
return nothing
435+
end
457436

458-
@show dedν
437+
if abspath(PROGRAM_FILE) == @__FILE__
438+
backend = get_backend()
439+
results = Dict()
440+
run_abernathey_channel_benchmark!(results, backend)
441+
end
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Oceananigans Benchmarks Runner
2+
# This script runs all oceananigans benchmarks and stores results to a JSON file
3+
4+
include("../utils.jl")
5+
6+
@info sprint(io -> versioninfo(io; verbose=true))
7+
8+
backend = get_backend()
9+
10+
# Load dependencies used in benchmarks
11+
using Reactant, LinearAlgebra, Enzyme
12+
using Chairmarks: @b
13+
using Random: Random
14+
using Printf: @sprintf
15+
16+
# Include benchmark modules
17+
module AbernatheyChannel
18+
19+
include("abernathey_channel.jl")
20+
21+
end
22+
23+
# Run all benchmarks
24+
function run_all_benchmarks(backend::String)
25+
results = Dict{String,Float64}()
26+
27+
AbernatheyChannel.run_abernathey_channel_benchmark!(results, backend)
28+
29+
return results
30+
end
31+
32+
results = run_all_benchmarks(backend)
33+
34+
save_results(results, joinpath(@__DIR__, "results"), "oceananigans", backend)

benchmark/polybench/common.jl

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,7 @@ using LinearAlgebra: norm
33
using Chairmarks: @b
44
using Printf: @sprintf
55

6-
function get_backend()
7-
# To run benchmarks on a specific backend
8-
BENCHMARK_GROUP = get(ENV, "BENCHMARK_GROUP", nothing)
9-
10-
if BENCHMARK_GROUP == "CUDA"
11-
Reactant.set_default_backend("gpu")
12-
@info "Running CUDA benchmarks" maxlog = 1
13-
elseif BENCHMARK_GROUP == "TPU"
14-
Reactant.set_default_backend("tpu")
15-
elseif BENCHMARK_GROUP == "CPU"
16-
Reactant.set_default_backend("cpu")
17-
@info "Running CPU benchmarks" maxlog = 1
18-
else
19-
BENCHMARK_GROUP = String(split(string(first(Reactant.devices())), ":")[1])
20-
@info "Running $(BENCHMARK_GROUP) benchmarks" maxlog = 1
21-
end
22-
return BENCHMARK_GROUP
23-
end
6+
include("../utils.jl")
247

258
function recursive_check(x::AbstractArray, y::AbstractArray; kwargs...)
269
res = isapprox(x, y; norm=Base.Fix2(norm, Inf), kwargs...)

benchmark/polybench/runbenchmarks.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Polybench Benchmarks Runner
22
# This script runs all polybench benchmarks and stores results to a JSON file
33

4-
include("../utils.jl")
4+
include("common.jl")
55

66
@info sprint(io -> versioninfo(io; verbose=true))
77

@@ -14,7 +14,6 @@ using Random: Random
1414
using Printf: @sprintf
1515

1616
# Include benchmark modules
17-
include("common.jl")
1817
include("stencil.jl")
1918
include("data_mining.jl")
2019
include("blas.jl")

0 commit comments

Comments
 (0)