Skip to content

Commit 97d7344

Browse files
Allow FieldTimeSeries to pass keyword arguments to jldopen (#3739)
Co-authored-by: Simone Silvestri <[email protected]>
1 parent 9ffbee3 commit 97d7344

6 files changed

+122
-70
lines changed

src/OutputReaders/field_dataset.jl

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
struct FieldDataset{F, M, P}
2-
fields :: F
3-
metadata :: M
4-
filepath :: P
1+
struct FieldDataset{F, M, P, KW}
2+
fields :: F
3+
metadata :: M
4+
filepath :: P
5+
reader_kw :: KW
56
end
67

78
"""
@@ -22,17 +23,24 @@ linearly.
2223
`file["metadata"]`.
2324
2425
- `grid`: May be specified to override the grid used in the JLD2 file.
26+
27+
- `reader_kw`: A dictionary of keyword arguments to pass to the reader (currently only JLD2)
28+
to be used when opening files.
2529
"""
2630
function FieldDataset(filepath;
27-
architecture=CPU(), grid=nothing, backend=InMemory(), metadata_paths=["metadata"])
31+
architecture = CPU(),
32+
grid = nothing,
33+
backend = InMemory(),
34+
metadata_paths = ["metadata"],
35+
reader_kw = Dict{Symbol, Any}())
2836

29-
file = jldopen(filepath)
37+
file = jldopen(filepath; reader_kw...)
3038

3139
field_names = keys(file["timeseries"])
3240
filter!(k -> k != "t", field_names) # Time is not a field.
3341

3442
ds = Dict{String, FieldTimeSeries}(
35-
name => FieldTimeSeries(filepath, name; architecture, backend, grid)
43+
name => FieldTimeSeries(filepath, name; architecture, backend, grid, reader_kw)
3644
for name in field_names
3745
)
3846

