Skip to content

Dates extension#2540

Open
maximilian-gelbrecht wants to merge 44 commits intomainfrom
mg/dates-extension
Open

Dates extension#2540
maximilian-gelbrecht wants to merge 44 commits intomainfrom
mg/dates-extension

Conversation

@maximilian-gelbrecht
Copy link
Collaborator

@maximilian-gelbrecht maximilian-gelbrecht commented Feb 24, 2026

Julia stdlib's Dates are currently not traceable as they are hard coded to have Int64 fields. Unfortunately a lot of functions defined for its datatypes are also defined on these concrete types and not on abstract type. So, we have to reimplement a large part of Dates for this to work. This is currently not a complete implementation of really everything that Dates does. It's a surprisingly large package.

My personal implementation for Speedy so far was just a drop-in replacement for Dates. This here works different, because it's supposed to work for users just working with regular Dates.

I am not that familiar with Reactant yet, so I also have a few questions, if I actually implemented the right things ;)
Currently I extend make_tracer and traced_type_inner for all types that may be traced.

Doing

using Reactant, Dates 
dt = Dates.DateTime(1999, 12, 27)
res = @jit(dt + Dates.Minute(1))

Then, res <: TracedRDateTime. I guess that's expected, or can we convert this back into the untraced type for the user? Is this even something we would want?

Aside from this issue, I have to add one or two slightly more complex unit tests and also see if this actually still works with what we need for Speedy or if I need to add more functionality. Folder structure and definitions follow those of Dates.

Edit: Oops, I just jlfmt but I think with the wrong style...

Should also resolve #2046 (cc @langestefan)

maximilian-gelbrecht and others added 29 commits February 24, 2026 19:44
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@maximilian-gelbrecht
Copy link
Collaborator Author

maximilian-gelbrecht commented Feb 26, 2026

Something isn't working yet when tracing Dates, I don't understand why.

Setting up an integration example modelled after the time stepping we do in Speedy. When I add a @trace the dates are not working correctly anymore. Without it it's fine and as expected.

using Reactant
using Dates

mutable struct Clock{I, T, TS}
    n_timesteps::I
    time::T
    time_step::TS
end

# Minimal state
struct State{T, C}
    x::T
    clock::C
end

function timestepping!(state::State)
    (; clock) = state

    for i in 1:(clock.n_timesteps)
        timestep!(state)
    end

    return state
end

function timestepping_trace!(state::State)
    (; clock) = state

    @trace for i in 1:(clock.n_timesteps)
        timestep!(state)
    end

    return state
end

function timestep!(state::State)
    state.x .+= 1.0
    state.clock.time += state.clock.time_step
    return nothing
end

clock = Clock(5, DateTime(2002, 1, 1), Dates.Day(1))
state = State(zeros(1), clock)
state = timestepping!(state)

clock_jit = Clock(5, DateTime(2002, 1, 1), Dates.Day(1))
state_jit = State(Reactant.to_rarray(zeros(1)), clock_jit)
state_jit = @jit timestepping!(state_jit)

clock_jit_trace = Clock(5, DateTime(2002, 1, 1), Dates.Day(1))
state_jit_trace = State(Reactant.to_rarray(zeros(1)), clock_jit_trace)
state_jit_trace = @jit timestepping_trace!(state_jit_trace)

# state_jit.clock ≈ state.clock 
# state_jit_trace.clock != state.clock 
# state_jit_trace.clock is unchanged stays at the initial value DateTime(2002, 1, 1), even though the state itself does iterate and: 
# state_jit_trace.x == state.x 

@langestefan
Copy link

langestefan commented Feb 26, 2026

It's compiling for me now, but Vector{DateTime} seems to still get folded in as a constant. While a single DateTime does work now. That may be something specific to how I'm using it, I'm still investigating.

edit: Nevermind, single datetime is also being folded in as a constant. Even though it is returning a ConcretePJRTNumber. How bizarre.

@maximilian-gelbrecht
Copy link
Collaborator Author

It's compiling for me now, but Vector{DateTime} seems to still get folded in as a constant. While a single DateTime does work now. That may be something specific to how I'm using it, I'm still investigating.

edit: Nevermind, single datetime is also being folded in as a constant. Even though it is returning a ConcretePJRTNumber. How bizarre.

Can you post an example? For me this actually seems to work, I get

using Reactant, Dates 

julia> arr = [DateTime(i,1,1) for i=2000:2010]
11-element Vector{DateTime}:
 2000-01-01T00:00:00
 2001-01-01T00:00:00

julia> Reactant.to_rarray(arr)
11-element Vector{ReactantDatesExt.TracedR
DateTime{Int64}}:
 2000-01-01T00:00:00
 2001-01-01T00:00:00

