Skip to content

Commit 3ea33a5

Browse files
authored
Refactor Filter implementation: avoid Tables.rowtable (#240)
* Refactor 'Filter' implementation: avoid 'Tables.rowtable' * Update tests
1 parent 169ad55 commit 3ea33a5

File tree

4 files changed

+40
-17
lines changed

4 files changed

+40
-17
lines changed

src/transforms/filter.jl

+20-11
Original file line numberDiff line numberDiff line change
@@ -47,29 +47,38 @@ function preprocess(transform::Filter, feat)
4747
end
4848

4949
function applyfeat(::Filter, feat, prep)
50-
# collect all rows
51-
rows = Tables.rowtable(feat)
52-
5350
# preprocessed indices
5451
sinds, rinds = prep
5552

56-
# select/reject rows
57-
srows = view(rows, sinds)
58-
rrows = view(rows, rinds)
53+
# selected/rejected rows
54+
srows = Tables.subset(feat, sinds, viewhint=true)
55+
rrows = Tables.subset(feat, rinds, viewhint=true)
5956

6057
newfeat = srows |> Tables.materializer(feat)
6158

6259
newfeat, (rinds, rrows)
6360
end
6461

6562
function revertfeat(::Filter, newfeat, fcache)
66-
# collect all rows
67-
rows = Tables.rowtable(newfeat)
63+
cols = Tables.columns(newfeat)
64+
names = Tables.columnnames(cols)
6865

6966
rinds, rrows = fcache
70-
for (i, row) in zip(rinds, rrows)
71-
insert!(rows, i, row)
67+
68+
# columns with selected rows
69+
columns = map(names) do name
70+
collect(Tables.getcolumn(cols, name))
71+
end
72+
73+
# insert rejected rows into columns
74+
rrcols = Tables.columns(rrows)
75+
for (name, x) in zip(names, columns)
76+
r = Tables.getcolumn(rrcols, name)
77+
for (i, v) in zip(rinds, r)
78+
insert!(x, i, v)
79+
end
7280
end
7381

74-
rows |> Tables.materializer(newfeat)
82+
𝒯 = (; zip(names, columns)...)
83+
𝒯 |> Tables.materializer(newfeat)
7584
end

src/transforms/sample.jl

+5-2
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ function preprocess(transform::Sample, feat)
5858
else
5959
sample(rng, inds, weights, size; replace, ordered)
6060
end
61+
62+
# rejected indices
6163
rinds = setdiff(inds, sinds)
6264

6365
sinds, rinds
@@ -67,7 +69,7 @@ function applyfeat(::Sample, feat, prep)
6769
# preprocessed indices
6870
sinds, rinds = prep
6971

70-
# selected and removed rows
72+
# selected/rejected rows
7173
srows = Tables.subset(feat, sinds, viewhint=true)
7274
rrows = Tables.subset(feat, rinds, viewhint=true)
7375

@@ -78,6 +80,7 @@ end
7880
function revertfeat(::Sample, newfeat, fcache)
7981
cols = Tables.columns(newfeat)
8082
names = Tables.columnnames(cols)
83+
8184
sinds, rinds, rrows = fcache
8285

8386
# columns with selected rows in original order
@@ -87,7 +90,7 @@ function revertfeat(::Sample, newfeat, fcache)
8790
[y[i] for i in uinds]
8891
end
8992

90-
# insert removed rows into columns
93+
# insert rejected rows into columns
9194
rrcols = Tables.columns(rrows)
9295
for (name, x) in zip(names, columns)
9396
r = Tables.getcolumn(rrcols, name)

test/transforms/filter.jl

+10
Original file line numberDiff line numberDiff line change
@@ -123,4 +123,14 @@
123123
@test Tables.isrowtable(n)
124124
rtₒ = revert(T, n, c)
125125
@test rt == rtₒ
126+
127+
# performance tests
128+
trng = MersenneTwister(2) # test rng
129+
x = rand(trng, 100_000)
130+
y = rand(trng, 100_000)
131+
c = CoDaArray((a=rand(trng, 100_000), b=rand(trng, 100_000), c=rand(trng, 100_000)))
132+
t = (; x, y, c)
133+
134+
T = Filter(row -> row.x > 0.5)
135+
@test @elapsed(apply(T, t)) < 0.5
126136
end

test/transforms/sample.jl

+5-4
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,11 @@
6262
@test isapprox(count(==(trows[6]), nrows) / 10_000, 6 / 21, atol=0.01)
6363

6464
# performance tests
65-
x = rand(100_000)
66-
y = rand(100_000)
67-
c = CoDaArray((a=rand(100_000), b=rand(100_000), c=rand(100_000)))
68-
t = Table(; x, y, c)
65+
trng = MersenneTwister(2) # test rng
66+
x = rand(trng, 100_000)
67+
y = rand(trng, 100_000)
68+
c = CoDaArray((a=rand(trng, 100_000), b=rand(trng, 100_000), c=rand(trng, 100_000)))
69+
t = (; x, y, c)
6970

7071
T = Sample(10_000)
7172
@test @elapsed(apply(T, t)) < 0.5

0 commit comments

Comments
 (0)