Skip to content

Conversation

@taimoorsohail
Copy link
Collaborator

I have created a wrapper for xESMF (python package) so we can generally regrid any field from source to destination. At the moment, the PR doesn't include vertical regrinding (i.e., source and destinations vertical grids need to be the same).

Note:

  1. There are some speed ups that could be achieved in the actual matrix multiplication step.
  2. I know an extension might be better but perhaps we can work on that together - I am just opening a PR so the collaboration can begin!

@glwagner
Copy link
Member

Great! A few comments off the bat:

  1. We should implement some of this functionality in an extension so that we don't add a dependency on PyCall
  2. I believe we also want to use PythonCall rather than PyCall: https://github.com/JuliaPy/PythonCall.jl. It appears to be actively developed (unlike PyCall) which is why ClimaOcean also uses PythonCall.

I am happy to implement these changes if that's ok @taimoorsohail ?

@glwagner
Copy link
Member

Ah and another thing --- regrid! is defined in the Fields module rather than the Diagnostics module. The idea is that this concept is purely associated with Field, similar to interpolate!. The Diagnostics module is actually model specific; eg it is a place for things that compute quantities that depend on the type of AbstractModel being used, such as the CFL constraint

@taimoorsohail
Copy link
Collaborator Author

Yep @glwagner very happy for you to implement. Note that I did switch to PythonCall at some point and then I reverted back. I can't remember why exactly, it should be totally fine though...

@taimoorsohail
Copy link
Collaborator Author

By the way, xESMF requires calling MPI to work, as far as I could tell. That means that if you are using the regridder you need to call using MPI and MPI.init() even if you don't actually use MPI in the code. Perhaps there's a way to bypass this.

end

function coordinate_dataset(grid::SomeTripolarGrid)
lat = Array(grid.φᶜᶜᵃ[1:grid.Nx, 1:grid.Ny])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, the preferred syntax is

lon, lat, z = nodes(grid, Center(), Center(), Center())

@navidcy
Copy link
Member

navidcy commented Sep 12, 2025

Great! A few comments off the bat:

  1. We should implement some of this functionality in an extension so that we don't add a dependency on PyCall
  2. I believe we also want to use PythonCall rather than PyCall: https://github.com/JuliaPy/PythonCall.jl. It appears to be actively developed (unlike PyCall) which is why ClimaOcean also uses PythonCall.

I am happy to implement these changes if that's ok @taimoorsohail ?

I started working on transitioning to an extension

@glwagner
Copy link
Member

Great! A few comments off the bat:

  1. We should implement some of this functionality in an extension so that we don't add a dependency on PyCall
  2. I believe we also want to use PythonCall rather than PyCall: https://github.com/JuliaPy/PythonCall.jl. It appears to be actively developed (unlike PyCall) which is why ClimaOcean also uses PythonCall.

I am happy to implement these changes if that's ok @taimoorsohail ?

I started working on transitioning to an extension

We are committing on top of each other, so lets coordinate

navidcy and others added 2 commits September 26, 2025 10:17
@navidcy
Copy link
Member

navidcy commented Sep 26, 2025

@simone-silvestri , using NumericalEarth/XESMF.jl#16 with the MWE below I get:

using Oceananigans
using XESMF

tg = TripolarGrid(; size=(360, 170, 1), z = (-1, 0), southernmost_latitude = -80)
llg = LatitudeLongitudeGrid(; size=(360, 180, 1), z = (-1, 0),
                            longitude=(0, 360), latitude=(-82, 90))

src_field = CenterField(tg)
dst_field = CenterField(llg)

λ₀, φ₀ = 150, 30.  # degrees
width = 12         # degrees
set!(src_field, (λ, φ, z) -> exp(-((λ - λ₀)^2 +- φ₀)^2) / 2width^2))

regridder = XESMF.Regridder(dst_field, src_field, method="conservative")

regrid!(dst_field, regridder, src_field)
julia> using Oceananigans

julia> using XESMF

julia> tg = TripolarGrid(; size=(360, 170, 1), z = (-1, 0), southernmost_latitude = -80)
360×170×1 OrthogonalSphericalShellGrid{Float64, Periodic, RightConnected, Bounded} on CPU with 4×4×4 halo and with precomputed metrics
├── centered at (λ, φ) = (70.0, 1.8005)
├── longitude: Periodic  extent 360.156 degrees                        variably spaced with min(Δλ)=0.00798438, max(Δλ)=1.05373
├── latitude:  Oceananigans.Grids.RightConnected  extent 171.0 degrees variably spaced with min(Δφ)=0.0127345, max(Δφ)=1.00592
└── z:         Bounded  z  [-1.0, 0.0]                                regularly spaced with Δz=1.0

