Skip to content

Kernel function interpolate_primary_atmospheric_state! calls intrinsic_vector, which can lead to sind call inside GPU kernel. #749

@jlk9

Description

@jlk9

This was discovered through this issue: EnzymeAD/Reactant.jl#2345

The kernel function _interpolate_primary_atmospheric_state calls the Oceananigans function intrinsic_vector:

@kernel function _interpolate_primary_atmospheric_state!(surface_atmos_state,
space_fractional_indices,
time_interpolator,
exchange_grid,
atmos_velocities,
atmos_tracers,
atmos_pressure,
downwelling_radiation,
prescribed_freshwater_flux,
atmos_backend,
atmos_time_indexing)
i, j = @index(Global, NTuple)
ii = space_fractional_indices.i
jj = space_fractional_indices.j
fi = get_fractional_index(i, j, ii)
fj = get_fractional_index(i, j, jj)
x_itp = FractionalIndices(fi, fj, nothing)
t_itp = time_interpolator
atmos_args = (x_itp, t_itp, atmos_backend, atmos_time_indexing)
uₐ = interp_atmos_time_series(atmos_velocities.u, atmos_args...)
vₐ = interp_atmos_time_series(atmos_velocities.v, atmos_args...)
Tₐ = interp_atmos_time_series(atmos_tracers.T, atmos_args...)
qₐ = interp_atmos_time_series(atmos_tracers.q, atmos_args...)
pₐ = interp_atmos_time_series(atmos_pressure, atmos_args...)
Qs = interp_atmos_time_series(downwelling_radiation.shortwave, atmos_args...)
Qℓ = interp_atmos_time_series(downwelling_radiation.longwave, atmos_args...)
# Usually precipitation
Mh = interp_atmos_time_series(prescribed_freshwater_flux, atmos_args...)
# Convert atmosphere velocities (usually defined on a latitude-longitude grid) to
# the frame of reference of the native grid
kᴺ = size(exchange_grid, 3) # index of the top ocean cell
uₐ, vₐ = intrinsic_vector(i, j, kᴺ, exchange_grid, uₐ, vₐ)

The problem is that (at least in some situations, like the script in the linked issue) the exchange_grid passed to intrinsic_vector is an OrthogonalSphericalShellGrid. The method for intrinsic_vector in this case calls sind, cosd, and atand (some through rotation_angle) which are all not supported in GPU kernels:

https://github.com/CliMA/Oceananigans.jl/blob/ae70538358c5ddf9e38267cccbd2bbf05bf7b66a/src/Operators/vector_rotation_operators.jl#L96-L109

Thus calling _interpolate_primary_atmospheric_state with a reactant or GPU backend can error.

A fix would be to modify _interpolate_primary_atmospheric_state so it doesn't call intrinsic_vector. Alternatively that method call of intrinsic_vector can be modified to get rid of the problem trig functions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions