Skip to content

Commit f8dd61a

Browse files
vepadulanodpiparo
authored andcommitted
[df] Refresh input TTree when changing spec
In-between distributed tasks the input data source of the RDataFrame is changed (via RLoopManager::ChangeSpec), thus invalidating the input TTree from the previous task. Ensure the Snapshot action helper is always aware of the latest, currently valid TTree by retrieving it from the input RLoopManager directly. This fixes the failures in the graph caching test. The test was also extended to check that the output snapshot from the distributed tasks also contains the right data as it was read from the input TTree.
1 parent de63ccc commit f8dd61a

File tree

5 files changed

+95
-57
lines changed

5 files changed

+95
-57
lines changed

bindings/experimental/distrdf/test/backend/test_graph_caching.py

Lines changed: 62 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
21
import os
32
import unittest
3+
from array import array
44

5+
import ROOT
6+
from DistRDF._graph_cache import _ACTIONS_REGISTER, _RDF_REGISTER
57
from DistRDF.Backends import Base
68
from DistRDF.DataFrame import RDataFrame
79
from DistRDF.HeadNode import get_headnode
8-
from DistRDF._graph_cache import _ACTIONS_REGISTER, _RDF_REGISTER
910

1011

1112
def clear_caches():
@@ -110,42 +111,74 @@ def test_count_tchain(self):
110111
# The RDataFrame should have run as many times as partitions
111112
self.assertEqual(cached_rdf.GetNRuns(), npartitions)
112113

113-
def test_snapshot(self):
114-
"""The cache is used to Snapshot data."""
115-
treename = "myTree"
116-
filenames = ["4clusters.root"] * 5
117-
nentries = 5000
118-
backend = GraphCaching.TestBackend()
119-
114+
def _test_snapshot_impl(self, backend, treename, filenames, nentries):
120115
for npartitions in [1, 2, 4, 8, 16]:
121116
# Start from a fresh cache at each subtest iteration
122117
clear_caches()
123118

124119
output_basename = "test_graph_caching_test_snapshot"
125120
output_filenames = [f"{output_basename}_{i}.root" for i in range(npartitions)]
121+
try:
122+
with self.subTest(npartitions=npartitions):
123+
headnode = get_headnode(backend, npartitions, treename, filenames)
124+
distrdf = RDataFrame(headnode)
125+
126+
output_branch = "b1"
127+
snapdf = distrdf.Snapshot(treename, f"{output_basename}.root", (output_branch,))
128+
129+
# There should be exactly one cached RDF and set of actions
130+
self.assertEqual(len(_RDF_REGISTER), 1)
131+
self.assertEqual(len(_RDF_REGISTER), len(_ACTIONS_REGISTER))
132+
cached_rdf = tuple(_RDF_REGISTER.values())[0]
133+
# The RDataFrame should have run as many times as partitions
134+
self.assertEqual(cached_rdf.GetNRuns(), npartitions)
135+
# All the correct output files should be present
136+
for output_filename in output_filenames:
137+
self.assertTrue(os.path.exists(output_filename))
138+
# Make sure we have the correct data in the snapshot output
139+
self.assertListEqual(list(snapdf.AsNumpy()[output_branch]), list(range(500)))
140+
finally:
141+
# Remove output files at each iteration
142+
for output_filename in output_filenames:
143+
try:
144+
os.remove(output_filename)
145+
except OSError:
146+
pass
126147

127-
with self.subTest(npartitions=npartitions):
128-
headnode = get_headnode(backend, npartitions, treename, filenames)
129-
distrdf = RDataFrame(headnode)
130-
131-
output_branches = ("b1", )
132-
snapdf = distrdf.Snapshot(treename, f"{output_basename}.root", ["b1", ])
148+
def test_snapshot(self):
149+
"""The cache is used to Snapshot data."""
133150