@@ -44,7 +52,7 @@ function FieldDataset(filepath;
4452

4553
close(file)
4654

47-
return FieldDataset(ds, metadata, abspath(filepath))
55+
return FieldDataset(ds, metadata, abspath(filepath), reader_kw)
4856
end
4957

5058
Base.getindex(fds::FieldDataset, inds...) = Base.getindex(fds.fields, inds...)

src/OutputReaders/field_time_series.jl

Lines changed: 50 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ period = t[end] - t[1] + Δt
8585
"""
8686
struct Cyclical{FT}
8787
period :: FT
88-
end
88+
end
8989

9090
Cyclical() = Cyclical(nothing)
9191

@@ -164,7 +164,7 @@ Nt = 5
164164
backend = InMemory(4, 3) # so we have (4, 5, 1)
165165
n = 1 # so, the right answer is m̃ = 3
166166
m = 1 - (4 - 1) # = -2
167-
m̃ = mod1(-2, 5) # = 3 ✓
167+
m̃ = mod1(-2, 5) # = 3 ✓
168168
```
169169
170170
# Another shifting + wrapping example
@@ -213,7 +213,7 @@ Base.length(backend::PartlyInMemory) = backend.length
213213
##### FieldTimeSeries
214214
#####
215215

216-
mutable struct FieldTimeSeries{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N} <: AbstractField{LX, LY, LZ, G, ET, 4}
216+
mutable struct FieldTimeSeries{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N, KW} <: AbstractField{LX, LY, LZ, G, ET, 4}
217217
data :: D
218218
grid :: G
219219
backend :: K
@@ -223,16 +223,18 @@ mutable struct FieldTimeSeries{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N} <: A
223223
path :: P
224224
name :: N
225225
time_indexing :: TI
226-
226+
reader_kw :: KW
227+
227228
function FieldTimeSeries{LX, LY, LZ}(data::D,
228229
grid::G,
229230
backend::K,
230231
bcs::B,
231-
indices::I,
232+
indices::I,
232233
times,
233234
path,
234235
name,
235-
time_indexing) where {LX, LY, LZ, K, D, G, B, I}
236+
time_indexing,
237+
reader_kw) where {LX, LY, LZ, K, D, G, B, I}
236238

237239
ET = eltype(data)
238240

@@ -250,7 +252,7 @@ mutable struct FieldTimeSeries{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N} <: A
250252

251253
times = on_architecture(architecture(grid), times)
252254
end
253-
255+
254256
if time_indexing isa Cyclical{Nothing} # we have to infer the period
255257
Δt = @allowscalar times[end] - times[end-1]
256258
period = @allowscalar times[end] - times[1] + Δt
@@ -261,23 +263,25 @@ mutable struct FieldTimeSeries{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N} <: A
261263
TI = typeof(time_indexing)
262264
P = typeof(path)
263265
N = typeof(name)
266+
KW = typeof(reader_kw)
264267

265-
return new{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N}(data, grid, backend, bcs,
266-
indices, times, path, name,
267-
time_indexing)
268+
return new{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N, KW}(data, grid, backend, bcs,
269+
indices, times, path, name,
270+
time_indexing, reader_kw)
268271
end
269272
end
270273

271-
on_architecture(to, fts::FieldTimeSeries{LX, LY, LZ}) where {LX, LY, LZ} =
274+
on_architecture(to, fts::FieldTimeSeries{LX, LY, LZ}) where {LX, LY, LZ} =
272275
FieldTimeSeries{LX, LY, LZ}(on_architecture(to, fts.data),
273276
on_architecture(to, fts.grid),
274277
on_architecture(to, fts.backend),
275278
on_architecture(to, fts.bcs),
276-
on_architecture(to, fts.indices),
279+
on_architecture(to, fts.indices),
277280
on_architecture(to, fts.times),
278281
on_architecture(to, fts.path),
279282
on_architecture(to, fts.name),
280-
on_architecture(to, fts.time_indexing))
283+
on_architecture(to, fts.time_indexing),
284+
on_architecture(to, fts.reader_kw))
281285

282286
#####
283287
##### Minimal implementation of FieldTimeSeries for use in GPU kernels
@@ -290,7 +294,7 @@ struct GPUAdaptedFieldTimeSeries{LX, LY, LZ, TI, K, ET, D, χ} <: AbstractField{
290294
times :: χ
291295
backend :: K
292296
time_indexing :: TI
293-
297+
294298
function GPUAdaptedFieldTimeSeries{LX, LY, LZ}(data::D,
295299
times:,
296300
backend::K,
@@ -313,7 +317,7 @@ const FTS{LX, LY, LZ, TI, K} = FieldTimeSeries{LX, LY, LZ, TI, K} w
313317
const GPUFTS{LX, LY, LZ, TI, K} = GPUAdaptedFieldTimeSeries{LX, LY, LZ, TI, K} where {LX, LY, LZ, TI, K}
314318

315319
const FlavorOfFTS{LX, LY, LZ, TI, K} = Union{GPUFTS{LX, LY, LZ, TI, K},
316-
FTS{LX, LY, LZ, TI, K}} where {LX, LY, LZ, TI, K}
320+
FTS{LX, LY, LZ, TI, K}} where {LX, LY, LZ, TI, K}
317321

318322
const InMemoryFTS = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:AbstractInMemoryBackend}
319323
const OnDiskFTS = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:OnDisk}
@@ -345,7 +349,7 @@ instantiate(T::Type) = T()
345349
new_data(FT, grid, loc, indices, ::Nothing) = nothing
346350

347351
# Apparently, not explicitly specifying Int64 in here makes this function
348-
# fail on x86 processors where `Int` is implied to be `Int32`
352+
# fail on x86 processors where `Int` is implied to be `Int32`
349353
# see ClimaOcean commit 3c47d887659d81e0caed6c9df41b7438e1f1cd52 at https://github.com/CliMA/ClimaOcean.jl/actions/runs/8804916198/job/24166354095)
350354
function new_data(FT, grid, loc, indices, Nt::Union{Int, Int64})
351355
space_size = total_size(grid, loc, indices)
@@ -360,12 +364,13 @@ time_indices_length(backend::PartlyInMemory, times) = length(backend)
360364
time_indices_length(::OnDisk, times) = nothing
361365

362366
function FieldTimeSeries(loc, grid, times=();
363-
indices = (:, :, :),
367+
indices = (:, :, :),
364368
backend = InMemory(),
365-
path = nothing,
369+
path = nothing,
366370
name = nothing,
367371
time_indexing = Linear(),
368-
boundary_conditions = nothing)
372+
boundary_conditions = nothing,
373+
reader_kw = Dict{Symbol, Any}())
369374

370375
LX, LY, LZ = loc
371376

@@ -376,9 +381,9 @@ function FieldTimeSeries(loc, grid, times=();
376381
isnothing(path) && error(ArgumentError("Must provide the keyword argument `path` when `backend=OnDisk()`."))
377382
isnothing(name) && error(ArgumentError("Must provide the keyword argument `name` when `backend=OnDisk()`."))
378383
end
379-
380-
return FieldTimeSeries{LX, LY, LZ}(data, grid, backend, boundary_conditions,
381-
indices, times, path, name, time_indexing)
384+
385+
return FieldTimeSeries{LX, LY, LZ}(data, grid, backend, boundary_conditions, indices,
386+
times, path, name, time_indexing, reader_kw)
382387
end
383388

384389
"""
@@ -405,10 +410,16 @@ end
405410
struct UnspecifiedBoundaryConditions end
406411

407412
"""
408-
FieldTimeSeries(path, name, backend = InMemory();
413+
FieldTimeSeries(path, name;
414+
backend = InMemory(),
415+
architecture = nothing,
409416
grid = nothing,
417+
location = nothing,
418+
boundary_conditions = UnspecifiedBoundaryConditions(),
419+
time_indexing = Linear(),
410420
iterations = nothing,
411-
times = nothing)
421+
times = nothing,
422+
reader_kw = Dict{Symbol, Any}())
412423
413424
Return a `FieldTimeSeries` containing a time-series of the field `name`
414425
load from JLD2 output located at `path`.
@@ -427,6 +438,9 @@ Keyword arguments
427438
- `times`: Save times to load, as determined through an approximate floating point
428439
comparison to recorded save times. Defaults to times associated with `iterations`.
429440
Takes precedence over `iterations` if `times` is specified.
441+
442+
- `reader_kw`: A dictionary of keyword arguments to pass to the reader (currently only JLD2)
443+
to be used when opening files.
430444
"""
431445
function FieldTimeSeries(path::String, name::String;
432446
backend = InMemory(),
@@ -436,9 +450,10 @@ function FieldTimeSeries(path::String, name::String;
436450
boundary_conditions = UnspecifiedBoundaryConditions(),
437451
time_indexing = Linear(),
438452
iterations = nothing,
439-
times = nothing)
453+
times = nothing,
454+
reader_kw = Dict{Symbol, Any}())
440455

441-
file = jldopen(path)
456+
file = jldopen(path; reader_kw...)
442457

443458
# Defaults
444459
isnothing(iterations) && (iterations = parse.(Int, keys(file["timeseries/t"])))
@@ -520,8 +535,8 @@ function FieldTimeSeries(path::String, name::String;
520535
Nt = time_indices_length(backend, times)
521536
data = new_data(eltype(grid), grid, loc, indices, Nt)
522537

523-
time_series = FieldTimeSeries{LX, LY, LZ}(data, grid, backend, boundary_conditions,
524-
indices, times, path, name, time_indexing)
538+
time_series = FieldTimeSeries{LX, LY, LZ}(data, grid, backend, boundary_conditions, indices,
539+
times, path, name, time_indexing, reader_kw)
525540

526541
set!(time_series, path, name)
527542

@@ -533,7 +548,8 @@ end
533548
grid = nothing,
534549
architecture = nothing,
535550
indices = (:, :, :),
536-
boundary_conditions = nothing)
551+
boundary_conditions = nothing,
552+
reader_kw = Dict{Symbol, Any}())
537553
538554
Load a field called `name` saved in a JLD2 file at `path` at `iter`ation.
539555
Unless specified, the `grid` is loaded from `path`.
@@ -542,7 +558,8 @@ function Field(location, path::String, name::String, iter;
542558
grid = nothing,
543559
architecture = nothing,
544560
indices = (:, :, :),
545-
boundary_conditions = nothing)
561+
boundary_conditions = nothing,
562+
reader_kw = Dict{Symbol, Any}())
546563

