Skip to content

Commit c6c91fc

Browse files
authored
Implement revert for Sample (#119)
1 parent e73b2cc commit c6c91fc

File tree

2 files changed

+36
-6
lines changed

2 files changed

+36
-6
lines changed

src/transforms/sample.jl

+28-5
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ Sample(size::Int, weights::AbstractWeights;
4646
Sample(size::Int, weights; kwargs...) =
4747
Sample(size, Weights(collect(weights)); kwargs...)
4848

49-
isrevertible(::Type{<:Sample}) = false
49+
isrevertible(::Type{<:Sample}) = true
5050

5151
function preprocess(transform::Sample, table)
5252
# retrieve valid indices
@@ -60,24 +60,47 @@ function preprocess(transform::Sample, table)
6060
rng = transform.rng
6161

6262
# sample a subset of indices
63-
if isnothing(weights)
63+
sinds = if isnothing(weights)
6464
sample(rng, inds, size; replace, ordered)
6565
else
6666
sample(rng, inds, weights, size; replace, ordered)
6767
end
68+
rinds = setdiff(inds, sinds)
69+
70+
sinds, rinds
6871
end
6972

7073
function applyfeat(::Sample, table, prep)
7174
# collect all rows
7275
rows = Tables.rowtable(table)
7376

7477
# preprocessed indices
75-
sinds = prep
78+
sinds, rinds = prep
7679

7780
# select rows
7881
srows = view(rows, sinds)
82+
rrows = view(rows, rinds)
83+
84+
stable = srows |> Tables.materializer(table)
85+
86+
stable, (sinds, rinds, rrows)
87+
end
88+
89+
function revertfeat(::Sample, newtable, fcache)
90+
# collect all rows
91+
rows = Tables.rowtable(newtable)
92+
93+
sinds, rinds, rrows = fcache
7994

80-
newtable = srows |> Tables.materializer(table)
95+
uinds = sort(unique(sinds))
96+
urows = map(uinds) do i
97+
j = findfirst(==(i), sinds)
98+
rows[j]
99+
end
100+
101+
for (i, row) in zip(rinds, rrows)
102+
insert!(urows, i, row)
103+
end
81104

82-
newtable, nothing
105+
urows |> Tables.materializer(newtable)
83106
end

test/transforms.jl

+8-1
Original file line numberDiff line numberDiff line change
@@ -678,20 +678,27 @@
678678
b = [8, 5, 1, 2, 3, 4]
679679
c = [1, 8, 5, 2, 9, 4]
680680
t = Table(; a, b, c)
681-
trows = Tables.rowtable(t)
682681

683682
T = Sample(30, replace=true)
684683
n, c = apply(T, t)
685684
@test length(n.a) == 30
686685

686+
@test isrevertible(T)
687+
r = revert(T, n, c)
688+
@test r == t
689+
687690
T = Sample(6, replace=false)
688691
n, c = apply(T, t)
689692
@test n.a t.a
690693
@test n.b t.b
691694
@test n.c t.c
692695

696+
r = revert(T, n, c)
697+
@test r == t
698+
693699
T = Sample(30, replace=true, ordered=true)
694700
n, c = apply(T, t)
701+
trows = Tables.rowtable(t)
695702
@test unique(Tables.rowtable(n)) == trows
696703

697704
T = Sample(6, replace=false, ordered=true)

0 commit comments

Comments
 (0)