julia> llg = LatitudeLongitudeGrid(; size=(360, 180, 1), z = (-1, 0),
                                   longitude=(0, 360), latitude=(-82, 90))
360×180×1 LatitudeLongitudeGrid{Float64, Periodic, Bounded, Bounded} on CPU with 3×3×1 halo and with precomputed metrics
├── longitude: Periodic λ  [0.0, 360.0)  regularly spaced with Δλ=1.0
├── latitude:  Bounded  φ  [-82.0, 90.0] regularly spaced with Δφ=0.955556
└── z:         Bounded  z  [-1.0, 0.0]   regularly spaced with Δz=1.0

julia> src_field = CenterField(tg)
360×170×1 Field{Center, Center, Center} on OrthogonalSphericalShellGrid on CPU
├── grid: 360×170×1 OrthogonalSphericalShellGrid{Float64, Periodic, RightConnected, Bounded} on CPU with 4×4×4 halo and with precomputed metrics
├── boundary conditions: FieldBoundaryConditions
│   └── west: Periodic, east: Periodic, south: ZeroFlux, north: Zipper(1.0), bottom: ZeroFlux, top: ZeroFlux, immersed: Nothing
└── data: 368×178×9 OffsetArray(::Array{Float64, 3}, -3:364, -3:174, -3:5) with eltype Float64 with indices -3:364×-3:174×-3:5
    └── max=0.0, min=0.0, mean=0.0

julia> dst_field = CenterField(llg)
360×180×1 Field{Center, Center, Center} on LatitudeLongitudeGrid on CPU
├── grid: 360×180×1 LatitudeLongitudeGrid{Float64, Periodic, Bounded, Bounded} on CPU with 3×3×1 halo and with precomputed metrics
├── boundary conditions: FieldBoundaryConditions
│   └── west: Periodic, east: Periodic, south: ZeroFlux, north: Value, bottom: ZeroFlux, top: ZeroFlux, immersed: Nothing
└── data: 366×186×3 OffsetArray(::Array{Float64, 3}, -2:363, -2:183, 0:2) with eltype Float64 with indices -2:363×-2:183×0:2
    └── max=0.0, min=0.0, mean=0.0

julia> λ₀, φ₀ = 150, 30.  # degrees
(150, 30.0)

julia> width = 12         # degrees
12

julia> set!(src_field, (λ, φ, z) -> exp(-((λ - λ₀)^2 +- φ₀)^2) / 2width^2))
360×170×1 Field{Center, Center, Center} on OrthogonalSphericalShellGrid on CPU
├── grid: 360×170×1 OrthogonalSphericalShellGrid{Float64, Periodic, RightConnected, Bounded} on CPU with 4×4×4 halo and with precomputed metrics
├── boundary conditions: FieldBoundaryConditions
│   └── west: Periodic, east: Periodic, south: ZeroFlux, north: Zipper(1.0), bottom: ZeroFlux, top: ZeroFlux, immersed: Nothing
└── data: 368×178×9 OffsetArray(::Array{Float64, 3}, -3:364, -3:174, -3:5) with eltype Float64 with indices -3:364×-3:174×-3:5
    └── max=0.998852, min=3.66505e-85, mean=0.0130418

julia> regridder = XESMF.Regridder(dst_field, src_field, method="conservative")
Conservative Regridder
├── weights: 64800×61200 SparseArrays.SparseMatrixCSC{Float64, Int64} with 263516 stored entries
├── src_temp: 61200-element Vector{Float64}
└── dst_temp: 64800-element Vector{Float64}

julia> regrid!(dst_field, regridder, src_field)
ERROR: UndefVarError: `regridder` not defined
Stacktrace:
 [1] regrid!(dst_field::Field{…}, regrider::XESMF.Regridder{…}, src_field::Field{…})
   @ OceananigansXESMFExt ~/Library/CloudStorage/OneDrive-TheUniversityofMelbourne/Documents/Research/Oceananigans.jl-v4/ext/OceananigansXESMFExt.jl:190
 [2] top-level scope
   @ REPL[34]:1
Some type information was truncated. Use `show(err)` to see complete types.

@navidcy
Copy link
Member

navidcy commented Sep 26, 2025

oh I have a typo... nevermind

@simone-silvestri
Copy link
Collaborator

I get this:

julia> regridder = XESMF.Regridder(dst_field, src_field, method="conservative")
conservative Regridder
├── weights: 64800×61200 SparseArrays.SparseMatrixCSC{Float64, Int64} with 263516 stored entries
├── src_temp: 61200-element Vector{Float64}
└── dst_temp: 64800-element Vector{Float64}

