Skip to content

Commit 9f49c87

Browse files
authored
Make write_array require less specialization (#2)
* Make `write_array` require less specialization This massages the implementation of `write_array` to be fully-inferred even for Array's of unknown type / dimensions. There's very little we're doing here other than looking up our dims and passing a pointer to C, so the main trick is to avoid over-specialization in other parts of Base (e.g. we have no way to type-stably convert `NTuple{N,Int64}` into `Vector{Int}` right now) * Workaround JET false positive JET thinks that possible method errors are "dynamic dispatches", so workaround it by type-asserting to avoid the method error.
1 parent db82cc9 commit 9f49c87

3 files changed

Lines changed: 55 additions & 59 deletions

File tree

src/SimpleHDF5.jl

Lines changed: 51 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -118,17 +118,18 @@ write_array(file_id, "matrix", rand(10, 10))
118118
close_file(file_id)
119119
```
120120
"""
121-
function write_array(file_id::API.hid_t, dataset_name::String, data::AbstractArray{T}) where T
121+
function write_array(file_id::API.hid_t, dataset_name::String, data::AbstractArray)
122122
# Convert to regular Array if it's not already one
123-
array_data = Array(data)
124-
125-
dims = size(array_data)
126-
rank = length(dims)
123+
return write_array(file_id, dataset_name, Array(data)::Array)
124+
end
127125

128-
dims_hsize_t = [reverse(dims)] # HDF5 uses C ordering (last dimension varies fastest)
129-
dataspace_id = API.h5s_create_simple(rank, dims_hsize_t, C_NULL)
126+
function write_array(file_id::API.hid_t, dataset_name::String, @nospecialize(data::Array))
127+
rank = length(size(data))
128+
dims = _convert(Vector{Int}, size(data))
130129

131-
datatype_id = _get_h5_datatype(T)
130+
reverse!(dims) # HDF5 uses C ordering (last dimension varies fastest)
131+
dataspace_id = API.h5s_create_simple(rank, dims, C_NULL)
132+
datatype_id = _get_h5_datatype(eltype(data)::Type)
132133

133134
# Create dataset
134135
dataset_id = API.h5d_create(
@@ -148,7 +149,7 @@ function write_array(file_id::API.hid_t, dataset_name::String, data::AbstractArr
148149
API.H5S_ALL,
149150
API.H5S_ALL,
150151
API.H5P_DEFAULT,
151-
array_data
152+
data
152153
)
153154

154155
# Clean up
@@ -348,56 +349,48 @@ function list_datasets(file_id::API.hid_t, path::String="/")
348349
return datasets
349350
end
350351

351-
# Helper function to get HDF5 datatype from Julia type
352-
function _get_h5_datatype(::Type{Float64})
353-
return API.h5t_copy(API.H5T_NATIVE_DOUBLE)
354-
end
355-
356-
function _get_h5_datatype(::Type{Float32})
357-
return API.h5t_copy(API.H5T_NATIVE_FLOAT)
358-
end
359-
360-
function _get_h5_datatype(::Type{Int64})
361-
return API.h5t_copy(API.H5T_NATIVE_INT64)
362-
end
363-
364-
function _get_h5_datatype(::Type{Int32})
365-
return API.h5t_copy(API.H5T_NATIVE_INT32)
366-
end
367-
368-
function _get_h5_datatype(::Type{Int16})
369-
return API.h5t_copy(API.H5T_NATIVE_INT16)
370-
end
371-
372-
function _get_h5_datatype(::Type{Int8})
373-
return API.h5t_copy(API.H5T_NATIVE_INT8)
374-
end
375-
376-
function _get_h5_datatype(::Type{UInt64})
377-
return API.h5t_copy(API.H5T_NATIVE_UINT64)
378-
end
379-
380-
function _get_h5_datatype(::Type{UInt32})
381-
return API.h5t_copy(API.H5T_NATIVE_UINT32)
382-
end
383-
384-
function _get_h5_datatype(::Type{UInt16})
385-
return API.h5t_copy(API.H5T_NATIVE_UINT16)
386-
end
387-
388-
function _get_h5_datatype(::Type{UInt8})
389-
return API.h5t_copy(API.H5T_NATIVE_UINT8)
390-
end
391-
392-
function bool_type()
393-
# Encode Bool as bitfield (UInt8-based) with precision 1
394-
bool_type = API.h5t_copy(API.H5T_NATIVE_B8)
395-
API.h5t_set_precision(bool_type, 1)
396-
return bool_type
397-
end
352+
# Type-stable conversion for NTuple -> Vector
353+
function _convert(::Type{Vector{T}}, @nospecialize(tup::NTuple{N,U} where N)) where {T,U}
354+
N = length(tup)
355+
v = Vector{T}(undef, N)
356+
for i = 1:N
357+
if T === U
358+
v[i] = tup[i]
359+
else
360+
v[i] = Base.convert(T, tup[i])
361+
end
362+
end
363+
return v
364+
end
398365

399-
function _get_h5_datatype(::Type{Bool})
400-
return bool_type()
366+
# Helper function to get HDF5 datatype from Julia type
367+
function _get_h5_datatype(@nospecialize(T::Type))
368+
if T === Float64
369+
return API.h5t_copy(API.H5T_NATIVE_DOUBLE)
370+
elseif T === Float32
371+
return API.h5t_copy(API.H5T_NATIVE_FLOAT)
372+
elseif T === Int64
373+
return API.h5t_copy(API.H5T_NATIVE_INT64)
374+
elseif T === Int32
375+
return API.h5t_copy(API.H5T_NATIVE_INT32)
376+
elseif T === Int16
377+
return API.h5t_copy(API.H5T_NATIVE_INT16)
378+
elseif T === Int8
379+
return API.h5t_copy(API.H5T_NATIVE_INT8)
380+
elseif T === UInt64
381+
return API.h5t_copy(API.H5T_NATIVE_UINT64)
382+
elseif T === UInt32
383+
return API.h5t_copy(API.H5T_NATIVE_UINT32)
384+
elseif T === UInt16
385+
return API.h5t_copy(API.H5T_NATIVE_UINT16)
386+
elseif T === UInt8
387+
return API.h5t_copy(API.H5T_NATIVE_UINT8)
388+
elseif T === Bool
389+
# Encode Bool as bitfield (UInt8-based) with precision 1
390+
bool_type = API.h5t_copy(API.H5T_NATIVE_B8)
391+
API.h5t_set_precision(bool_type, 1)
392+
return bool_type
393+
else @assert false "unsupported datatype" end
401394
end
402395

403396
function _get_julia_type(datatype_id::API.hid_t)

src/api/functions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -924,7 +924,7 @@ end
924924
925925
See `libhdf5` documentation for [`H5Dwrite`](https://docs.hdfgroup.org/hdf5/v1_14/group___h5_d.html#ga98f44998b67587662af8b0d8a0a75906).
926926
"""
927-
function h5d_write(dataset_id, mem_type_id, mem_space_id, file_space_id, xfer_plist_id, buf)
927+
function h5d_write(dataset_id, mem_type_id, mem_space_id, file_space_id, xfer_plist_id, @nospecialize(buf::Array))
928928
lock(liblock)
929929
var"#status#" = try
930930
ccall((:H5Dwrite, libhdf5), herr_t, (hid_t, hid_t, hid_t, hid_t, hid_t, Ptr{Cvoid}), dataset_id, mem_type_id, mem_space_id, file_space_id, xfer_plist_id, buf)

test/test_type_stability.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ using JET
3535
JET.@test_opt write_array(file_id, "float_array", float_array)
3636
JET.@test_opt write_array(file_id, "bool_array", bool_array)
3737

38+
# write_array should be optimized even with an unknown array type
39+
JET.test_opt(write_array, (typeof(file_id), String, Array))
40+
3841
# Actually write the arrays for subsequent tests
3942
write_array(file_id, "int_array", int_array)
4043
write_array(file_id, "float_array", float_array)

0 commit comments

Comments
 (0)