Skip to content

Commit 487dff2

Browse files
authored
Merge pull request #98 from OpenGATE/update_uproot
Try with a version of uproot <5.7.0 to avoid error
2 parents 26d2e80 + a6dac79 commit 487dff2

2 files changed

Lines changed: 100 additions & 73 deletions

File tree

gatetools/merge_root.py

Lines changed: 100 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,16 @@
1919
# -----------------------------------------------------------------------------
2020

2121

22-
import gatetools as gt
22+
import logging
23+
2324
import numpy as np
2425
import tqdm
25-
import logging
26-
logger=logging.getLogger(__name__)
26+
import uproot
27+
28+
import gatetools as gt
29+
30+
logger = logging.getLogger(__name__)
31+
2732

2833
def unicity(root_keys):
2934
"""
@@ -38,34 +43,30 @@ def unicity(root_keys):
3843
name = name[0]
3944
if not name in root_array:
4045
root_array.append(name)
41-
return(root_array)
46+
return root_array
47+
4248

4349
def merge_root(rootfiles, outputfile, incrementRunId=False):
4450
"""
4551
Merge root files in output files
4652
"""
47-
try:
48-
import uproot
49-
except:
50-
print("uproot4 is mandatory to merge root file. Please, do:")
51-
print("pip install uproot")
5253

5354
uproot.default_library = "np"
5455

5556
out = uproot.recreate(outputfile)
5657

57-
#Previous ID values to be able to increment runIn or EventId
58+
# Previous ID values to be able to increment runIn or EventId
5859
previousId = {}
5960

60-
#create the dict reading all input root files
61-
trees = {} #TTree with TBranch
62-
hists = {} #Directory with THist
63-
pbar = tqdm.tqdm(total = len(rootfiles))
61+
# create the dict reading all input root files
62+
trees = {} # TTree with TBranch
63+
hists = {} # Directory with THist
64+
pbar = tqdm.tqdm(total=len(rootfiles))
6465
for file in rootfiles:
6566
root = uproot.open(file)
6667
root_keys = unicity(root.keys())
6768
for tree in root_keys:
68-
if hasattr(root[tree], 'keys'):
69+
if hasattr(root[tree], "keys"):
6970
if not tree in trees:
7071
trees[tree] = {}
7172
trees[tree]["rootDictType"] = {}
@@ -75,67 +76,95 @@ def merge_root(rootfiles, outputfile, incrementRunId=False):
7576
hists[tree]["rootDictValue"] = {}
7677
previousId[tree] = {}
7778
for branch in root[tree].keys():
78-
if isinstance(root[tree],uproot.reading.ReadOnlyDirectory):
79+
if isinstance(root[tree], uproot.reading.ReadOnlyDirectory):
7980
array = root[tree][branch].values()
8081
if len(array) > 0:
8182
branchName = tree + "/" + branch
82-
if type(array[0]) is type('c'):
83+
if type(array[0]) is type("c"):
8384
array = np.array([0 for xi in array])
8485
if not branchName in hists[tree]["rootDictType"]:
85-
hists[tree]["rootDictType"][branchName] = root[tree][branch].to_numpy()
86-
hists[tree]["rootDictValue"][branchName] = np.zeros(array.shape)
86+
hists[tree]["rootDictType"][branchName] = root[tree][
87+
branch
88+
].to_numpy()
89+
hists[tree]["rootDictValue"][branchName] = np.zeros(
90+
array.shape
91+
)
8792
hists[tree]["rootDictValue"][branchName] += array
8893
else:
8994
array = root[tree][branch].array(library="np")
90-
if len(array) > 0 and not (type(array[0]) is type(np.ndarray(2,))):
91-
if type(array[0]) is type('c'):
95+
if len(array) > 0 and (
96+
type(array[0])
97+
is not type(
98+
np.ndarray(
99+
2,
100+
)
101+
)
102+
):
103+
if type(array[0]) is type("c"):
92104
array = np.array([0 for xi in array])
93-
if not branch in trees[tree]["rootDictType"]:
105+
if branch not in trees[tree]["rootDictType"]:
94106
trees[tree]["rootDictType"][branch] = type(array[0])
95107
trees[tree]["rootDictValue"][branch] = np.array([])
96-
if (not incrementRunId and branch.startswith('eventID')) or (incrementRunId and branch.startswith('runID')):
97-
if not branch in previousId[tree]:
108+
if (
109+
not incrementRunId and branch.startswith("eventID")
110+
) or (incrementRunId and branch.startswith("runID")):
111+
if branch not in previousId[tree]:
98112
previousId[tree][branch] = 0
99113
array += previousId[tree][branch]
100-
previousId[tree][branch] = max(array) +1
101-
trees[tree]["rootDictValue"][branch] = np.append(trees[tree]["rootDictValue"][branch], array)
114+
previousId[tree][branch] = max(array) + 1
115+
trees[tree]["rootDictValue"][branch] = np.append(
116+
trees[tree]["rootDictValue"][branch], array
117+
)
102118
pbar.update(1)
103119
pbar.close()
104120

105-
#Set the dict in the output root file
121+
# Set the dict in the output root file
106122
for tree in trees:
107-
if not trees[tree]["rootDictValue"] == {} or not trees[tree]["rootDictType"] == {}:
108-
#out.mktree(tree, trees[tree]["rootDictType"])
109-
out[tree] = trees[tree]["rootDictValue"]
123+
if (
124+
not trees[tree]["rootDictValue"] == {}
125+
or not trees[tree]["rootDictType"] == {}
126+
):
127+
out.mktree(tree, trees[tree]["rootDictType"])
128+
out[tree].extend(trees[tree]["rootDictValue"])
110129
for hist in hists:
111-
if not hists[hist]["rootDictValue"] == {} or not hists[hist]["rootDictType"] == {}:
130+
if (
131+
not hists[hist]["rootDictValue"] == {}
132+
or not hists[hist]["rootDictType"] == {}
133+
):
112134
for branch in hists[hist]["rootDictValue"]:
113135
for i in range(len(hists[hist]["rootDictValue"][branch])):
114-
hists[hist]["rootDictType"][branch][0][i] = hists[hist]["rootDictValue"][branch][i]
115-
out[branch[:-2]] = hists[hist]["rootDictType"][branch]
136+
hists[hist]["rootDictType"][branch][0][i] = hists[hist][
137+
"rootDictValue"
138+
][branch][i]
139+
out.mktree(branch[:-2], hists[hist]["rootDictType"][branch])
140+
out[branch[:-2]].extend(hists[hist]["rootDictType"][branch])
116141

117142

118143
#####################################################################################
119-
import unittest
120-
import tempfile
121-
import wget
122144
import os
123145
import shutil
146+
import tempfile
147+
import unittest
148+
124149
import numpy as np
150+
import uproot
151+
import wget
152+
125153
from .logging_conf import LoggedTestCase
126154

155+
127156
class Test_MergeRoot(LoggedTestCase):
128157
def test_merge_root_phsp(self):
129-
try:
130-
import uproot
131-
except:
132-
print("uproot4 is mandatory to merge root file. Please, do:")
133-
print("pip install uproot")
134-
135-
logger.info('Test_MergeRoot test_merge_root_phsp')
158+
logger.info("Test_MergeRoot test_merge_root_phsp")
136159
tmpdirpath = tempfile.mkdtemp()
137-
filenameRoot = wget.download("https://gitlab.in2p3.fr/opengate/gatetools_data/-/raw/master/phsp.root?inline=false", out=tmpdirpath, bar=None)
138-
gt.merge_root([filenameRoot, filenameRoot], os.path.join(tmpdirpath, "output.root"))
160+
filenameRoot = wget.download(
161+
"https://gitlab.in2p3.fr/opengate/gatetools_data/-/raw/master/phsp.root?inline=false",
162+
out=tmpdirpath,
163+
bar=None,
164+
)
165+
gt.merge_root(
166+
[filenameRoot, filenameRoot], os.path.join(tmpdirpath, "output.root")
167+
)
139168
input = uproot.open(filenameRoot)
140169
output = uproot.open(os.path.join(tmpdirpath, "output.root"))
141170
self.assertTrue(output.keys() == input.keys())
@@ -144,57 +173,56 @@ def test_merge_root_phsp(self):
144173
self.assertTrue(outputTree.keys() == inputTree.keys())
145174
inputBranch = inputTree[inputTree.keys()[1]].array(library="np")
146175
outputBranch = outputTree[outputTree.keys()[1]].array(library="np")
147-
self.assertTrue(2*len(inputBranch) == len(outputBranch))
176+
self.assertTrue(2 * len(inputBranch) == len(outputBranch))
148177
shutil.rmtree(tmpdirpath)
149178

150179
def test_merge_root_pet_incrementEvent(self):
151-
try:
152-
import uproot
153-
except:
154-
print("uproot4 is mandatory to merge root file. Please, do:")
155-
print("pip install uproot")
156-
157-
logger.info('Test_MergeRoot test_merge_root_pet')
180+
logger.info("Test_MergeRoot test_merge_root_pet")
158181
tmpdirpath = tempfile.mkdtemp()
159-
filenameRoot = wget.download("https://gitlab.in2p3.fr/opengate/gatetools_data/-/raw/master/pet.root?inline=false", out=tmpdirpath, bar=None)
160-
gt.merge_root([filenameRoot, filenameRoot], os.path.join(tmpdirpath, "output.root"))
182+
filenameRoot = wget.download(
183+
"https://gitlab.in2p3.fr/opengate/gatetools_data/-/raw/master/pet.root?inline=false",
184+
out=tmpdirpath,
185+
bar=None,
186+
)
187+
gt.merge_root(
188+
[filenameRoot, filenameRoot], os.path.join(tmpdirpath, "output.root")
189+
)
161190
input = uproot.open(filenameRoot)
162191
output = uproot.open(os.path.join(tmpdirpath, "output.root"))
163192
inputTree = input[input.keys()[0]]
164193
outputTree = output[output.keys()[0]]
165194
inputRunBranch = inputTree[inputTree.keys()[0]].array(library="np")
166195
outputRunBranch = outputTree[outputTree.keys()[0]].array(library="np")
167196
self.assertTrue(max(inputRunBranch) == max(outputRunBranch))
168-
self.assertTrue(2*len(inputRunBranch) == len(outputRunBranch))
197+
self.assertTrue(2 * len(inputRunBranch) == len(outputRunBranch))
169198
inputEventBranch = inputTree[inputTree.keys()[1]].array(library="np")
170199
outputEventBranch = outputTree[outputTree.keys()[1]].array(library="np")
171-
self.assertTrue(2*max(inputEventBranch)+1 == max(outputEventBranch))
172-
self.assertTrue(2*len(inputEventBranch) == len(outputEventBranch))
200+
self.assertTrue(2 * max(inputEventBranch) + 1 == max(outputEventBranch))
201+
self.assertTrue(2 * len(inputEventBranch) == len(outputEventBranch))
173202
shutil.rmtree(tmpdirpath)
174203

175204
def test_merge_root_pet_incrementRun(self):
176-
try:
177-
import uproot
178-
except:
179-
print("uproot4 is mandatory to merge root file. Please, do:")
180-
print("pip install uproot")
181-
182-
logger.info('Test_MergeRoot test_merge_root_pet')
205+
logger.info("Test_MergeRoot test_merge_root_pet")
183206
tmpdirpath = tempfile.mkdtemp()
184207
print(tmpdirpath)
185-
filenameRoot = wget.download("https://gitlab.in2p3.fr/opengate/gatetools_data/-/raw/master/pet.root?inline=false", out=tmpdirpath, bar=None)
186-
gt.merge_root([filenameRoot, filenameRoot], os.path.join(tmpdirpath, "output.root"), True)
208+
filenameRoot = wget.download(
209+
"https://gitlab.in2p3.fr/opengate/gatetools_data/-/raw/master/pet.root?inline=false",
210+
out=tmpdirpath,
211+
bar=None,
212+
)
213+
gt.merge_root(
214+
[filenameRoot, filenameRoot], os.path.join(tmpdirpath, "output.root"), True
215+
)
187216
input = uproot.open(filenameRoot)
188217
output = uproot.open(os.path.join(tmpdirpath, "output.root"))
189218
inputTree = input[input.keys()[0]]
190219
outputTree = output[output.keys()[0]]
191220
inputRunBranch = inputTree[inputTree.keys()[0]].array(library="np")
192221
outputRunBranch = outputTree[outputTree.keys()[0]].array(library="np")
193-
self.assertTrue(2*max(inputRunBranch)+1 == max(outputRunBranch))
194-
self.assertTrue(2*len(inputRunBranch) == len(outputRunBranch))
222+
self.assertTrue(2 * max(inputRunBranch) + 1 == max(outputRunBranch))
223+
self.assertTrue(2 * len(inputRunBranch) == len(outputRunBranch))
195224
inputEventBranch = inputTree[inputTree.keys()[1]].array(library="np")
196225
outputEventBranch = outputTree[outputTree.keys()[1]].array(library="np")
197226
self.assertTrue(max(inputEventBranch) == max(outputEventBranch))
198-
self.assertTrue(2*len(inputEventBranch) == len(outputEventBranch))
199-
#shutil.rmtree(tmpdirpath)
200-
227+
self.assertTrue(2 * len(inputEventBranch) == len(outputEventBranch))
228+
# shutil.rmtree(tmpdirpath)

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,4 +85,3 @@ gt_digi_mac_converter = "gatetools.bin.gt_digi_mac_converter:convert_macro"
8585
gate_split_and_run = "gatetools.clustertools.gate_split_and_run:runJobs"
8686
opengate_run = "gatetools.clustertools.opengate_run:runJobs_click"
8787
computeElapsedTime = "gatetools.clustertools.computeElapsedTime:computeElapsedTime"
88-

0 commit comments

Comments
 (0)