Skip to content

SpeedyWeather + Reactant #970

Draft
maximilian-gelbrecht wants to merge 68 commits intomk/matrix_transformfrom
mg/reactant
Draft

SpeedyWeather + Reactant #970
maximilian-gelbrecht wants to merge 68 commits intomk/matrix_transformfrom
mg/reactant

Conversation

@maximilian-gelbrecht
Copy link
Member

@maximilian-gelbrecht maximilian-gelbrecht commented Jan 27, 2026

Initial preparations to run Speedy with Reactant.

This is based on mk/matrix_transform as I am thinking that it's easier to avoid FFT plans for the first attempt.

@maximilian-gelbrecht maximilian-gelbrecht added gpu 🖼️ Everthing GPU related array types 🔢 LowerTriangularMatrices and RingGrids labels Jan 27, 2026
@maximilian-gelbrecht maximilian-gelbrecht added the differentiability 🤖 Making the model differentiable via AD label Jan 27, 2026
@maximilian-gelbrecht maximilian-gelbrecht added the skip-gpu-ci Skip tests on GPU label Jan 27, 2026
@maximilian-gelbrecht maximilian-gelbrecht added the skip-docs Who needs a documentation anyway? label Jan 27, 2026
@maximilian-gelbrecht
Copy link
Member Author

maximilian-gelbrecht commented Jan 28, 2026

Test script

import Pkg 
Pkg.activate("SpeedyWeather")
using SpeedyWeather, Reactant, LinearAlgebra
#using CUDA 

arch = SpeedyWeather.ReactantDevice()
nlayers = 1 
spectral_grid = SpectralGrid(; architecture = arch, nlayers)
M = MatrixSpectralTransform(spectral_grid)

spec = randn(Complex{Float32}, spectral_grid.spectrum, nlayers)
field = zeros(Float32, spectral_grid.grid, nlayers)
   
tr = @compile transform!(spec, field, M)

# model construction works 
#model = PrimitiveWetModel(spectral_grid; spectral_transform=M)

# model initialization currently fails (because of the transforms used)
#simulation = initialize!(model)

@maximilian-gelbrecht
Copy link
Member Author

maximilian-gelbrecht commented Jan 28, 2026

At the moment, the model construction of all models works.
We can also run e.g. the spectral transform in matrix form with Reactant thanks to @mofeing

Next, we have to think about the full model initialisation. That's more a question for our own code organisation. When we construct the model with Reactant Arrays, we should probably do all the pre-compute with it as well, but how to do this best? The alternative would be doing the pre-compute without Reactant and then transferring and converting later, but that's a bit annoying in practice. I have to have a think about the best approach.

@milankl
Copy link
Member

milankl commented Jan 28, 2026

Btw calling NN parameterizations that are compiled with Reactant already works on the CPU, see #973