@langestefan
Copy link

langestefan commented Feb 26, 2026

Here is the script I am using:

Script
"""
Benchmark: Reactant.jl @compile vs plain Julia for solar_position (PSA algorithm).
"""

using Pkg

Pkg.activate("./reactant")

using Dates: AbstractDateTime, DateTime, datetime2julian
using Reactant

using BenchmarkTools

# Set backend to CPU
Reactant.set_default_backend("cpu")

# ============================================================================
# Setup
# ============================================================================


struct Observer{T<:AbstractFloat}
    "Geodetic latitude (+N)"
    latitude::T
    "Longitude (+E)"
    longitude::T        # longitude (+E)
    "Altitude above mean sea level (meters)"
    altitude::T         # altitude above MSL
    "Horizon angle in degrees (e.g., for refraction or sunrise/sunset calculations)"
    horizon::T
    "Latitude in radians"
    latitude_rad::T
    "Longitude in radians"
    longitude_rad::T
    "sin(latitude)"
    sin_lat::T
    "cos(latitude)"
    cos_lat::T

    function Observer{T}(
        lat::T,
        lon::T,
        alt::T=zero(T),
        horiz::T=zero(T),
    ) where {T<:AbstractFloat}
        lat_rad = deg2rad(lat)
        lon_rad = deg2rad(lon)
        (sin_lat, cos_lat) = sincos(lat_rad)
        return new{T}(lat, lon, alt, horiz, lat_rad, lon_rad, sin_lat, cos_lat)
    end
end

function Observer(lat::T, lon::T, alt::T=zero(T), horiz::T=zero(T)) where {T<:AbstractFloat}
    return Observer{T}(lat, lon, alt, horiz)
end

obs = Observer(51.5074, -0.1278)  # London
dt = DateTime(2024, 6, 21, 12, 0, 0)

struct SolPos{T}
    "Azimuth (degrees, 0=N, +clockwise, range [-180, 180])"
    azimuth::T
    "Elevation (degrees, range [-90, 90])"
    elevation::T
    "Zenith = 90 - elevation (degrees, range [0, 180])"
    zenith::T
end

# ============================================================================
# PSA Algorithm
# ============================================================================

# dt.instant.periods.value = milliseconds since epoch
@inline fractional_hour(dt::AbstractDateTime) = (dt.instant.periods.value % 86_400_000) / 3_600_000

# constants
const EMR = 6371.01  # Earth Mean Radius in km
const AU = 149597890.0  # Astronomical Unit in km

function _solar_position(obs::Observer{T}, dt::AbstractDateTime) where {T}
    # Get parameters as tuple (allocation-free)
    p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11, p12, p13, p14, p15 = (
        2.267127827,
        -9.300339267e-4,
        4.895036035,
        1.720279602e-2,
        6.239468336,
        1.720200135e-2,
        3.338320972e-2,
        3.497596876e-4,
        -1.544353226e-4,
        -8.68972936e-6,
        4.090904909e-1,
        -6.213605399e-9,
        4.418094944e-5,
        6.697096103,
        6.570984737e-2,
    )

    # elapsed julian days (n) since J2000.0
    jd = datetime2julian(dt)
    n = jd - 2451545.0                                                  # Eq. 2

    # ecliptic coordinates of the sun
    # ecliptic longitude (λₑ), and obliquity of the ecliptic (ϵ)
    Ω = p1 + p2 * n                                                     # Eq. 3
    L = p3 + p4 * n                                                     # Eq. 4
    g = p5 + p6 * n                                                     # Eq. 5
    (sin_Ω, cos_Ω) = sincos(Ω)
    λₑ = L + p7 * sin(g) + p8 * sin(2 * g) + p9 + p10 * sin_Ω           # Eq. 6
    ϵ = p11 + p12 * n + p13 * cos_Ω                                     # Eq. 7

    # celestial right ascension (ra) and declination (d)
    (sin_ϵ, cos_ϵ) = sincos(ϵ)
    (sin_λₑ, cos_λₑ) = sincos(λₑ)
    ra = atan(cos_ϵ * sin_λₑ, cos_λₑ)                                   # Eq. 8
    ra = mod2pi(ra)
    δ = asin(sin_ϵ * sin_λₑ)                                            # Eq. 9

    # computes the local coordinates: azimuth (γ) and zenith angle (θz)
    λt = rad2deg(obs.longitude_rad)
    cos_lat = obs.cos_lat
    sin_lat = obs.sin_lat

    hour = fractional_hour(dt)
    gmst = p14 + p15 * n + hour                                         # Eq. 10
    lmst = deg2rad(gmst * 15 + λt)                                      # Eq. 11
    ω = lmst - ra                                                       # Eq. 12
    (sin_δ, cos_δ) = sincos(δ)
    (sin_ω, cos_ω) = sincos(ω)
    θz = acos(cos_lat * cos_ω * cos_δ + sin_δ * sin_lat)                # Eq. 13
    γ = atan(-sin_ω, (tan(δ) * cos_lat - sin_lat * cos_ω))              # Eq. 14

    # parallax correction
    θz = θz + (EMR / AU) * sin(θz)                                      # Eq. 15,16

    az = mod(rad2deg(γ), 360.0)
    el = rad2deg/ 2 - θz)
    zen = rad2deg(θz)
    return SolPos(az, el, zen)