134-
# There should be exactly one cached RDF and set of actions
135-
self.assertEqual(len(_RDF_REGISTER), 1)
136-
self.assertEqual(len(_RDF_REGISTER), len(_ACTIONS_REGISTER))
137-
cached_rdf = tuple(_RDF_REGISTER.values())[0]
138-
# The RDataFrame should have run as many times as partitions
139-
self.assertEqual(cached_rdf.GetNRuns(), npartitions)
140-
# All the correct output files should be present
141-
for output_filename in output_filenames:
142-
self.assertTrue(os.path.exists(output_filename))
143-
# The snapshotted dataframe should be usable
144-
self.assertEqual(snapdf.Count().GetValue(), nentries)
151+
def write_data(treenames, filenames):
152+
# Create a dataset with 100 entries per file with sequential values
153+
b1 = array("i", [0])
154+
for i in range(1, len(filenames) + 1):
155+
with ROOT.TFile.Open(filenames[i - 1], "recreate") as f:
156+
t = ROOT.TTree(treenames[i - 1], treenames[i - 1])
157+
t.Branch("b1", b1, "b1/I")
158+
nentries = 0
159+
for val in range((i - 1) * 100, i * 100):
160+
b1[0] = val
161+
t.Fill()
162+
nentries += 1
163+
# 5 clusters per file
164+
if nentries % 20 == 0:
165+
t.FlushBaskets()
166+
f.Write()
167+
168+
treename = "Events"
169+
filenames = [f"test_graph_caching_test_snapshot_input_{i}.root" for i in range(1, 6)]
170+
nentries = 100 * len(filenames)
171+
backend = GraphCaching.TestBackend()
145172

146-
# Remove output files at each iteration
147-
for output_filename in output_filenames:
148-
os.remove(output_filename)
173+
try:
174+
write_data([treename] * len(filenames), filenames)
175+
self._test_snapshot_impl(backend, treename, filenames, nentries)
176+
finally:
177+
for fn in filenames:
178+
try:
179+
os.remove(fn)
180+
except OSError:
181+
pass
149182

150183
def test_multiple_graphs(self):
151184
"""The caches are used with multiple executions."""

tree/dataframe/inc/ROOT/RDF/ActionHelpers.hxx

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1528,14 +1528,14 @@ class R__CLING_PTRCHECK(off) SnapshotTTreeHelper : public RActionImpl<SnapshotTT
15281528
RBranchSet fOutputBranches;
15291529
std::vector<bool> fIsDefine;
15301530
ROOT::Detail::RDF::RLoopManager *fOutputLoopManager;
1531-
ROOT::RDF::RDataSource *fInputDataSource;
1531+
ROOT::Detail::RDF::RLoopManager *fInputLoopManager;
15321532

15331533
public:
15341534
using ColumnTypes_t = TypeList<ColTypes...>;
15351535
SnapshotTTreeHelper(std::string_view filename, std::string_view dirname, std::string_view treename,
15361536
const ColumnNames_t &vbnames, const ColumnNames_t &bnames, const RSnapshotOptions &options,
15371537
std::vector<bool> &&isDefine, ROOT::Detail::RDF::RLoopManager *loopManager,
1538-
ROOT::RDF::RDataSource *inputDataSource)
1538+
ROOT::Detail::RDF::RLoopManager *inputLM)
15391539
: fFileName(filename),
15401540
fDirName(dirname),
15411541
fTreeName(treename),
@@ -1546,7 +1546,7 @@ public:
15461546
fBranchAddresses(vbnames.size(), nullptr),
15471547
fIsDefine(std::move(isDefine)),
15481548
fOutputLoopManager(loopManager),
1549-
fInputDataSource(inputDataSource)
1549+
fInputLoopManager(inputLM)
15501550
{
15511551
EnsureValidSnapshotTTreeOutput(fOptions, fTreeName, fFileName);
15521552
}
@@ -1569,12 +1569,11 @@ public:
15691569
}
15701570
}
15711571

1572-
void InitTask(TTreeReader *r, unsigned int /* slot */)
1572+
void InitTask(TTreeReader * /*treeReader*/, unsigned int /* slot */)
15731573
{
1574-
if (r)
1575-
fInputTree = r->GetTree();
1576-
else if (auto treeDS = dynamic_cast<ROOT::Internal::RDF::RTTreeDS *>(fInputDataSource))
1577-
fInputTree = treeDS->GetTree();
1574+
// We ask the input RLoopManager if it has a TTree. We cannot rely on getting this information when constructing
1575+
// this action helper, since the TTree might change e.g. when ChangeSpec is called in-between distributed tasks.
1576+
fInputTree = fInputLoopManager->GetTree();
15781577
fBranchAddressesNeedReset = true;
15791578
}
15801579

@@ -1689,7 +1688,7 @@ public:
16891688
fOptions,
16901689
std::vector<bool>(fIsDefine),
16911690
fOutputLoopManager,
1692-
fInputDataSource};
1691+
fInputLoopManager};
16931692
}
16941693
};
16951694

@@ -1715,15 +1714,15 @@ class R__CLING_PTRCHECK(off) SnapshotTTreeHelperMT : public RActionImpl<Snapshot
17151714
std::vector<RBranchSet> fOutputBranches;
17161715
std::vector<bool> fIsDefine;
17171716
ROOT::Detail::RDF::RLoopManager *fOutputLoopManager;
1718-
ROOT::RDF::RDataSource *fInputDataSource;
1717+
ROOT::Detail::RDF::RLoopManager *fInputLoopManager;
17191718

