Skip to content

Commit 9e8f605

Browse files
authored
Merge pull request #224 from JuliaML/fix
Fix #223 + Some refactors in `tablerows` and `Parallel`
2 parents 84453f9 + 9ad346c commit 9e8f605

File tree

4 files changed

+47
-21
lines changed

4 files changed

+47
-21
lines changed

src/tablerows.jl

+16-8
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ struct CTableRows{T}
5252

5353
function CTableRows(table)
5454
cols = Tables.columns(table)
55-
nrows = _nrows(cols)
55+
nrows = _nrows(table)
5656
new{typeof(cols)}(cols, nrows)
5757
end
5858
end
@@ -81,15 +81,17 @@ Tables.getcolumn(row::CTableRow, nm::Symbol) = Tables.getcolumn(getcols(row), nm
8181

8282
struct RTableRows{T}
8383
rows::T
84+
nrows::Int
8485

8586
function RTableRows(table)
8687
rows = Tables.rows(table)
87-
new{typeof(rows)}(rows)
88+
nrows = _nrows(table)
89+
new{typeof(rows)}(rows, nrows)
8890
end
8991
end
9092

9193
# iterator interface
92-
Base.length(rows::RTableRows) = length(rows.rows)
94+
Base.length(rows::RTableRows) = rows.nrows
9395
function Base.iterate(rows::RTableRows, args...)
9496
next = iterate(rows.rows, args...)
9597
if isnothing(next)
@@ -116,9 +118,15 @@ Tables.getcolumn(row::RTableRow, nm::Symbol) = Tables.getcolumn(getrow(row), nm)
116118
# UTILS
117119
#-------
118120

119-
function _nrows(cols)
120-
names = Tables.columnnames(cols)
121-
isempty(names) && return 0
122-
column = Tables.getcolumn(cols, first(names))
123-
length(column)
121+
function _nrows(table)
122+
if Tables.rowaccess(table)
123+
rows = Tables.rows(table)
124+
length(rows)
125+
else
126+
cols = Tables.columns(table)
127+
names = Tables.columnnames(cols)
128+
isempty(names) && return 0
129+
column = Tables.getcolumn(cols, first(names))
130+
length(column)
131+
end
124132
end

src/transforms/parallel.jl

+18-13
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ function apply(p::ParallelTableTransform, table)
5353
feats = first.(splits)
5454
metas = last.(splits)
5555

56+
# check the number of rows of generated tables
57+
@assert allequal(_nrows(f) for f in feats) "parallel branches must produce the same number of rows"
58+
5659
# table with concatenated features
5760
newfeat = tablehcat(feats)
5861

@@ -128,6 +131,9 @@ function reapply(p::ParallelTableTransform, table, cache)
128131
feats = first.(splits)
129132
metas = last.(splits)
130133

134+
# check the number of rows of generated tables
135+
@assert allequal(_nrows(f) for f in feats) "parallel branches must produce the same number of rows"
136+
131137
# table with concatenated features
132138
newfeat = tablehcat(feats)
133139

@@ -140,24 +146,23 @@ end
140146

141147
function tablehcat(tables)
142148
# concatenate columns
143-
allvars, allvals = [], []
144-
varsdict = Set{Symbol}()
145-
for 𝒯 in tables
146-
cols = Tables.columns(𝒯)
147-
vars = Tables.columnnames(cols)
148-
vals = [Tables.getcolumn(cols, var) for var in vars]
149-
for (var, val) in zip(vars, vals)
150-
while var varsdict
151-
var = Symbol(var, :_)
149+
allnames = Symbol[]
150+
allcolumns = []
151+
for table in tables
152+
cols = Tables.columns(table)
153+
names = Tables.columnnames(cols)
154+
for name in names
155+
column = Tables.getcolumn(cols, name)
156+
while name allnames
157+
name = Symbol(name, :_)
152158
end
153-
push!(varsdict, var)
154-
push!(allvars, var)
155-
push!(allvals, val)
159+
push!(allnames, name)
160+
push!(allcolumns, column)
156161
end
157162
end
158163

159164
# table with concatenated columns
160-
𝒯 = (; zip(allvars, allvals)...)
165+
𝒯 = (; zip(allnames, allcolumns)...)
161166
𝒯 |> Tables.materializer(first(tables))
162167
end
163168

test/tablerows.jl

+2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
@test row[:a] == 1
3434
@test row["a"] == 1
3535
# iterator interface
36+
@test length(row) == 2
3637
item, state = iterate(row)
3738
@test item == 1
3839
item, state = iterate(row, state)
@@ -73,6 +74,7 @@
7374
@test row[:b] == 4
7475
@test row["b"] == 4
7576
# iterator interface
77+
@test length(row) == 2
7678
item, state = iterate(row)
7779
@test item == 1
7880
item, state = iterate(row, state)

test/transforms/parallel.jl

+11
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,15 @@
4646
@test t == n
4747
tₒ = revert(T, n, c)
4848
@test tₒ == t
49+
50+
# https://github.com/JuliaML/TableTransforms.jl/issues/223
51+
t = (a=1:4, b=5:8)
52+
left = Select(:a) Filter(row -> row.a < 4)
53+
right = Select(:b) Filter(row -> row.b < 7)
54+
T = left right
55+
@test_throws AssertionError apply(T, t)
56+
t1 = (a=1:4, b=4:7)
57+
t2 = (a=1:4, b=5:8)
58+
n, c = apply(T, t1)
59+
@test_throws AssertionError reapply(T, t2, c)
4960
end

0 commit comments

Comments
 (0)