Skip to content

Commit 3b7f2ae

Browse files
authored
Fix Parallel transform in the presence of metadata (#122)
* Fix Parallel transform in the presence of metadata * Update src/transforms/parallel.jl * Update src/transforms/parallel.jl
1 parent d4fa63c commit 3b7f2ae

File tree

1 file changed

+43
-13
lines changed

1 file changed

+43
-13
lines changed

src/transforms/parallel.jl

+43-13
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,19 @@ function apply(p::Parallel, table)
4141
tables = first.(vals)
4242
caches = last.(vals)
4343

44-
# table with concatenated columns
45-
newtable = tablehcat(tables)
44+
# features and metadata
45+
splits = divide.(tables)
46+
feats = first.(splits)
47+
metas = last.(splits)
48+
49+
# table with concatenated features
50+
newfeat = tablehcat(feats)
51+
52+
# propagate metadata
53+
newmeta = first(metas)
54+
55+
# attach new features and metatada
56+
newtable = attach(newfeat, newmeta)
4657

4758
# find first revertible transform
4859
ind = findfirst(isrevertible, p.transforms)
@@ -51,9 +62,9 @@ function apply(p::Parallel, table)
5162
rinfo = if isnothing(ind)
5263
nothing
5364
else
54-
tcols = Tables.columns.(tables)
55-
tnames = Tables.columnnames.(tcols)
56-
ncols = length.(tnames)
65+
fcols = Tables.columns.(feats)
66+
fnames = Tables.columnnames.(fcols)
67+
ncols = length.(fnames)
5768
nrcols = ncols[ind]
5869
start = sum(ncols[1:ind-1]) + 1
5970
finish = start + nrcols - 1
@@ -71,24 +82,32 @@ function revert(p::Parallel, newtable, cache)
7182

7283
@assert !isnothing(rinfo) "transform is not revertible"
7384

85+
# features and metadata
86+
newfeat, newmeta = divide(newtable)
87+
7488
# retrieve info to revert transform
7589
ind = rinfo[1]
7690
range = rinfo[2]
7791
rtrans = p.transforms[ind]
7892
rcache = caches[ind]
7993

8094
# columns of transformed table
81-
cols = Tables.columns(newtable)
82-
names = Tables.columnnames(cols)
95+
fcols = Tables.columns(newfeat)
96+
names = Tables.columnnames(fcols)
8397

8498
# retrieve subtable to revert
85-
rcols = [Tables.getcolumn(cols, j) for j in range]
8699
rnames = names[range]
87-
𝒯 = (; zip(rnames, rcols)...)
88-
rtable = 𝒯 |> Tables.materializer(newtable)
100+
rcols = [Tables.getcolumn(fcols, j) for j in range]
101+
rfeat = (; zip(rnames, rcols)...) |> Tables.materializer(newfeat)
89102

90103
# revert transform on subtable
91-
revert(rtrans, rtable, rcache)
104+
feat = revert(rtrans, rfeat, rcache)
105+
106+
# propagate metadata
107+
meta = newmeta
108+
109+
# attach features and metadata
110+
attach(feat, meta)
92111
end
93112

94113
function reapply(p::Parallel, table, cache)
@@ -100,8 +119,19 @@ function reapply(p::Parallel, table, cache)
100119
itr = zip(p.transforms, caches)
101120
tables = tcollect(f(t, c) for (t, c) in itr)
102121

103-
# table with concatenated columns
104-
tablehcat(tables)
122+
# features and metadata
123+
splits = divide.(tables)
124+
feats = first.(splits)
125+
metas = last.(splits)
126+
127+
# table with concatenated features
128+
newfeat = tablehcat(feats)
129+
130+
# metadata of the first table
131+
newmeta = first(metas)
132+
133+
# attach new features and metatada
134+
attach(newfeat, newmeta)
105135
end
106136

107137
function tablehcat(tables)

0 commit comments

Comments
 (0)