17201719
public:
17211720
using ColumnTypes_t = TypeList<ColTypes...>;
17221721

17231722
SnapshotTTreeHelperMT(const unsigned int nSlots, std::string_view filename, std::string_view dirname,
17241723
std::string_view treename, const ColumnNames_t &vbnames, const ColumnNames_t &bnames,
17251724
const RSnapshotOptions &options, std::vector<bool> &&isDefine,
1726-
ROOT::Detail::RDF::RLoopManager *loopManager, ROOT::RDF::RDataSource *inputDataSource)
1725+
ROOT::Detail::RDF::RLoopManager *loopManager, ROOT::Detail::RDF::RLoopManager *inputLM)
17271726
: fNSlots(nSlots),
17281727
fOutputFiles(fNSlots),
17291728
fOutputTrees(fNSlots),
@@ -1740,7 +1739,7 @@ public:
17401739
fOutputBranches(fNSlots),
17411740
fIsDefine(std::move(isDefine)),
17421741
fOutputLoopManager(loopManager),
1743-
fInputDataSource(inputDataSource)
1742+
fInputLoopManager(inputLM)
17441743
{
17451744
EnsureValidSnapshotTTreeOutput(fOptions, fTreeName, fFileName);
17461745
}
@@ -1785,11 +1784,11 @@ public:
17851784
if (fOptions.fAutoFlush)
17861785
fOutputTrees[slot]->SetAutoFlush(fOptions.fAutoFlush);
17871786
if (r) {
1788-
// not an empty-source RDF
1787+
// We could be getting a task-local TTreeReader from the TTreeProcessorMT.
17891788
fInputTrees[slot] = r->GetTree();
1790-
} else if (auto treeDS = dynamic_cast<ROOT::Internal::RDF::RTTreeDS *>(fInputDataSource))
1791-
fInputTrees[slot] = treeDS->GetTree();
1792-
1789+
} else {
1790+
fInputTrees[slot] = fInputLoopManager->GetTree();
1791+
}
17931792
fBranchAddressesNeedReset[slot] = 1; // reset first event flag for this slot
17941793
}
17951794

@@ -1914,7 +1913,7 @@ public:
19141913
fOptions,
19151914
std::vector<bool>(fIsDefine),
19161915
fOutputLoopManager,
1917-
fInputDataSource};
1916+
fInputLoopManager};
19181917
}
19191918
};
19201919

tree/dataframe/inc/ROOT/RDF/InterfaceUtils.hxx

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,8 @@ struct SnapshotHelperArgs {
250250
std::string fTreeName;
251251
std::vector<std::string> fOutputColNames;
252252
ROOT::RDF::RSnapshotOptions fOptions;
253-
ROOT::Detail::RDF::RLoopManager *fLoopManager;
254-
ROOT::RDF::RDataSource *fDataSource;
253+
ROOT::Detail::RDF::RLoopManager *fOutputLoopManager;
254+
ROOT::Detail::RDF::RLoopManager *fInputLoopManager;
255255
bool fToNTuple;
256256
};
257257

@@ -267,8 +267,8 @@ BuildAction(const ColumnNames_t &colNames, const std::shared_ptr<SnapshotHelperA
267267
const auto &treename = snapHelperArgs->fTreeName;
268268
const auto &outputColNames = snapHelperArgs->fOutputColNames;
269269
const auto &options = snapHelperArgs->fOptions;
270-
const auto &lmPtr = snapHelperArgs->fLoopManager;
271-
const auto &dataSource = snapHelperArgs->fDataSource;
270+
const auto &lmPtr = snapHelperArgs->fOutputLoopManager;
271+
const auto &inputLM = snapHelperArgs->fInputLoopManager;
272272

273273
auto sz = sizeof...(ColTypes);
274274
std::vector<bool> isDefine(sz);
@@ -304,14 +304,14 @@ BuildAction(const ColumnNames_t &colNames, const std::shared_ptr<SnapshotHelperA
304304
using Helper_t = SnapshotTTreeHelper<ColTypes...>;
305305
using Action_t = RAction<Helper_t, PrevNodeType>;
306306
actionPtr.reset(new Action_t(Helper_t(filename, dirname, treename, colNames, outputColNames, options,
307-
std::move(isDefine), lmPtr, dataSource),
307+
std::move(isDefine), lmPtr, inputLM),
308308
colNames, prevNode, colRegister));
309309
} else {
310310
// multi-thread snapshot
311311
using Helper_t = SnapshotTTreeHelperMT<ColTypes...>;
312312
using Action_t = RAction<Helper_t, PrevNodeType>;
313313
actionPtr.reset(new Action_t(Helper_t(nSlots, filename, dirname, treename, colNames, outputColNames, options,
314-
std::move(isDefine), lmPtr, dataSource),
314+
std::move(isDefine), lmPtr, inputLM),
315315
colNames, prevNode, colRegister));
316316
}
317317
}