end

# ============================================================================
# Baseline benchmark
# ============================================================================

dt1 = DateTime(2024, 6, 21, 12, 0, 0)
dt2 = DateTime(2024, 1, 21, 12, 0, 0)
out1 = _solar_position(obs, dt1)
out2 = _solar_position(obs, dt2)
@assert out1 != out2  # Should be different times, but just checking it runs

println("Benchmarking Julia original solar_position:")
b_julia = @benchmark _solar_position($obs, $dt)
display(b_julia)
println()

# ============================================================================
# Reactant.jl benchmark
# ============================================================================
rdt1 = Reactant.to_rarray(dt1);
rdt2 = Reactant.to_rarray(dt2);
println("rdt type: ", typeof(rdt1))

# Compile the Reactant version
f = @compile _solar_position(obs, rdt1)

rout1 = f(obs, rdt1)
rout2 = f(obs, rdt2)

@assert rout1 != rout2  # Should be different times, but they are identical

No matter what dt I pass in it will always return the same result.

edit: I think the problem is with dt.instant.periods.value, that can't be traced. Some tools to figure out where reactant stops tracing would be super helpful.

@maximilian-gelbrecht
Copy link
Collaborator Author

Thanks.
I think that might be related to the same issue I also have with the other example.

I suspect it might be due to the way I defined traced_type_inner. It's the piece of code I also commented on above. Maybe @mofeing or @wsmoses can help out.

@wsmoses
Copy link
Member

wsmoses commented Feb 26, 2026

to_rarray by default only converts arrays, not numbers. You can use the track number flag to also convert numbers (or of course explicitly construct the object with reactant types of choice).

@wsmoses
Copy link
Member

wsmoses commented Feb 26, 2026

To look at what the compiled function does (And thus what might have been baked in), I'd recommend looking at @code_hlo _solar_position(obs, rdt1), essentially replacing your earlier @compile command with @code_hlo

maximilian-gelbrecht and others added 5 commits March 2, 2026 10:38
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
end

# custom DateTime, Date, Time types analogues to those defined in Dates.jl
struct TracedRDateTime{I} <: AbstractDateTime
Copy link
Collaborator

@mofeing mofeing Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is fine, but since DateTime is not generic, TracedDateTime does not require to be generic (and can only contain a TracedRNumber{Int64}). same for the other types.

again, this is fine.

Copy link
Collaborator Author

@maximilian-gelbrecht maximilian-gelbrecht Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean that the only possible type for its field is TracedRNumber{Int64}?

While trying out the extension I saw something different. Sometimes I get TracedRMillisecond{Int64}, some times TracedRMillisecond{TracedRNumber{Int64}}, sometimes TracedRMillisecond{ConcretePJRTNumber{Int,1}.

@maximilian-gelbrecht

This comment was marked as resolved.

@maximilian-gelbrecht

This comment was marked as resolved.

@maximilian-gelbrecht
Copy link
Collaborator Author

maximilian-gelbrecht commented Mar 2, 2026

I think this PR might be fine for a review now. I am just having problems actually using it for our model that are related to @trace, but independent of Dates and this extension here.

@langestefan
Copy link

Not having much luck with this yet, datetime vectors are still folded in as constants.

@maximilian-gelbrecht
Copy link
Collaborator Author

I don't understand what you mean by folded in as a constant. Could you condense the example from above down to a MWE of the issue?

@langestefan
Copy link

The above is an MWE that can be run, but I will try to produce a smaller script.

I don't understand what you mean by folded in as a constant.

Essentially reactant compiles an empty function that does nothing. This is the output from @code_hlo

module @reactant__solar_... attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  func.func @main() attributes {enzymexla.memory_effects = []} {
    return
  }
}

@maximilian-gelbrecht
Copy link
Collaborator Author

Even after applying a workaround to #2582 , I still encounter a weird scoping issue here that I can't quite condense down to a MWE.

Basically, the unit test works unless I put it inside a @testset, then I get an error. Isn't that weird?

using Test
using Reactant

