Conversation
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
|
Something isn't working yet when tracing Setting up an integration example modelled after the time stepping we do in Speedy. When I add a using Reactant
using Dates
mutable struct Clock{I, T, TS}
n_timesteps::I
time::T
time_step::TS
end
# Minimal state
struct State{T, C}
x::T
clock::C
end
function timestepping!(state::State)
(; clock) = state
for i in 1:(clock.n_timesteps)
timestep!(state)
end
return state
end
function timestepping_trace!(state::State)
(; clock) = state
@trace for i in 1:(clock.n_timesteps)
timestep!(state)
end
return state
end
function timestep!(state::State)
state.x .+= 1.0
state.clock.time += state.clock.time_step
return nothing
end
clock = Clock(5, DateTime(2002, 1, 1), Dates.Day(1))
state = State(zeros(1), clock)
state = timestepping!(state)
clock_jit = Clock(5, DateTime(2002, 1, 1), Dates.Day(1))
state_jit = State(Reactant.to_rarray(zeros(1)), clock_jit)
state_jit = @jit timestepping!(state_jit)
clock_jit_trace = Clock(5, DateTime(2002, 1, 1), Dates.Day(1))
state_jit_trace = State(Reactant.to_rarray(zeros(1)), clock_jit_trace)
state_jit_trace = @jit timestepping_trace!(state_jit_trace)
# state_jit.clock ≈ state.clock
# state_jit_trace.clock != state.clock
# state_jit_trace.clock is unchanged stays at the initial value DateTime(2002, 1, 1), even though the state itself does iterate and:
# state_jit_trace.x == state.x |
|
It's compiling for me now, but edit: Nevermind, single datetime is also being folded in as a constant. Even though it is returning a |
Can you post an example? For me this actually seems to work, I get using Reactant, Dates
julia> arr = [DateTime(i,1,1) for i=2000:2010]
11-element Vector{DateTime}:
2000-01-01T00:00:00
2001-01-01T00:00:00
julia> Reactant.to_rarray(arr)
11-element Vector{ReactantDatesExt.TracedR
DateTime{Int64}}:
2000-01-01T00:00:00
2001-01-01T00:00:00 |
|
Here is the script I am using: Script"""
Benchmark: Reactant.jl @compile vs plain Julia for solar_position (PSA algorithm).
"""
using Pkg
Pkg.activate("./reactant")
using Dates: AbstractDateTime, DateTime, datetime2julian
using Reactant
using BenchmarkTools
# Set backend to CPU
Reactant.set_default_backend("cpu")
# ============================================================================
# Setup
# ============================================================================
struct Observer{T<:AbstractFloat}
"Geodetic latitude (+N)"
latitude::T
"Longitude (+E)"
longitude::T # longitude (+E)
"Altitude above mean sea level (meters)"
altitude::T # altitude above MSL
"Horizon angle in degrees (e.g., for refraction or sunrise/sunset calculations)"
horizon::T
"Latitude in radians"
latitude_rad::T
"Longitude in radians"
longitude_rad::T
"sin(latitude)"
sin_lat::T
"cos(latitude)"
cos_lat::T
function Observer{T}(
lat::T,
lon::T,
alt::T=zero(T),
horiz::T=zero(T),
) where {T<:AbstractFloat}
lat_rad = deg2rad(lat)
lon_rad = deg2rad(lon)
(sin_lat, cos_lat) = sincos(lat_rad)
return new{T}(lat, lon, alt, horiz, lat_rad, lon_rad, sin_lat, cos_lat)
end
end
function Observer(lat::T, lon::T, alt::T=zero(T), horiz::T=zero(T)) where {T<:AbstractFloat}
return Observer{T}(lat, lon, alt, horiz)
end
obs = Observer(51.5074, -0.1278) # London
dt = DateTime(2024, 6, 21, 12, 0, 0)
struct SolPos{T}
"Azimuth (degrees, 0=N, +clockwise, range [-180, 180])"
azimuth::T
"Elevation (degrees, range [-90, 90])"
elevation::T
"Zenith = 90 - elevation (degrees, range [0, 180])"
zenith::T
end
# ============================================================================
# PSA Algorithm
# ============================================================================
# dt.instant.periods.value = milliseconds since epoch
@inline fractional_hour(dt::AbstractDateTime) = (dt.instant.periods.value % 86_400_000) / 3_600_000
# constants
const EMR = 6371.01 # Earth Mean Radius in km
const AU = 149597890.0 # Astronomical Unit in km
function _solar_position(obs::Observer{T}, dt::AbstractDateTime) where {T}
# Get parameters as tuple (allocation-free)
p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11, p12, p13, p14, p15 = (
2.267127827,
-9.300339267e-4,
4.895036035,
1.720279602e-2,
6.239468336,
1.720200135e-2,
3.338320972e-2,
3.497596876e-4,
-1.544353226e-4,
-8.68972936e-6,
4.090904909e-1,
-6.213605399e-9,
4.418094944e-5,
6.697096103,
6.570984737e-2,
)
# elapsed julian days (n) since J2000.0
jd = datetime2julian(dt)
n = jd - 2451545.0 # Eq. 2
# ecliptic coordinates of the sun
# ecliptic longitude (λₑ), and obliquity of the ecliptic (ϵ)
Ω = p1 + p2 * n # Eq. 3
L = p3 + p4 * n # Eq. 4
g = p5 + p6 * n # Eq. 5
(sin_Ω, cos_Ω) = sincos(Ω)
λₑ = L + p7 * sin(g) + p8 * sin(2 * g) + p9 + p10 * sin_Ω # Eq. 6
ϵ = p11 + p12 * n + p13 * cos_Ω # Eq. 7
# celestial right ascension (ra) and declination (d)
(sin_ϵ, cos_ϵ) = sincos(ϵ)
(sin_λₑ, cos_λₑ) = sincos(λₑ)
ra = atan(cos_ϵ * sin_λₑ, cos_λₑ) # Eq. 8
ra = mod2pi(ra)
δ = asin(sin_ϵ * sin_λₑ) # Eq. 9
# computes the local coordinates: azimuth (γ) and zenith angle (θz)
λt = rad2deg(obs.longitude_rad)
cos_lat = obs.cos_lat
sin_lat = obs.sin_lat
hour = fractional_hour(dt)
gmst = p14 + p15 * n + hour # Eq. 10
lmst = deg2rad(gmst * 15 + λt) # Eq. 11
ω = lmst - ra # Eq. 12
(sin_δ, cos_δ) = sincos(δ)
(sin_ω, cos_ω) = sincos(ω)
θz = acos(cos_lat * cos_ω * cos_δ + sin_δ * sin_lat) # Eq. 13
γ = atan(-sin_ω, (tan(δ) * cos_lat - sin_lat * cos_ω)) # Eq. 14
# parallax correction
θz = θz + (EMR / AU) * sin(θz) # Eq. 15,16
az = mod(rad2deg(γ), 360.0)
el = rad2deg(π / 2 - θz)
zen = rad2deg(θz)
return SolPos(az, el, zen)
end
# ============================================================================
# Baseline benchmark
# ============================================================================
dt1 = DateTime(2024, 6, 21, 12, 0, 0)
dt2 = DateTime(2024, 1, 21, 12, 0, 0)
out1 = _solar_position(obs, dt1)
out2 = _solar_position(obs, dt2)
@assert out1 != out2 # Should be different times, but just checking it runs
println("Benchmarking Julia original solar_position:")
b_julia = @benchmark _solar_position($obs, $dt)
display(b_julia)
println()
# ============================================================================
# Reactant.jl benchmark
# ============================================================================
rdt1 = Reactant.to_rarray(dt1);
rdt2 = Reactant.to_rarray(dt2);
println("rdt type: ", typeof(rdt1))
# Compile the Reactant version
f = @compile _solar_position(obs, rdt1)
rout1 = f(obs, rdt1)
rout2 = f(obs, rdt2)
@assert rout1 != rout2 # Should be different times, but they are identicalNo matter what edit: I think the problem is with |
|
to_rarray by default only converts arrays, not numbers. You can use the track number flag to also convert numbers (or of course explicitly construct the object with reactant types of choice). |
|
To look at what the compiled function does (And thus what might have been baked in), I'd recommend looking at |
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
| end | ||
|
|
||
| # custom DateTime, Date, Time types analogues to those defined in Dates.jl | ||
| struct TracedRDateTime{I} <: AbstractDateTime |
There was a problem hiding this comment.
this is fine, but since DateTime is not generic, TracedDateTime does not require to be generic (and can only contain a TracedRNumber{Int64}). same for the other types.
again, this is fine.
There was a problem hiding this comment.
You mean that the only possible type for its field is TracedRNumber{Int64}?
While trying out the extension I saw something different. Sometimes I get TracedRMillisecond{Int64}, some times TracedRMillisecond{TracedRNumber{Int64}}, sometimes TracedRMillisecond{ConcretePJRTNumber{Int,1}.
This comment was marked as resolved.
This comment was marked as resolved.
This comment was marked as resolved.
This comment was marked as resolved.
|
I think this PR might be fine for a review now. I am just having problems actually using it for our model that are related to |
|
Not having much luck with this yet, datetime vectors are still folded in as constants. |
|
I don't understand what you mean by folded in as a constant. Could you condense the example from above down to a MWE of the issue? |
|
The above is an MWE that can be run, but I will try to produce a smaller script.
Essentially reactant compiles an empty function that does nothing. This is the output from |
|
Even after applying a workaround to #2582 , I still encounter a weird scoping issue here that I can't quite condense down to a MWE. Basically, the unit test works unless I put it inside a using Test
using Reactant
using Dates
using Dates: value, UTInstant
#@testset "Minimal timestepper with Dates" begin
# Inspired by the usage in SpeedyWeather.jl
# Minimal clock-like mutable struct
mutable struct Clock{I,T,TS}
n_timesteps::I
time::T
time_step::TS
end
# Minimal state
struct State{C}
clock::C
end
function timestep!(state::State)
state.clock.time += state.clock.time_step
return nothing
end
function timestepping!(state::State)
(; clock) = state
@trace for i in 1:(clock.n_timesteps)
timestep!(state)
end
return nothing
end
clock = Clock(5, DateTime(2002, 1, 1), Dates.Day(1))
state = State(clock)
timestepping!(state)
clock_jit = Reactant.to_rarray(
Clock(5, DateTime(2002, 1, 1), Dates.Day(1)); track_numbers=true
)
state_jit = State(clock_jit)
@jit timestepping!(state_jit)
@test DateTime(state_jit.clock.time) == state.clock.time
# endWith the Got exception outside of a @test
setfield!: immutable struct of type Int64 cannot be changed
Stacktrace:
[1] traced_setfield!
@ ~/.julia/dev/Reactant/src/Compiler.jl:96 [inlined]
[2] traced_setfield_buffer!(::Val{:PJRT}, _cache_dict::IdDict{Union{Reactant.TracedRArray, Reactant.TracedRNumber}, Union{ConcretePJRTArray, ConcretePJRTNumber}}, val::Int64, concrete_res::Tuple{Reactant.XLA.PJRT.AsyncBuffer}, _obj::Millisecond, _field::Int64, path::Tuple{Int64, Symbol, Vararg{Int64, 4}})
@ Reactant.Compiler ~/.julia/dev/Reactant/src/Compiler.jl:170
[3] traced_setfield_buffer!(runtime::Val{:PJRT}, cache_dict::IdDict{Union{Reactant.TracedRArray, Reactant.TracedRNumber}, Union{ConcretePJRTArray, ConcretePJRTNumber}}, concrete_res::Tuple{Reactant.XLA.PJRT.AsyncBuffer}, obj::Millisecond, field::Int64, path::Tuple{Int64, Symbol, Vararg{Int64, 4}})
@ Reactant.Compiler ~/.julia/dev/Reactant/src/Compiler.jl:164
[4] macro expansion
@ ~/.julia/dev/Reactant/src/Compiler.jl:3613 [inlined]
[5] (::Reactant.Compiler.Thunk{var"#timestepping!#22"{var"#timestep!#21"}, Symbol("##timestepping!_reactant#250"), true, Tuple{State{Clock{ConcretePJRTNumber{Int64, 1}, ReactantDatesExt.TracedRDateTime{ConcretePJRTNumber{Int64, 1}}, ReactantDatesExt.TracedRDay{ConcretePJRTNumber{Int64, 1}}}}}, Reactant.XLA.PJRT.LoadedExecutable, Reactant.XLA.PJRT.Device, Reactant.XLA.PJRT.Client, Tuple{}, Vector{Bool}})(args::State{Clock{ConcretePJRTNumber{Int64, 1}, ReactantDatesExt.TracedRDateTime{ConcretePJRTNumber{Int64, 1}}, ReactantDatesExt.TracedRDay{ConcretePJRTNumber{Int64, 1}}}})
@ Reactant.Compiler ~/.julia/dev/Reactant/src/Compiler.jl:4304
[6] macro expansion
@ ~/.julia/dev/Reactant/src/Compiler.jl:3103 [inlined]
[7] macro expansion
@ ~/.julia/packages/LLVM/fEIbx/src/base.jl:97 [inlined]
[8] macro expansion
@ ~/.julia/dev/Reactant/src/Compiler.jl:3095 [inlined]
[9] macro expansion
@ REPL[21]:41 [inlined]
[10] macro expansion
@ ~/.julia/juliaup/julia-1.11.7+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/Test/src/Test.jl:1709 [inlined]
[11] top-level scope
@ REPL[21]:5
Test Summary: | Error Total Time
Minimal timestepper with Dates | 1 1 1.2s |
|
move the function and struct defn outside of the testset, my guess is that |
|
Ah yes, thanks that worked. Never encountered an issue like that before. |
Julia stdlib's
Datesare currently not traceable as they are hard coded to haveInt64fields. Unfortunately a lot of functions defined for its datatypes are also defined on these concrete types and not on abstract type. So, we have to reimplement a large part ofDatesfor this to work. This is currently not a complete implementation of really everything thatDatesdoes. It's a surprisingly large package.My personal implementation for Speedy so far was just a drop-in replacement for
Dates. This here works different, because it's supposed to work for users just working with regularDates.I am not that familiar with Reactant yet, so I also have a few questions, if I actually implemented the right things ;)
Currently I extend
make_tracerandtraced_type_innerfor all types that may be traced.Doing
Then,
res <: TracedRDateTime. I guess that's expected, or can we convert this back into the untraced type for the user? Is this even something we would want?Aside from this issue, I have to add one or two slightly more complex unit tests and also see if this actually still works with what we need for Speedy or if I need to add more functionality. Folder structure and definitions follow those of
Dates.Edit: Oops, I just
jlfmtbut I think with the wrong style...Should also resolve #2046 (cc @langestefan)