547564
# Default to CPU if neither architecture nor grid is specified
548565
if isnothing(architecture)
@@ -552,9 +569,9 @@ function Field(location, path::String, name::String, iter;
552569
architecture = Architectures.architecture(grid)
553570
end
554571
end
555-
572+
556573
# Load the grid and data from file
557-
file = jldopen(path)
574+
file = jldopen(path; reader_kw...)
558575

559576
isnothing(grid) && (grid = file["serialized/grid"])
560577
raw_data = file["timeseries/$name/$iter"]
@@ -565,7 +582,7 @@ function Field(location, path::String, name::String, iter;
565582
grid = on_architecture(architecture, grid)
566583
raw_data = on_architecture(architecture, raw_data)
567584
data = offset_data(raw_data, grid, location, indices)
568-
585+
569586
return Field(location, grid; boundary_conditions, indices, data)
570587
end
571588

@@ -625,4 +642,3 @@ function fill_halo_regions!(fts::InMemoryFTS)
625642

626643
return nothing
627644
end
628-

src/OutputReaders/field_time_series_indexing.jl

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import Oceananigans.Fields: interpolate
1414
# Cyclical implementation if out-of-bounds (wrap around the time-series)
1515
@inline function interpolating_time_indices(ti::Cyclical, times, t)
1616
Nt = length(times)
17-
= first(times)
17+
= first(times)
1818
tᴺ = last(times)
1919

2020
T = ti.period
@@ -32,14 +32,14 @@ import Oceananigans.Fields: interpolate
3232
uncycled_indices = (ñ, n₁, n₂)
3333

3434
return ifelse(cycling, cycled_indices, uncycled_indices)
35-
end
35+
end
3636

3737
# Clamp mode if out-of-bounds, i.e get the neareast neighbor
3838
@inline function interpolating_time_indices(::Clamp, times, t)
3939
n, n₁, n₂ = time_index_binary_search(times, t)
4040

4141
beyond_indices = (0, n₂, n₂) # Beyond the last time: return n₂
42-
before_indices = (0, n₁, n₁) # Before the first time: return n₁
42+
before_indices = (0, n₁, n₁) # Before the first time: return n₁
4343
unclamped_indices = (n, n₁, n₂) # Business as usual
4444

4545
Nt = length(times)
@@ -53,13 +53,13 @@ end
5353
@inline function time_index_binary_search(times, t)
5454
Nt = length(times)
5555

56-
# n₁ and n₂ are the index to interpolate inbetween and
56+
# n₁ and n₂ are the index to interpolate inbetween and
5757
# n is a fractional index where 0 ≤ n ≤ 1
5858
n₁, n₂ = index_binary_search(times, t, Nt)
5959

6060
@inbounds begin
61-
t₁ = times[n₁]
62-
t₂ = times[n₂]
61+
t₁ = times[n₁]
62+
t₂ = times[n₂]
6363
end
6464

6565
# "Fractional index" ñ ∈ (0, 1)
@@ -79,7 +79,7 @@ import Base: getindex
7979
function getindex(fts::OnDiskFTS, n::Int)
8080
# Load data
8181
arch = architecture(fts)
82-
file = jldopen(fts.path)
82+
file = jldopen(fts.path; fts.reader_kw...)
8383
iter = keys(file["timeseries/t"])[n]
8484
raw_data = on_architecture(arch, file["timeseries/$(fts.name)/$iter"])
8585
close(file)
@@ -117,7 +117,7 @@ const YZFTS = FlavorOfFTS{Nothing, <:Any, <:Any, <:Any, <:Any}
117117

118118
@inline function interpolating_getindex(fts, i, j, k, time_index)
119119
ñ, n₁, n₂ = interpolating_time_indices(fts.time_indexing, fts.times, time_index.time)
120-
120+
121121
@inbounds begin
122122
ψ₁ = getindex(fts, i, j, k, n₁)
123123
ψ₂ = getindex(fts, i, j, k, n₂)
@@ -229,14 +229,14 @@ end
229229
##### FieldTimeSeries updating
230230
#####
231231

232-
# Let's make sure `times` is available on the CPU. This is always the case
233-
# for ranges. if `times` is a vector that resides on the GPU, it has to be moved to the CPU for safe indexing.
232+
# Let's make sure `times` is available on the CPU. This is always the case
233+
# for ranges. if `times` is a vector that resides on the GPU, it has to be moved to the CPU for safe indexing.
234234
# TODO: Copying the whole array is a bit unclean, maybe find a way that avoids the penalty of allocating and copying memory.
235235
# This would require refactoring `FieldTimeSeries` to include a cpu-allocated times array
236236
cpu_interpolating_time_indices(::CPU, times, time_indexing, t, arch) = interpolating_time_indices(time_indexing, times, t)
237237
cpu_interpolating_time_indices(::CPU, times::AbstractVector, time_indexing, t) = interpolating_time_indices(time_indexing, times, t)
238238

239-
function cpu_interpolating_time_indices(::GPU, times::AbstractVector, time_indexing, t)
239+
function cpu_interpolating_time_indices(::GPU, times::AbstractVector, time_indexing, t)
240240
cpu_times = on_architecture(CPU(), times)
241241
return interpolating_time_indices(time_indexing, cpu_times, t)
242242
end
@@ -279,4 +279,3 @@ function getindex(fts::InMemoryFTS, n::Int)
279279

280280
return Field(location(fts), fts.grid; data, fts.boundary_conditions, fts.indices)
281281
end
282-

0 commit comments

Comments
 (0)