using Dates
using Dates: value, UTInstant

#@testset "Minimal timestepper with Dates" begin
        # Inspired by the usage in SpeedyWeather.jl

        # Minimal clock-like mutable struct
        mutable struct Clock{I,T,TS}
            n_timesteps::I
            time::T
            time_step::TS
        end

        # Minimal state
        struct State{C}
            clock::C
        end

        function timestep!(state::State)
            state.clock.time += state.clock.time_step
            return nothing
        end

        function timestepping!(state::State)
            (; clock) = state

            @trace for i in 1:(clock.n_timesteps)
                timestep!(state)
            end

            return nothing
        end

        clock = Clock(5, DateTime(2002, 1, 1), Dates.Day(1))
        state = State(clock)
        timestepping!(state)

        clock_jit = Reactant.to_rarray(
            Clock(5, DateTime(2002, 1, 1), Dates.Day(1)); track_numbers=true
        )

        state_jit = State(clock_jit)
        @jit timestepping!(state_jit)

        @test DateTime(state_jit.clock.time) == state.clock.time
#    end

With the @testset commented out it works. With the @testset I get a

Got exception outside of a @test
  setfield!: immutable struct of type Int64 cannot be changed
  Stacktrace:
    [1] traced_setfield!
      @ ~/.julia/dev/Reactant/src/Compiler.jl:96 [inlined]
    [2] traced_setfield_buffer!(::Val{:PJRT}, _cache_dict::IdDict{Union{Reactant.TracedRArray, Reactant.TracedRNumber}, Union{ConcretePJRTArray, ConcretePJRTNumber}}, val::Int64, concrete_res::Tuple{Reactant.XLA.PJRT.AsyncBuffer}, _obj::Millisecond, _field::Int64, path::Tuple{Int64, Symbol, Vararg{Int64, 4}})
      @ Reactant.Compiler ~/.julia/dev/Reactant/src/Compiler.jl:170
    [3] traced_setfield_buffer!(runtime::Val{:PJRT}, cache_dict::IdDict{Union{Reactant.TracedRArray, Reactant.TracedRNumber}, Union{ConcretePJRTArray, ConcretePJRTNumber}}, concrete_res::Tuple{Reactant.XLA.PJRT.AsyncBuffer}, obj::Millisecond, field::Int64, path::Tuple{Int64, Symbol, Vararg{Int64, 4}})
      @ Reactant.Compiler ~/.julia/dev/Reactant/src/Compiler.jl:164
    [4] macro expansion
      @ ~/.julia/dev/Reactant/src/Compiler.jl:3613 [inlined]
    [5] (::Reactant.Compiler.Thunk{var"#timestepping!#22"{var"#timestep!#21"}, Symbol("##timestepping!_reactant#250"), true, Tuple{State{Clock{ConcretePJRTNumber{Int64, 1}, ReactantDatesExt.TracedRDateTime{ConcretePJRTNumber{Int64, 1}}, ReactantDatesExt.TracedRDay{ConcretePJRTNumber{Int64, 1}}}}}, Reactant.XLA.PJRT.LoadedExecutable, Reactant.XLA.PJRT.Device, Reactant.XLA.PJRT.Client, Tuple{}, Vector{Bool}})(args::State{Clock{ConcretePJRTNumber{Int64, 1}, ReactantDatesExt.TracedRDateTime{ConcretePJRTNumber{Int64, 1}}, ReactantDatesExt.TracedRDay{ConcretePJRTNumber{Int64, 1}}}})
      @ Reactant.Compiler ~/.julia/dev/Reactant/src/Compiler.jl:4304
    [6] macro expansion
      @ ~/.julia/dev/Reactant/src/Compiler.jl:3103 [inlined]
    [7] macro expansion
      @ ~/.julia/packages/LLVM/fEIbx/src/base.jl:97 [inlined]
    [8] macro expansion
      @ ~/.julia/dev/Reactant/src/Compiler.jl:3095 [inlined]
    [9] macro expansion
      @ REPL[21]:41 [inlined]
   [10] macro expansion
      @ ~/.julia/juliaup/julia-1.11.7+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/Test/src/Test.jl:1709 [inlined]
   [11] top-level scope
      @ REPL[21]:5
Test Summary:                  | Error  Total  Time
Minimal timestepper with Dates |     1      1  1.2s

@wsmoses
Copy link
Member

wsmoses commented Mar 3, 2026

move the function and struct defn outside of the testset, my guess is that clock or other variables in the function are shadowing the locals in bad ways

@maximilian-gelbrecht
Copy link
Collaborator Author

Ah yes, thanks that worked. Never encountered an issue like that before.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

to_rarray on Vector{DateTime}

5 participants