Skip to content

Commit 6a478e0

Browse files
committed
Use mktree now for uproot
1 parent b96cf7f commit 6a478e0

1 file changed

Lines changed: 98 additions & 72 deletions

File tree

gatetools/merge_root.py

Lines changed: 98 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,16 @@
1919
# -----------------------------------------------------------------------------
2020

2121

22-
import gatetools as gt
22+
import logging
23+
2324
import numpy as np
2425
import tqdm
25-
import logging
26-
logger=logging.getLogger(__name__)
26+
import uproot
27+
28+
import gatetools as gt
29+
30+
logger = logging.getLogger(__name__)
31+
2732

2833
def unicity(root_keys):
2934
"""
@@ -38,34 +43,30 @@ def unicity(root_keys):
3843
name = name[0]
3944
if not name in root_array:
4045
root_array.append(name)
41-
return(root_array)
46+
return root_array
47+
4248

4349
def merge_root(rootfiles, outputfile, incrementRunId=False):
4450
"""
4551
Merge root files in output files
4652
"""
47-
try:
48-
import uproot
49-
except:
50-
print("uproot4 is mandatory to merge root file. Please, do:")
51-
print("pip install uproot")
5253

5354
uproot.default_library = "np"
5455

5556
out = uproot.recreate(outputfile)
5657

57-
#Previous ID values to be able to increment runIn or EventId
58+
# Previous ID values to be able to increment runIn or EventId
5859
previousId = {}
5960

60-
#create the dict reading all input root files
61-
trees = {} #TTree with TBranch
62-
hists = {} #Directory with THist
63-
pbar = tqdm.tqdm(total = len(rootfiles))
61+
# create the dict reading all input root files
62+
trees = {} # TTree with TBranch
63+
hists = {} # Directory with THist
64+
pbar = tqdm.tqdm(total=len(rootfiles))
6465
for file in rootfiles:
6566
root = uproot.open(file)
6667
root_keys = unicity(root.keys())
6768
for tree in root_keys:
68-
if hasattr(root[tree], 'keys'):
69+
if hasattr(root[tree], "keys"):
6970
if not tree in trees:
7071
trees[tree] = {}
7172
trees[tree]["rootDictType"] = {}
@@ -75,67 +76,93 @@ def merge_root(rootfiles, outputfile, incrementRunId=False):
7576
hists[tree]["rootDictValue"] = {}
7677
previousId[tree] = {}
7778
for branch in root[tree].keys():
78-
if isinstance(root[tree],uproot.reading.ReadOnlyDirectory):
79+
if isinstance(root[tree], uproot.reading.ReadOnlyDirectory):
7980
array = root[tree][branch].values()
8081
if len(array) > 0:
8182
branchName = tree + "/" + branch
82-
if type(array[0]) is type('c'):
83+
if type(array[0]) is type("c"):
8384
array = np.array([0 for xi in array])
8485
if not branchName in hists[tree]["rootDictType"]:
85-
hists[tree]["rootDictType"][branchName] = root[tree][branch].to_numpy()
86-
hists[tree]["rootDictValue"][branchName] = np.zeros(array.shape)
86+
hists[tree]["rootDictType"][branchName] = root[tree][
87+
branch
88+
].to_numpy()
89+
hists[tree]["rootDictValue"][branchName] = np.zeros(
90+
array.shape
91+
)
8792
hists[tree]["rootDictValue"][branchName] += array
8893
else:
8994
array = root[tree][branch].array(library="np")
90-
if len(array) > 0 and not (type(array[0]) is type(np.ndarray(2,))):
91-
if type(array[0]) is type('c'):
95+
if len(array) > 0 and (
96+
type(array[0])
97+
is not type(
98+
np.ndarray(
99+
2,
100+
)
101+
)
102+
):
103+
if type(array[0]) is type("c"):
92104
array = np.array([0 for xi in array])
93-
if not branch in trees[tree]["rootDictType"]:
105+
if branch not in trees[tree]["rootDictType"]:
94106
trees[tree]["rootDictType"][branch] = type(array[0])
95107
trees[tree]["rootDictValue"][branch] = np.array([])
96-
if (not incrementRunId and branch.startswith('eventID')) or (incrementRunId and branch.startswith('runID')):
97-
if not branch in previousId[tree]:
108+
if (
109+
not incrementRunId and branch.startswith("eventID")
110+
) or (incrementRunId and branch.startswith("runID")):
111+
if branch not in previousId[tree]:
98112
previousId[tree][branch] = 0
99113
array += previousId[tree][branch]
100-
previousId[tree][branch] = max(array) +1
101-
trees[tree]["rootDictValue"][branch] = np.append(trees[tree]["rootDictValue"][branch], array)
114+
previousId[tree][branch] = max(array) + 1
115+
trees[tree]["rootDictValue"][branch] = np.append(
116+
trees[tree]["rootDictValue"][branch], array
117+
)
102118
pbar.update(1)
103119
pbar.close()
104120

105-
#Set the dict in the output root file
121+
# Set the dict in the output root file
106122
for tree in trees:
107-
if not trees[tree]["rootDictValue"] == {} or not trees[tree]["rootDictType"] == {}:
108-
#out.mktree(tree, trees[tree]["rootDictType"])
109-
out[tree] = trees[tree]["rootDictValue"]
123+
if (
124+
not trees[tree]["rootDictValue"] == {}
125+
or not trees[tree]["rootDictType"] == {}
126+
):
127+
out.mktree(tree, trees[tree]["rootDictType"])
110128
for hist in hists:
111-
if not hists[hist]["rootDictValue"] == {} or not hists[hist]["rootDictType"] == {}:
129+
if (
130+
not hists[hist]["rootDictValue"] == {}
131+
or not hists[hist]["rootDictType"] == {}
132+
):
112133
for branch in hists[hist]["rootDictValue"]:
113134
for i in range(len(hists[hist]["rootDictValue"][branch])):
114-
hists[hist]["rootDictType"][branch][0][i] = hists[hist]["rootDictValue"][branch][i]
115-
out[branch[:-2]] = hists[hist]["rootDictType"][branch]
135+
hists[hist]["rootDictType"][branch][0][i] = hists[hist][
136+
"rootDictValue"
137+
][branch][i]
138+
out.mktree(branch[:-2], hists[hist]["rootDictType"][branch])
116139

117140

118141
#####################################################################################
119-
import unittest
120-
import tempfile
121-
import wget
122142
import os
123143
import shutil
144+
import tempfile
145+
import unittest
146+
124147
import numpy as np
148+
import uproot
149+
import wget
150+
125151
from .logging_conf import LoggedTestCase
126152

153+
127154
class Test_MergeRoot(LoggedTestCase):
128155
def test_merge_root_phsp(self):
129-
try:
130-
import uproot
131-
except:
132-
print("uproot4 is mandatory to merge root file. Please, do:")
133-
print("pip install uproot")
134-
135-
logger.info('Test_MergeRoot test_merge_root_phsp')
156+
logger.info("Test_MergeRoot test_merge_root_phsp")
136157
tmpdirpath = tempfile.mkdtemp()
137-
filenameRoot = wget.download("https://gitlab.in2p3.fr/opengate/gatetools_data/-/raw/master/phsp.root?inline=false", out=tmpdirpath, bar=None)
138-
gt.merge_root([filenameRoot, filenameRoot], os.path.join(tmpdirpath, "output.root"))
158+
filenameRoot = wget.download(
159+
"https://gitlab.in2p3.fr/opengate/gatetools_data/-/raw/master/phsp.root?inline=false",
160+
out=tmpdirpath,
161+
bar=None,
162+
)
163+
gt.merge_root(
164+
[filenameRoot, filenameRoot], os.path.join(tmpdirpath, "output.root")
165+
)
139166
input = uproot.open(filenameRoot)
140167
output = uproot.open(os.path.join(tmpdirpath, "output.root"))
141168
self.assertTrue(output.keys() == input.keys())
@@ -144,57 +171,56 @@ def test_merge_root_phsp(self):
144171
self.assertTrue(outputTree.keys() == inputTree.keys())
145172
inputBranch = inputTree[inputTree.keys()[1]].array(library="np")
146173
outputBranch = outputTree[outputTree.keys()[1]].array(library="np")
147-
self.assertTrue(2*len(inputBranch) == len(outputBranch))
174+
self.assertTrue(2 * len(inputBranch) == len(outputBranch))
148175
shutil.rmtree(tmpdirpath)
149176

150177
def test_merge_root_pet_incrementEvent(self):
151-
try:
152-
import uproot
153-
except:
154-
print("uproot4 is mandatory to merge root file. Please, do:")
155-
print("pip install uproot")
156-
157-
logger.info('Test_MergeRoot test_merge_root_pet')
178+
logger.info("Test_MergeRoot test_merge_root_pet")
158179
tmpdirpath = tempfile.mkdtemp()
159-
filenameRoot = wget.download("https://gitlab.in2p3.fr/opengate/gatetools_data/-/raw/master/pet.root?inline=false", out=tmpdirpath, bar=None)
160-
gt.merge_root([filenameRoot, filenameRoot], os.path.join(tmpdirpath, "output.root"))
180+
filenameRoot = wget.download(
181+
"https://gitlab.in2p3.fr/opengate/gatetools_data/-/raw/master/pet.root?inline=false",
182+
out=tmpdirpath,
183+
bar=None,
184+
)
185+
gt.merge_root(
186+
[filenameRoot, filenameRoot], os.path.join(tmpdirpath, "output.root")
187+
)
161188
input = uproot.open(filenameRoot)
162189
output = uproot.open(os.path.join(tmpdirpath, "output.root"))
163190
inputTree = input[input.keys()[0]]
164191
outputTree = output[output.keys()[0]]
165192
inputRunBranch = inputTree[inputTree.keys()[0]].array(library="np")
166193
outputRunBranch = outputTree[outputTree.keys()[0]].array(library="np")
167194
self.assertTrue(max(inputRunBranch) == max(outputRunBranch))
168-
self.assertTrue(2*len(inputRunBranch) == len(outputRunBranch))
195+
self.assertTrue(2 * len(inputRunBranch) == len(outputRunBranch))
169196
inputEventBranch = inputTree[inputTree.keys()[1]].array(library="np")
170197
outputEventBranch = outputTree[outputTree.keys()[1]].array(library="np")
171-
self.assertTrue(2*max(inputEventBranch)+1 == max(outputEventBranch))
172-
self.assertTrue(2*len(inputEventBranch) == len(outputEventBranch))
198+
self.assertTrue(2 * max(inputEventBranch) + 1 == max(outputEventBranch))
199+
self.assertTrue(2 * len(inputEventBranch) == len(outputEventBranch))
173200
shutil.rmtree(tmpdirpath)
174201

175202
def test_merge_root_pet_incrementRun(self):
176-
try:
177-
import uproot
178-
except:
179-
print("uproot4 is mandatory to merge root file. Please, do:")
180-
print("pip install uproot")
181-
182-
logger.info('Test_MergeRoot test_merge_root_pet')
203+
logger.info("Test_MergeRoot test_merge_root_pet")
183204
tmpdirpath = tempfile.mkdtemp()
184205
print(tmpdirpath)
185-
filenameRoot = wget.download("https://gitlab.in2p3.fr/opengate/gatetools_data/-/raw/master/pet.root?inline=false", out=tmpdirpath, bar=None)
186-
gt.merge_root([filenameRoot, filenameRoot], os.path.join(tmpdirpath, "output.root"), True)
206+
filenameRoot = wget.download(
207+
"https://gitlab.in2p3.fr/opengate/gatetools_data/-/raw/master/pet.root?inline=false",
208+
out=tmpdirpath,
209+
bar=None,
210+
)
211+
gt.merge_root(
212+
[filenameRoot, filenameRoot], os.path.join(tmpdirpath, "output.root"), True
213+
)
187214
input = uproot.open(filenameRoot)
188215
output = uproot.open(os.path.join(tmpdirpath, "output.root"))
189216
inputTree = input[input.keys()[0]]
190217
outputTree = output[output.keys()[0]]
191218
inputRunBranch = inputTree[inputTree.keys()[0]].array(library="np")
192219
outputRunBranch = outputTree[outputTree.keys()[0]].array(library="np")
193-
self.assertTrue(2*max(inputRunBranch)+1 == max(outputRunBranch))
194-
self.assertTrue(2*len(inputRunBranch) == len(outputRunBranch))
220+
self.assertTrue(2 * max(inputRunBranch) + 1 == max(outputRunBranch))
221+
self.assertTrue(2 * len(inputRunBranch) == len(outputRunBranch))
195222
inputEventBranch = inputTree[inputTree.keys()[1]].array(library="np")
196223
outputEventBranch = outputTree[outputTree.keys()[1]].array(library="np")
197224
self.assertTrue(max(inputEventBranch) == max(outputEventBranch))
198-
self.assertTrue(2*len(inputEventBranch) == len(outputEventBranch))
199-
#shutil.rmtree(tmpdirpath)
200-
225+
self.assertTrue(2 * len(inputEventBranch) == len(outputEventBranch))
226+
# shutil.rmtree(tmpdirpath)

0 commit comments

Comments
 (0)