Skip to content

Commit

Permalink
Fix generic RowIterator performance by caching columns length (#14)
Browse files Browse the repository at this point in the history
* Fix generic RowIterator performance by caching columns length and fixup some queryverse integration
  • Loading branch information
quinnj authored Sep 6, 2018
1 parent 4295d83 commit 36e9520
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 15 deletions.
2 changes: 1 addition & 1 deletion src/Tables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ export rowtable, columntable

function __init__()
@require DataValues="e7dc6d0d-1eca-5fa6-8ad6-5aecde8b7ea5" include("datavalues.jl")
@require QueryOperators="2aef5ad7-51ca-5a8f-8e88-e75cf067b44b" include("enumerable.jl")
@require Query="1a8c2f83-1ff3-5112-b086-8aa67b057ba1" include("enumerable.jl")
@require CategoricalArrays="324d7699-5711-5eae-9e2f-1d82baa6b597" begin
using .CategoricalArrays
allocatecolumn(::Type{CategoricalString{R}}, rows) where {R} = CategoricalArray{String, 1, R}(undef, rows)
Expand Down
15 changes: 14 additions & 1 deletion src/datavalues.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ unwrap(x) = x
unwrap(x::DataValue) = isna(x) ? missing : DataValues.unsafe_get(x)

datavaluetype(::Type{T}) where {T <: DataValue} = T
datavaluetype(::Type{T}) where {T} = DataValue{T}
datavaluetype(::Type{T}) where {T} = T
datavaluetype(::Type{Union{T, Missing}}) where {T} = DataValue{T}
Base.@pure function datavaluetype(::Tables.Schema{names, types}) where {names, types}
TT = Tuple{Any[ datavaluetype(fieldtype(types, i)) for i = 1:fieldcount(types) ]...}
Expand Down Expand Up @@ -55,3 +55,16 @@ function Base.iterate(rows::DataValueRowIterator{NT, S}, st=()) where {NT <: Nam
end
end

# function Base.iterate(rows::DataValueRowIterator{NT}, st=()) where {NT}
# state = iterate(rows.x, st...)
# state === nothing && return nothing
# row, st = state
# return DataValueRow{NT, typeof(row)}(row), (st,)
# end

# struct DataValueRow{NT, T}
# row::T
# end

# @inline Base.getproperty(dvr::DataValueRow{NamedTuple{names, types}}, nm::Symbol) where {names, types} = getproperty(dvr, Tables.columntype(names, types, nm), Tables.columnindex(names, nm), nm)
# @inline Base.getproperty(dvr::DataValueRow, ::Type{T}, col::Int, nm::Symbol) where {T} = T(getproperty(getfield(dvr, 1), nondatavaluetype(T), col, nm))
15 changes: 12 additions & 3 deletions src/enumerable.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
using .QueryOperators: Enumerable
using .DataValues
using .Query

@static if isdefined(Query.QueryOperators, :Enumerable)

import .Query.QueryOperators: Enumerable

Tables.istable(::Type{<:Enumerable}) = true
Tables.rowaccess(::Type{<:Enumerable}) = true
Expand All @@ -9,7 +12,11 @@ struct DataValueUnwrapper{S}
x::S
end

Tables.schema(dv::DataValueUnwrapper) = Tables.Schema(nondatavaluetype(eltype(dv.x)))
function Tables.schema(dv::DataValueUnwrapper)
eT = eltype(dv.x)
!(eT <: NamedTuple) && return nothing
return Tables.Schema(nondatavaluetype(eT))
end
Base.eltype(rows::DataValueUnwrapper) = DataValueUnwrapRow{eltype(rows.x)}
Base.IteratorSize(::Type{DataValueUnwrapper{S}}) where {S} = Base.IteratorSize(S)
Base.length(rows::DataValueUnwrapper) = length(rows.x)
Expand All @@ -28,3 +35,5 @@ end
Base.getproperty(d::DataValueUnwrapRow, ::Type{T}, col::Int, nm::Symbol) where {T} = unwrap(getproperty(getfield(d, 1), T, col, nm))
Base.getproperty(d::DataValueUnwrapRow, nm::Symbol) = unwrap(getproperty(getfield(d, 1), nm))
Base.propertynames(d::DataValueUnwrapRow) = propertynames(getfield(d, 1))

end # isdefined
10 changes: 7 additions & 3 deletions src/fallbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
## we'll provide a default implementation of the dual

# generic row iteration of columns
rowcount(cols) = length(getproperty(cols, propertynames(cols)[1]))

struct ColumnsRow{T}
columns::T # a `Columns` object
row::Int
Expand All @@ -14,9 +16,10 @@ Base.propertynames(c::ColumnsRow) = propertynames(getfield(c, 1))

struct RowIterator{T}
columns::T
len::Int
end
Base.eltype(x::RowIterator{T}) where {T} = ColumnsRow{T}
Base.length(x::RowIterator) = length(getproperty(x.columns, propertynames(x.columns)[1]))
Base.length(x::RowIterator) = x.len
schema(x::RowIterator) = schema(x.columns)

function Base.iterate(rows::RowIterator, st=1)
Expand All @@ -26,7 +29,8 @@ end

function rows(x::T) where {T}
if columnaccess(T)
return RowIterator(columns(x))
cols = columns(x)
return RowIterator(cols, rowcount(cols))
else
throw(ArgumentError("no default `Tables.rows` implementation for type: $T"))
end
Expand Down Expand Up @@ -90,7 +94,7 @@ end
function buildcolumns(::Nothing, rowitr::T) where {T}
state = iterate(rowitr)
state === nothing && return NamedTuple()
row::eltype(rowitr), st = state
row, st = state
names = propertynames(row)
L = Base.IteratorSize(T)
len = haslength(L) ? length(rowitr) : 0
Expand Down
2 changes: 1 addition & 1 deletion src/namedtuples.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ function Base.iterate(rows::NamedTupleIterator{Schema{names, T}}, st=()) where {
end

# unknown schema case
function Base.iterate(rows::NamedTupleIterator{nothing, T}, st=()) where {T}
function Base.iterate(rows::NamedTupleIterator{Nothing, T}, st=()) where {T}
x = iterate(rows.x, st...)
x === nothing && return nothing
row, st = x
Expand Down
16 changes: 10 additions & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ end
rt = [(a=1, b=4.0, c="7"), (a=2.0, b=missing, c="8"), (a=3, b=6.0, c="9")]
@test isequal(Tables.buildcolumns(nothing, rt), (a = Real[1, 2.0, 3], b = Union{Missing, Float64}[4.0, missing, 6.0], c = ["7", "8", "9"]))

nti = Tables.NamedTupleIterator{nothing, typeof(rt)}(rt)
nti = Tables.NamedTupleIterator{Nothing, typeof(rt)}(rt)
nti2 = collect(nti)
@test isequal(rt, nti2)

Expand Down Expand Up @@ -183,19 +183,23 @@ end
rt = (a = Real[1, 2.0, 3], b = Union{Missing, Float64}[4.0, missing, 6.0], c = ["7", "8", "9"])

dv = Tables.datavaluerows(rt)
@test eltype(dv) == NamedTuple{(:a, :b, :c),Tuple{DataValue{Real},DataValue{Float64},DataValue{String}}}
@test eltype(dv) == NamedTuple{(:a, :b, :c),Tuple{Real,DataValue{Float64},String}}
rt2 = collect(dv)
@test rt2[1] == (a = DataValue{Real}(1), b = DataValue{Float64}(4.0), c = DataValue{String}("7"))
@test rt2[1] == (a = 1, b = DataValue{Float64}(4.0), c = "7")

ei = QueryOperators.EnumerableIterable{eltype(dv), typeof(dv)}(dv)
nt = ei |> columntable
@test isequal(rt, nt)
rt3 = ei |> rowtable
@test isequal(rt |> rowtable, rt3)

rt = [(a=1, b=4.0, c="7"), (a=2, b=5.0, c="8"), (a=3, b=6.0, c="9")]
map(source::Enumerable, f::Function, f_expr::Expr)
mt = ei |> y->QueryOperators.map(y, x->(a=x.a, c=x.c), Expr(:block))
# rt = [(a=1, b=4.0, c="7"), (a=2, b=5.0, c="8"), (a=3, b=6.0, c="9")]
mt = ei |> y->QueryOperators.map(y, x->(a=x.b, c=x.c), Expr(:block))
@inferred (mt |> columntable)
@inferred (mt |> rowtable)

# uninferrable case
mt = ei |> y->QueryOperators.map(y, x->(a=x.a, c=x.c), Expr(:block))
@test (mt |> columntable) == (a = Real[1, 2.0, 3], c = ["7", "8", "9"])
@test length(mt |> rowtable) == 3
end

0 comments on commit 36e9520

Please sign in to comment.