|
1 | | -using Oceananigans |
2 | | -ENV["GKSwstype"] = "100" |
| 1 | +using Oceananigans, Printf, Statistics, SeawaterPolynomials, CUDA, Reactant, Enzyme |
3 | 2 |
|
4 | | -using Printf |
5 | | -using Statistics |
6 | | - |
7 | | -using Oceananigans |
8 | 3 | using Oceananigans.Units |
9 | 4 | using Oceananigans.OutputReaders: FieldTimeSeries |
10 | 5 | using Oceananigans.Grids: xnode, ynode, znode |
11 | 6 | using Oceananigans.TurbulenceClosures: CATKEVerticalDiffusivity, HorizontalFormulation |
12 | | - |
13 | | -using SeawaterPolynomials |
14 | | - |
15 | | -using CUDA |
16 | | - |
17 | | -using Reactant |
18 | 7 | using Oceananigans.Architectures: ReactantState |
19 | 8 |
|
20 | | -using Enzyme |
21 | | - |
22 | 9 | Oceananigans.defaults.FloatType = Float64 |
23 | 10 |
|
| 11 | +include("../utils.jl") |
| 12 | + |
24 | 13 | graph_directory = "run_abernathy_model_ad_spinup100_100steps/" |
25 | 14 |
|
26 | 15 | # number of grid points |
@@ -368,91 +357,85 @@ end |
368 | 357 | ##### Actually creating our model and using these functions to run it: |
369 | 358 | ##### |
370 | 359 |
|
371 | | -# Architecture |
372 | | -architecture = ReactantState() |
373 | | - |
374 | | -# Timestep size: |
375 | | -Δt₀ = 2.5minutes |
376 | | - |
377 | | -# Make the grid: |
378 | | -grid = make_grid(architecture, Nx, Ny, Nz, z_faces) |
379 | | -model = build_model(grid, Δt₀, parameters) |
380 | | -T_flux = T_flux_init(model.grid, parameters) |
381 | | -u_wind_stress = u_wind_stress_init(model.grid, parameters) |
382 | | -v_wind_stress = v_wind_stress_init(model.grid, parameters) |
383 | | -Tᵢ, Sᵢ = temperature_salinity_init(model.grid, parameters) |
384 | | -mld = Field{Center,Center,Nothing}(model.grid) # Not used for now |
385 | | -Δz = Reactant.to_rarray(Δz) |
386 | | - |
387 | | -dmodel = Enzyme.make_zero(model) |
388 | | -dTᵢ = Field{Center,Center,Center}(model.grid) |
389 | | -dSᵢ = Field{Center,Center,Center}(model.grid) |
390 | | -du_wind_stress = Field{Face,Center,Nothing}(model.grid) |
391 | | -dv_wind_stress = Field{Center,Face,Nothing}(model.grid) |
392 | | -dT_flux = Field{Center,Center,Nothing}(model.grid) |
393 | | -dmld = Field{Center,Center,Nothing}(model.grid) |
394 | | -dΔz = Enzyme.make_zero(Δz) |
395 | | - |
396 | | -# Trying zonal transport: |
397 | | - |
398 | | -tic = time() |
399 | | -rspinup_reentrant_channel_model! = @compile raise_first = true raise = true sync = true spinup_reentrant_channel_model!( |
400 | | - model, Tᵢ, Sᵢ, u_wind_stress, v_wind_stress, T_flux |
401 | | -) |
402 | | -#restimate_tracer_error = @compile raise_first=true raise=true sync=true estimate_tracer_error(model, Tᵢ, Sᵢ, u_wind_stress, v_wind_stress, T_flux, Δz, mld) |
403 | | -rdifferentiate_tracer_error = @compile raise_first = true raise = true sync = true differentiate_tracer_error( |
404 | | - model, |
405 | | - Tᵢ, |
406 | | - Sᵢ, |
407 | | - u_wind_stress, |
408 | | - v_wind_stress, |
409 | | - T_flux, |
410 | | - Δz, |
411 | | - mld, |
412 | | - dmodel, |
413 | | - dTᵢ, |
414 | | - dSᵢ, |
415 | | - du_wind_stress, |
416 | | - dv_wind_stress, |
417 | | - dT_flux, |
418 | | - dΔz, |
419 | | - dmld, |
420 | | -) |
421 | | -compile_toc = time() - tic |
422 | | - |
423 | | -@show compile_toc |
| 360 | +function run_abernathey_channel_benchmark!(results::Dict{String,Float64}, backend::String) |
| 361 | + architecture = ReactantState() |
| 362 | + |
| 363 | + Δt₀ = 2.5minutes |
| 364 | + |
| 365 | + # Make the grid: |
| 366 | + grid = make_grid(architecture, Nx, Ny, Nz, z_faces) |
| 367 | + model = build_model(grid, Δt₀, parameters) |
| 368 | + T_flux = T_flux_init(model.grid, parameters) |
| 369 | + u_wind_stress = u_wind_stress_init(model.grid, parameters) |
| 370 | + v_wind_stress = v_wind_stress_init(model.grid, parameters) |
| 371 | + Tᵢ, Sᵢ = temperature_salinity_init(model.grid, parameters) |
| 372 | + mld = Field{Center,Center,Nothing}(model.grid) # Not used for now |
| 373 | + Δz = Reactant.to_rarray(Δz) |
| 374 | + |
| 375 | + dmodel = Enzyme.make_zero(model) |
| 376 | + dTᵢ = Field{Center,Center,Center}(model.grid) |
| 377 | + dSᵢ = Field{Center,Center,Center}(model.grid) |
| 378 | + du_wind_stress = Field{Face,Center,Nothing}(model.grid) |
| 379 | + dv_wind_stress = Field{Center,Face,Nothing}(model.grid) |
| 380 | + dT_flux = Field{Center,Center,Nothing}(model.grid) |
| 381 | + dmld = Field{Center,Center,Nothing}(model.grid) |
| 382 | + dΔz = Enzyme.make_zero(Δz) |
| 383 | + |
| 384 | + # Profile and time the spinup_reentrant_channel_model! |
| 385 | + time_spinup_reentrant_channel_model! = Reactant.Profiler.profile_with_xprof( |
| 386 | + spinup_reentrant_channel_model!, |
| 387 | + model, |
| 388 | + Tᵢ, |
| 389 | + Sᵢ, |
| 390 | + u_wind_stress, |
| 391 | + v_wind_stress, |
| 392 | + T_flux; |
| 393 | + nrepeat=10, |
| 394 | + warmup=1, |
| 395 | + compile_options=CompileOptions(; raise=true, raise_first=true), |
| 396 | + ) |
| 397 | + results["Oceananigans/SpinUpReentrantChannelModel/$(backend)/Primal"] = |
| 398 | + time_spinup_reentrant_channel_model!.profiling_result.runtime_ns / 1e9 |
424 | 399 |
|
425 | | -# Spinup the model for a sufficient amount of time, save the T and S from this state: |
426 | | -tic = time() |
427 | | -rspinup_reentrant_channel_model!(model, Tᵢ, Sᵢ, u_wind_stress, v_wind_stress, T_flux) |
428 | | -@allowscalar set!(Tᵢ, model.tracers.T) |
429 | | -@allowscalar set!(Sᵢ, model.tracers.S) |
430 | | -spinup_toc = time() - tic |
431 | | -@show spinup_toc |
| 400 | + # Spinup the model for a sufficient amount of time, save the T and S from this state: |
| 401 | + rspinup_reentrant_channel_model! = @compile raise_first = true raise = true sync = true spinup_reentrant_channel_model!( |
| 402 | + model, Tᵢ, Sᵢ, u_wind_stress, v_wind_stress, T_flux |
| 403 | + ) |
| 404 | + rspinup_reentrant_channel_model!(model, Tᵢ, Sᵢ, u_wind_stress, v_wind_stress, T_flux) |
| 405 | + @allowscalar set!(Tᵢ, model.tracers.T) |
| 406 | + @allowscalar set!(Sᵢ, model.tracers.S) |
432 | 407 |
|
433 | | -tic = time() |
434 | | -#output = restimate_tracer_error(model, Tᵢ, Sᵢ, u_wind_stress, v_wind_stress, T_flux, Δz, mld) |
435 | | -dedν = rdifferentiate_tracer_error( |
436 | | - model, |
437 | | - Tᵢ, |
438 | | - Sᵢ, |
439 | | - u_wind_stress, |
440 | | - v_wind_stress, |
441 | | - T_flux, |
442 | | - Δz, |
443 | | - mld, |
444 | | - dmodel, |
445 | | - dTᵢ, |
446 | | - dSᵢ, |
447 | | - du_wind_stress, |
448 | | - dv_wind_stress, |
449 | | - dT_flux, |
450 | | - dΔz, |
451 | | - dmld, |
452 | | -) |
453 | | -run_toc = time() - tic |
| 408 | + # Profile and time the differentiate_tracer_error |
| 409 | + time_differentiate_tracer_error = Reactant.Profiler.profile_with_xprof( |
| 410 | + differentiate_tracer_error, |
| 411 | + model, |
| 412 | + Tᵢ, |
| 413 | + Sᵢ, |
| 414 | + u_wind_stress, |
| 415 | + v_wind_stress, |
| 416 | + T_flux, |
| 417 | + Δz, |
| 418 | + mld, |
| 419 | + dmodel, |
| 420 | + dTᵢ, |
| 421 | + dSᵢ, |
| 422 | + du_wind_stress, |
| 423 | + dv_wind_stress, |
| 424 | + dT_flux, |
| 425 | + dΔz, |
| 426 | + dmld; |
| 427 | + nrepeat=10, |
| 428 | + warmup=1, |
| 429 | + compile_options=CompileOptions(; raise=true, raise_first=true), |
| 430 | + ) |
| 431 | + results["Oceananigans/DifferentiateTracerError/$(backend)/Reverse"] = |
| 432 | + time_differentiate_tracer_error!.profiling_result.runtime_ns / 1e9 |
454 | 433 |
|
455 | | -@show run_toc |
456 | | -#@show output |
| 434 | + return nothing |
| 435 | +end |
457 | 436 |
|
458 | | -@show dedν |
| 437 | +if abspath(PROGRAM_FILE) == @__FILE__ |
| 438 | + backend = get_backend() |
| 439 | + results = Dict() |
| 440 | + run_abernathey_channel_benchmark!(results, backend) |
| 441 | +end |
0 commit comments