tree/dataframe/inc/ROOT/RDF/RInterface.hxx

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1359,7 +1359,7 @@ public:
13591359

13601360
auto snapHelperArgs = std::make_shared<RDFInternal::SnapshotHelperArgs>(RDFInternal::SnapshotHelperArgs{
13611361
std::string(filename), std::string(dirname), std::string(treename), colListWithAliasesAndSizeBranches,
1362-
options, newRDF->GetLoopManager(), GetDataSource(), true /* fToNTuple */});
1362+
options, newRDF->GetLoopManager(), GetLoopManager(), true /* fToNTuple */});
13631363

13641364
// The Snapshot helper will use colListNoAliasesWithSizeBranches (with aliases resolved) as input columns, and
13651365
// colListWithAliasesAndSizeBranches (still with aliases in it, passed through snapHelperArgs) as output column
@@ -1388,7 +1388,7 @@ public:
13881388

13891389
auto snapHelperArgs = std::make_shared<RDFInternal::SnapshotHelperArgs>(RDFInternal::SnapshotHelperArgs{
13901390
std::string(filename), std::string(dirname), std::string(treename), colListWithAliasesAndSizeBranches,
1391-
options, newRDF->GetLoopManager(), GetDataSource(), false /* fToRNTuple */});
1391+
options, newRDF->GetLoopManager(), GetLoopManager(), false /* fToRNTuple */});
13921392

13931393
resPtr = CreateAction<RDFInternal::ActionTags::Snapshot, RDFDetail::RInferredType>(
13941394
colListNoAliasesWithSizeBranches, newRDF, snapHelperArgs, fProxiedPtr,
@@ -3247,7 +3247,7 @@ private:
32473247

32483248
auto snapHelperArgs = std::make_shared<RDFInternal::SnapshotHelperArgs>(RDFInternal::SnapshotHelperArgs{
32493249
std::string(filename), std::string(dirname), std::string(treename), columnListWithoutSizeColumns, options,
3250-
newRDF->GetLoopManager(), GetDataSource(), true /* fToRNTuple */});
3250+
newRDF->GetLoopManager(), GetLoopManager(), true /* fToRNTuple */});
32513251

32523252
// The Snapshot helper will use validCols (with aliases resolved) as input columns, and
32533253
// columnListWithoutSizeColumns (still with aliases in it, passed through snapHelperArgs) as output column
@@ -3275,7 +3275,7 @@ private:
32753275

32763276
auto snapHelperArgs = std::make_shared<RDFInternal::SnapshotHelperArgs>(RDFInternal::SnapshotHelperArgs{
32773277
std::string(filename), std::string(dirname), std::string(treename), columnListWithoutSizeColumns, options,
3278-
newRDF->GetLoopManager(), GetDataSource(), false /* fToRNTuple */});
3278+
newRDF->GetLoopManager(), GetLoopManager(), false /* fToRNTuple */});
32793279

32803280
// The Snapshot helper will use validCols (with aliases resolved) as input columns, and
32813281
// columnListWithoutSizeColumns (still with aliases in it, passed through snapHelperArgs) as output column

tree/dataframe/src/RLoopManager.cxx

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1070,7 +1070,13 @@ const ColumnNames_t &RLoopManager::GetDefaultColumnNames() const
10701070

10711071
TTree *RLoopManager::GetTree() const
10721072
{
1073-
return fTree.get();
1073+
// This is currently called in SnapshotTTreeHelper[MT] to retrieve the task-local input TTree. It is not guaranteed
1074+
// that the same RLoopManager will always have the same input TTree for its entire lifetime, notably it could be
1075+
// changed by ChangeSpec when moving to a different entry range.
1076+
if (auto *treeDS = dynamic_cast<ROOT::Internal::RDF::RTTreeDS *>(fDataSource.get())) {
1077+
return treeDS->GetTree();
1078+
}
1079+
return nullptr;
10741080
}
10751081

10761082
void RLoopManager::Register(RDFInternal::RActionBase *actionPtr)

0 commit comments

Comments
 (0)