julia> regrid!(dst_field, regridder, src_field)
ERROR: UndefVarError: `topology` not defined
Stacktrace:
 [1] regrid!(dst_field::Field{…}, regrider::XESMF.Regridder{…}, src_field::Field{…})
   @ OceananigansXESMFExt ~/development/TestOceananigans.jl/ext/OceananigansXESMFExt.jl:182
 [2] top-level scope
   @ REPL[13]:1
Some type information was truncated. Use `show(err)` to see complete types.

I think we need

using Oceananigans.Grids: topology

@simone-silvestri
Copy link
Collaborator

just with this

regridder(vec(interior(dst_field)), vec(interior(src_field)))

it works

@xkykai
Copy link
Collaborator

xkykai commented Sep 27, 2025

Looks great! Should we consider the case for regridding in GPU as well? I was trying it out here
https://github.com/CliMA/ClimaOceanCalibration.jl/blob/4c379b8de655006722a1cf024da80e3b6b9a78aa/src/DataWrangling/bilinear_interpolator.jl
and came up with a (not so smart) way to do it temporarily:

function regrid!(dst, weights::CuSparseMatrixCSC, src)
    vec(dst) .= weights * CuArray(vec(src))
end

By the same approach to do it over the entire 3D AbstractField another not so smart way to do it is probably

function regrid!(dst::AbstractField, src::AbstractField, interpolator::BilinearInterpolator)
    weights = interpolator.weights
    
    # Get the interior data
    dst_data = interior(dst)
    src_data = interior(src)
    
    Nz = size(src_data, 3)
    
    for k in 1:Nz
        src_slice = view(src_data, :, :, k)
        dst_slice = view(dst_data, :, :, k)
        regrid!(dst_slice, weights, src_slice)
    end
    
    return dst
end

I haven't put too much thought into this, and it looks inefficient as I had to allocate another CuArray for sparse matrix multiply for CuSparseMatrixCSC to work. This doesn't make use of the kernel approach as well. I haven't managed to get sparse matrix multiply to work within a kernel.

Perhaps a better way to go about this is to actually write our own sparse matrix multiply? Doing that can allow us to dispatch kernels to do the regridding over all vertical levels at the same time. It doesn't seem too hard to write but I also haven't thought very deeply about it.

@navidcy navidcy changed the title Add XESMF.jl extension to use xESMF to compute tracer regridding weights (0.99.2) Add XESMF.jl extension to use xESMF to compute tracer regridding weights Sep 27, 2025
@navidcy
Copy link
Member

navidcy commented Sep 27, 2025

I'm merging this as a good starting point and will continue working on it.

Some outstanding issues are:

  • Cleanup of:
    function x_node_array(x::AbstractVector, Nx, Ny)
    return Array(repeat(view(x, 1:Nx), 1, Ny))'
    end
    function y_node_array(x::AbstractVector, Nx, Ny)
    return Array(repeat(view(x, 1:Ny)', Nx, 1))'
    end
    x_node_array(x::AbstractMatrix, Nx, Ny) = Array(view(x, 1:Nx, 1:Ny))'
    function x_vertex_array(x::AbstractVector, Nx, Ny)
    return Array(repeat(view(x, 1:Nx+1), 1, Ny+1))'
    end
    function y_vertex_array(x::AbstractVector, Nx, Ny)
    return Array(repeat(view(x, 1:Ny+1)', Nx+1, 1))'
    end
    x_vertex_array(x::AbstractMatrix, Nx, Ny) = Array(view(x, 1:Nx+1, 1:Ny+1))'
    y_node_array(x::AbstractMatrix, Nx, Ny) = x_node_array(x, Nx, Ny)
    y_vertex_array(x::AbstractMatrix, Nx, Ny) = x_vertex_array(x, Nx, Ny)

    eg use permutedims instead of '
  • Ensure all is smooth with GPU architectures; add tests for GPU (perhaps move from GitHub actions to buildkite)
  • Does the regridding reproduces the Python xESMF regridding?
  • (Perhaps related to the above) Can the integral before and after regridding be better than rtol = 1e-4?
  • Documentation/Example for how to do a regridding using XESMF.jl (docstrings with examples should suffice for now; convert some example(s) to doctest(s)).

@navidcy navidcy merged commit ff6d022 into main Sep 27, 2025
69 checks passed
@navidcy navidcy deleted the ts/implement-conservative-regridder branch September 27, 2025 14:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

extensions 🧬 feature 🌟 Something new and shiny

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants