Skip to content

Commit 4295d83

Browse files
committed
Package cleanup, improve test coverage, and fix tests
1 parent 80717e3 commit 4295d83

File tree

5 files changed

+78
-31
lines changed

5 files changed

+78
-31
lines changed

src/Tables.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,11 @@ Obviously every table type is different, but via a combination of `Tables.rows`
115115
abstract type Table end
116116

117117
# default definitions
118+
istable(x::T) where {T} = istable(T)
118119
istable(::Type{T}) where {T} = false
120+
rowaccess(x::T) where {T} = rowaccess(T)
119121
rowaccess(::Type{T}) where {T} = false
122+
columnaccess(x::T) where {T} = columnaccess(T)
120123
columnaccess(::Type{T}) where {T} = false
121124
schema(x) = nothing
122125

src/datavalues.jl

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,27 +14,31 @@ unwrap(x::DataValue) = isna(x) ? missing : DataValues.unsafe_get(x)
1414
datavaluetype(::Type{T}) where {T <: DataValue} = T
1515
datavaluetype(::Type{T}) where {T} = DataValue{T}
1616
datavaluetype(::Type{Union{T, Missing}}) where {T} = DataValue{T}
17-
Base.@pure function datavaluetype(::Type{NT}) where {NT <: NamedTuple{names}} where {names}
18-
TT = Tuple{Any[ datavaluetype(fieldtype(NT, i)) for i = 1:fieldcount(NT) ]...}
17+
Base.@pure function datavaluetype(::Tables.Schema{names, types}) where {names, types}
18+
TT = Tuple{Any[ datavaluetype(fieldtype(types, i)) for i = 1:fieldcount(types) ]...}
1919
return NamedTuple{names, TT}
2020
end
2121

2222
struct DataValueRowIterator{NT, S}
2323
x::S
2424
end
25+
DataValueRowIterator(::Type{NT}, x::S) where {NT <: NamedTuple, S} = DataValueRowIterator{NT, S}(x)
26+
27+
"Returns a DataValue-based NamedTuple-iterator"
28+
DataValueRowIterator(::Type{Schema{names, types}}, x::S) where {names, types, S} = DataValueRowIterator{datavaluetype(NamedTuple{names, types}), S}(x)
29+
function datavaluerows(x)
30+
r = Tables.rows(x)
31+
#TODO: add support for unknown schema
32+
return DataValueRowIterator(datavaluetype(Tables.schema(r)), r)
33+
end
2534

26-
# Should maybe make this return a custom DataValueRow type to allow lazier
27-
# DataValue wrapping; but need to make sure Query/QueryOperators support first
2835
Base.eltype(rows::DataValueRowIterator{NT, S}) where {NT, S} = NT
2936
Base.IteratorSize(::Type{DataValueRowIterator{NT, S}}) where {NT, S} = Base.IteratorSize(S)
3037
Base.length(rows::DataValueRowIterator) = length(rows.x)
3138

32-
"Returns a DataValue-based NamedTuple-iterator"
33-
DataValueRowIterator(::Type{NT}, x::S) where {NT <: NamedTuple, S} = DataValueRowIterator{datavaluetype(NT), S}(x)
34-
3539
function Base.iterate(rows::DataValueRowIterator{NT, S}, st=()) where {NT <: NamedTuple{names}, S} where {names}
3640
if @generated
37-
vals = Tuple(:(getproperty(row, $(fieldtype(NT, i)), $i, $(Meta.QuoteNode(names[i])))) for i = 1:fieldcount(NT))
41+
vals = Tuple(:($(fieldtype(NT, i))(getproperty(row, $(nondatavaluetype(fieldtype(NT, i))), $i, $(Meta.QuoteNode(names[i]))))) for i = 1:fieldcount(NT))
3842
q = quote
3943
x = iterate(rows.x, st...)
4044
x === nothing && return nothing
@@ -47,8 +51,7 @@ function Base.iterate(rows::DataValueRowIterator{NT, S}, st=()) where {NT <: Nam
4751
x = iterate(rows.x, st...)
4852
x === nothing && return nothing
4953
row, st = x
50-
return NT(Tuple(getproperty(row, fieldtype(NT, i), i, names[i]) for i = 1:fieldcount(NT))), (st,)
54+
return NT(Tuple(fieldtype(NT, i)(getproperty(row, nondatavaluetype(fieldtype(NT, i)), i, names[i])) for i = 1:fieldcount(NT))), (st,)
5155
end
5256
end
5357

54-
datavaluerows(x) = DataValueRowIterator(schema(x), rows(x))

src/enumerable.jl

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,30 @@
1-
using .QueryOperators
1+
using .QueryOperators: Enumerable
2+
using .DataValues
23

3-
struct DataValueUnwrapRow{T}
4-
row::T
5-
end
6-
7-
Base.getproperty(d::DataValueUnwrapRow, ::Type{T}, col::Int, nm::Symbol) where {T} = unwrap(getproperty(getfield(d, 1), T, col, nm))
8-
Base.getproperty(d::DataValueUnwrapRow, nm::Symbol) = unwrap(getproperty(getfield(d, 1), nm))
9-
Base.propertynames(d::DataValueUnwrapRow) = propertynames(getfield(d, 1))
4+
Tables.istable(::Type{<:Enumerable}) = true
5+
Tables.rowaccess(::Type{<:Enumerable}) = true
6+
Tables.rows(e::Enumerable) = DataValueUnwrapper(e)
107

11-
struct DataValueUnwrapper{NT, S}
8+
struct DataValueUnwrapper{S}
129
x::S
1310
end
1411

12+
Tables.schema(dv::DataValueUnwrapper) = Tables.Schema(nondatavaluetype(eltype(dv.x)))
1513
Base.eltype(rows::DataValueUnwrapper) = DataValueUnwrapRow{eltype(rows.x)}
16-
Base.IteratorSize(::Type{DataValueUnwrapper{NT, S}}) where {NT, S} = Base.IteratorSize(S)
14+
Base.IteratorSize(::Type{DataValueUnwrapper{S}}) where {S} = Base.IteratorSize(S)
1715
Base.length(rows::DataValueUnwrapper) = length(rows.x)
1816

19-
AccessStyle(::Type{E}) where {E <: QueryOperators.Enumerable} = RowAccess()
20-
schema(e::QueryOperators.Enumerable) = nondatavaluetype(eltype(e))
21-
rows(e::E) where {E <: QueryOperators.Enumerable} = DataValueUnwrapper{schema(e), E}(e)
22-
23-
function Base.iterate(rows::DataValueUnwrapper{NT}, st=()) where {NT <: NamedTuple{names}} where {names}
17+
function Base.iterate(rows::DataValueUnwrapper, st=())
2418
x = iterate(rows.x, st...)
2519
x === nothing && return nothing
2620
row, st = x
2721
return DataValueUnwrapRow(row), (st,)
2822
end
23+
24+
struct DataValueUnwrapRow{T}
25+
row::T
26+
end
27+
28+
Base.getproperty(d::DataValueUnwrapRow, ::Type{T}, col::Int, nm::Symbol) where {T} = unwrap(getproperty(getfield(d, 1), T, col, nm))
29+
Base.getproperty(d::DataValueUnwrapRow, nm::Symbol) = unwrap(getproperty(getfield(d, 1), nm))
30+
Base.propertynames(d::DataValueUnwrapRow) = propertynames(getfield(d, 1))

src/fallbacks.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ end
1010

1111
Base.getproperty(c::ColumnsRow, ::Type{T}, col::Int, nm::Symbol) where {T} = getproperty(getfield(c, 1), T, col, nm)[getfield(c, 2)]
1212
Base.getproperty(c::ColumnsRow, nm::Symbol) = getproperty(getfield(c, 1), nm)[getfield(c, 2)]
13-
Base.propertynames(c::ColumnsRow) = propertynames(c.columns)
13+
Base.propertynames(c::ColumnsRow) = propertynames(getfield(c, 1))
1414

1515
struct RowIterator{T}
1616
columns::T
@@ -33,6 +33,8 @@ function rows(x::T) where {T}
3333
end
3434

3535
# build columns from rows
36+
haslength(L) = L isa Union{Base.HasShape, Base.HasLength}
37+
3638
"""
3739
Tables.allocatecolumn(::Type{T}, len) => returns a column type (usually AbstractVector) w/ size to hold `len` elements
3840
@@ -55,7 +57,7 @@ end
5557

5658
@inline function buildcolumns(schema, rowitr::T) where {T}
5759
L = Base.IteratorSize(T)
58-
len = Base.haslength(L) ? length(rowitr) : 0
60+
len = haslength(L) ? length(rowitr) : 0
5961
nt = allocatecolumns(schema, len)
6062
for (i, row) in enumerate(rowitr)
6163
eachcolumn(add!, schema, row, L, nt, i)
@@ -91,7 +93,7 @@ function buildcolumns(::Nothing, rowitr::T) where {T}
9193
row::eltype(rowitr), st = state
9294
names = propertynames(row)
9395
L = Base.IteratorSize(T)
94-
len = Base.haslength(L) ? length(rowitr) : 0
96+
len = haslength(L) ? length(rowitr) : 0
9597
sch = Schema(names, nothing)
9698
columns = NamedTuple{names}(Tuple(Union{}[] for _ = 1:length(names)))
9799
return _buildcolumns(rowitr, row, st, sch, L, columns, 1, len, Ref{Any}(columns))

test/runtests.jl

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,12 @@ end
8888
@test Tables.buildcolumns(nothing, rt) == nt
8989
rt = [(a=1, b=4.0, c="7"), (a=2.0, b=missing, c="8"), (a=3, b=6.0, c="9")]
9090
@test isequal(Tables.buildcolumns(nothing, rt), (a = Real[1, 2.0, 3], b = Union{Missing, Float64}[4.0, missing, 6.0], c = ["7", "8", "9"]))
91+
92+
nti = Tables.NamedTupleIterator{nothing, typeof(rt)}(rt)
93+
nti2 = collect(nti)
94+
@test isequal(rt, nti2)
95+
96+
@test Tables.columntable(nothing, nt) == nt
9197
end
9298

9399
import Base: ==
@@ -146,8 +152,26 @@ function genericcolumntable(x)
146152
end
147153
==(a::GenericColumnTable, b::GenericColumnTable) = getfield(a, 1) == getfield(b, 1) && getfield(a, 2) == getfield(b, 2)
148154

149-
@testset "Tables.jl" begin
150-
155+
@testset "Tables.jl interface" begin
156+
157+
@test !Tables.istable(1)
158+
@test !Tables.istable(Int)
159+
@test !Tables.rowaccess(1)
160+
@test !Tables.rowaccess(Int)
161+
@test !Tables.columnaccess(1)
162+
@test !Tables.columnaccess(Int)
163+
@test Tables.schema(1) === nothing
164+
165+
sch = Tables.Schema{(:a, :b), Tuple{Int64, Float64}}()
166+
@test Tables.Schema((:a, :b), Tuple{Int64, Float64}) === sch
167+
@test Tables.Schema(NamedTuple{(:a, :b), Tuple{Int64, Float64}}) === sch
168+
@test Tables.Schema((:a, :b), nothing) === Tables.Schema{(:a, :b), nothing}()
169+
@test Tables.Schema([:a, :b], [Int64, Float64]) === sch
170+
show(sch)
171+
@test sch.names == (:a, :b)
172+
@test sch.types == (Int64, Float64)
173+
@test_throws ArgumentError sch.foobar
174+
151175
gr = GenericRowTable([GenericRow(1, 4.0, "7"), GenericRow(2, 5.0, "8"), GenericRow(3, 6.0, "9")])
152176
gc = GenericColumnTable(Dict(:a=>1, :b=>2, :c=>3), [GenericColumn([1,2,3]), GenericColumn([4.0, 5.0, 6.0]), GenericColumn(["7", "8", "9"])])
153177
@test gc == (gr |> genericcolumntable)
@@ -156,9 +180,22 @@ end
156180
end
157181

158182
@static if :Query in Symbol.(Base.loaded_modules_array())
183+
rt = (a = Real[1, 2.0, 3], b = Union{Missing, Float64}[4.0, missing, 6.0], c = ["7", "8", "9"])
184+
185+
dv = Tables.datavaluerows(rt)
186+
@test eltype(dv) == NamedTuple{(:a, :b, :c),Tuple{DataValue{Real},DataValue{Float64},DataValue{String}}}
187+
rt2 = collect(dv)
188+
@test rt2[1] == (a = DataValue{Real}(1), b = DataValue{Float64}(4.0), c = DataValue{String}("7"))
189+
190+
ei = QueryOperators.EnumerableIterable{eltype(dv), typeof(dv)}(dv)
191+
nt = ei |> columntable
192+
@test isequal(rt, nt)
193+
rt3 = ei |> rowtable
194+
@test isequal(rt |> rowtable, rt3)
159195

160196
rt = [(a=1, b=4.0, c="7"), (a=2, b=5.0, c="8"), (a=3, b=6.0, c="9")]
161-
mt = rt |> @map({_.a, _.c})
197+
map(source::Enumerable, f::Function, f_expr::Expr)
198+
mt = ei |> y->QueryOperators.map(y, x->(a=x.a, c=x.c), Expr(:block))
162199
@inferred (mt |> columntable)
163200
@inferred (mt |> rowtable)
164201
end

0 commit comments

Comments
 (0)