@milankl milankl changed the title Speedy + Reactant SpeedyWeather + Reactant Jan 28, 2026
"""$(TYPEDSIGNATURES)
Precomputes the hyper diffusion terms for all layers based on the
model time step in `L`, the vertical level sigma level in `G`."""
function initialize!(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to do really do all the initialization in kernels, with Reactant you also get warnings about scalar indexing here

We just have one or two more of those that transfer back and forth, I'll do the others later.

@maximilian-gelbrecht
Copy link
Member Author

maximilian-gelbrecht commented Jan 30, 2026

Good news before the weekend:

Both initialization and time stepping of the BarotropicModel (our simplest model) with Reactant compiles and runs without errors 🎉

I haven't done any form of correctness checks, and there are issues with the initial conditions, at the moment I tested it with just with zero ICs.

import Pkg 
Pkg.activate("SpeedyWeather")
using SpeedyWeather, Reactant, LinearAlgebra, CUDA # cuda is loaded for reactant but this is still on CPU

arch = SpeedyWeather.ReactantDevice()
nlayers = 1 
spectral_grid = SpectralGrid(; architecture = arch, nlayers)
M = MatrixSpectralTransform(spectral_grid)

spec = zeros(Complex{Float32}, spectral_grid.spectrum, nlayers)
field = rand(Float32, spectral_grid.grid, nlayers)
    
# model construction works 
#model = PrimitiveWetModel(spectral_grid; spectral_transform=M)
initial_conditions = InitialConditions(; vordiv=ZeroInitially())
# feedback == nothing is needed currently due to issues with Reactant 
model = BarotropicModel(spectral_grid; spectral_transform=M, feedback=nothing, initial_conditions=initial_conditions)
simulation = initialize!(model)

@jit SpeedyWeather.timestep!(simulation)

@maximilian-gelbrecht
Copy link
Member Author

maximilian-gelbrecht commented Jan 30, 2026

@milankl For proper correctness of our models with Reactant we need be to aware of the functions that do have significant branch divergences / control flow those limits/conditions change while running the model. Those need to be explicitly traced by Reactant. I don't have the best overview of that currently. I can see those e.g. in convection and condensation scheme, can you think about more like these?

@mofeing
Copy link

mofeing commented Jan 30, 2026

For proper correctness of our models with Reactant we need be to aware of the functions that do have significant branch divergences / control flow those limits/conditions change while running the model. Those need to be explicitly traced by Reactant.

for this, you can add the ReactantCore dependency to SpeedyWeather (which just adds the @trace macro) and you can add it wherever you need in the main code. but IMO I would just add it in the following cases:

  • control-flow where the predicate is traced (e.g. a if pred where pred::TracedRNumber{Bool})
  • loops where the partial evaluation unrolls the loop and the resulting code is tooo large

@maximilian-gelbrecht
Copy link
Member Author

Thanks!

  • loops where the partial evaluation unrolls the loop and the resulting code is tooo large

Just to be sure, what do you mean with the "the resulting code is too large", when do we have to watch out there? The typical situation in our model is that loop limits depend only on the model configuration that isn't changed at runtime, so I thought we don't need to trace them, with typical loop lengths of ~8-80 (for vertical layers). There are just some very few exceptions in our parametrizations.

@maximilian-gelbrecht
Copy link
Member Author

Works on GPU as well, expect I get a repeated warning

'86' is not a recognized feature for this target (ignoring feature)

Not sure what that means.

@maximilian-gelbrecht
Copy link
Member Author

While, the forward model is working, the reverse model / autodiff isn't.
It already needed a few adjustments of making types proper parametric, but the correct stopper is Dates, see: EnzymeAD/Reactant.jl#2046

@maximilian-gelbrecht
Copy link
Member Author

After talking to @mofeing , a custom Dates seems to be only real way to make the differentiation of the time stepping Reactant compatible. Just ditching using date types at all is the alternative, but we use them so much throughout the model that this might be even more annoying. With some help from Claude, I just set up a custom TracableDates submodule now that we load in with import TracableDates as Dates as a drop-in replacement for Dates.

@milankl
Copy link
Member

milankl commented Feb 25, 2026

After talking to @mofeing , a custom Dates seems to be only real way to make the differentiation of the time stepping Reactant compatible. Just ditching using date types at all is the alternative, but we use them so much throughout the model that this might be even more annoying. With some help from Claude, I just set up a custom TracableDates submodule now that we load in with import TracableDates as Dates as a drop-in replacement for Dates.

Sorry only trying to catch up with the updates here now. Ditching Dates would be a bummer though I might be happy to reinvent Dates to make it Reactant-compatible. However there is a quite a lot of functionality we use, including in physics/radiation/zenith.jl ... I hope this is not too much work. I guess this also removes our type piracy from #746 as a positive side effect ...

@maximilian-gelbrecht
Copy link
Member Author

maximilian-gelbrecht commented Feb 25, 2026

With regards to Dates I just made a PR for Reactant EnzymeAD/Reactant.jl#2540 that implements the custom dates format there to make Dates Reactant-compatible, so that we don't need to have ithe custom date format here anymore. I am not totally finished yet with the PR, I need to test if it actually works here with what we do in Speedy as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

array types 🔢 LowerTriangularMatrices and RingGrids differentiability 🤖 Making the model differentiable via AD gpu 🖼️ Everthing GPU related skip-docs Who needs a documentation anyway? skip-gpu-ci Skip tests on GPU

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants