Skip to content

Commit a5901fa

Browse files
authored
Merge pull request #18 from desy-ml/pytorch-histogramdd
Replace custom histogramdd() with torch.histogramdd()
2 parents bf326a2 + 928b12b commit a5901fa

File tree

9 files changed

+269
-295
lines changed

9 files changed

+269
-295
lines changed

benchmark/cheetah/cheetah.ipynb

Lines changed: 58 additions & 28 deletions
Large diffs are not rendered by default.

cheetah/accelerator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,7 @@ def reading(self):
745745
image = dist.pdf(pos)
746746
image = np.flipud(image.T)
747747
elif isinstance(self.read_beam, ParticleBeam):
748-
image, _ = utils.histogramdd(
748+
image, _ = torch.histogramdd(
749749
torch.stack((self.read_beam.xs, self.read_beam.ys)),
750750
bins=self.pixel_bin_edges,
751751
)

cheetah/utils.py

Lines changed: 2 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ def ocelot2cheetah(element, warnings=True):
113113
elif isinstance(element, oc.Monitor) and "BSC" in element.id:
114114
if warnings:
115115
print(
116-
"WARNING: Diagnostic screen was converted with default screen properties."
116+
"WARNING: Diagnostic screen was converted with default screen"
117+
" properties."
117118
)
118119
return acc.Screen((2448, 2040), (3.5488e-6, 2.5003e-6), name=element.id)
119120
elif isinstance(element, oc.Monitor) and "BPM" in element.id:
@@ -141,127 +142,3 @@ def subcell_of_ocelot(cell, start, end):
141142
break
142143

143144
return subcell
144-
145-
146-
_range = range
147-
148-
149-
def histogramdd(sample, bins=None, range=None, weights=None, remove_overflow=True):
150-
"""
151-
Pytorch version of n-dimensional histogram.
152-
153-
Taken from https://github.com/miranov25/RootInteractive/blob/b54446e09072e90e17f3da72d5244a20c8fdd209/RootInteractive/Tools/Histograms/histogramdd.py
154-
"""
155-
edges = None
156-
device = None
157-
custom_edges = False
158-
D, N = sample.shape
159-
if device == None:
160-
device = sample.device
161-
if bins == None:
162-
if edges == None:
163-
bins = 10
164-
custom_edges = False
165-
else:
166-
try:
167-
bins = edges.size(1) - 1
168-
except AttributeError:
169-
bins = torch.empty(D)
170-
for i in _range(len(edges)):
171-
bins[i] = edges[i].size(0) - 1
172-
bins = bins.to(device)
173-
custom_edges = True
174-
try:
175-
M = bins.size(0)
176-
if M != D:
177-
raise ValueError(
178-
"The dimension of bins must be equal to the dimension of sample x."
179-
)
180-
except AttributeError:
181-
# bins is either an integer or a list
182-
if type(bins) == int:
183-
bins = torch.full([D], bins, dtype=torch.long, device=device)
184-
elif torch.is_tensor(bins[0]):
185-
custom_edges = True
186-
edges = bins
187-
bins = torch.empty(D, dtype=torch.long)
188-
for i in _range(len(edges)):
189-
bins[i] = edges[i].size(0) - 1
190-
bins = bins.to(device)
191-
else:
192-
bins = torch.as_tensor(bins)
193-
if bins.dim() == 2:
194-
custom_edges = True
195-
edges = bins
196-
bins = torch.full([D], bins.size(1) - 1, dtype=torch.long, device=device)
197-
if custom_edges:
198-
use_old_edges = False
199-
if not torch.is_tensor(edges):
200-
use_old_edges = True
201-
edges_old = edges
202-
m = max(i.size(0) for i in edges)
203-
tmp = torch.empty([D, m], device=edges[0].device)
204-
for i in _range(D):
205-
s = edges[i].size(0)
206-
tmp[i, :] = edges[i][-1]
207-
tmp[i, :s] = edges[i][:]
208-
edges = tmp.to(device)
209-
k = torch.searchsorted(edges, sample)
210-
k = torch.min(k, (bins + 1).reshape(-1, 1))
211-
if use_old_edges:
212-
edges = edges_old
213-
else:
214-
edges = torch.unbind(edges)
215-
else:
216-
if range == None: # range is not defined
217-
range = torch.empty(2, D, device=device)
218-
if N == 0: # Empty histogram
219-
range[0, :] = 0
220-
range[1, :] = 1
221-
else:
222-
range[0, :] = torch.min(sample, 1)[0]
223-
range[1, :] = torch.max(sample, 1)[0]
224-
elif not torch.is_tensor(range): # range is a tuple
225-
r = torch.empty(2, D)
226-
for i in _range(D):
227-
if range[i] is not None:
228-
r[:, i] = torch.as_tensor(range[i])
229-
else:
230-
if N == 0: # Edge case: empty histogram
231-
r[0, i] = 0
232-
r[1, i] = 1
233-
r[0, i] = torch.min(sample[:, i])[0]
234-
r[1, i] = torch.max(sample[:, i])[0]
235-
range = r.to(device=device, dtype=sample.dtype)
236-
singular_range = torch.eq(
237-
range[0], range[1]
238-
) # If the range consists of only one point, pad it up.
239-
range[0, singular_range] -= 0.5
240-
range[1, singular_range] += 0.5
241-
edges = [
242-
torch.linspace(range[0, i], range[1, i], bins[i] + 1)
243-
for i in _range(len(bins))
244-
]
245-
tranges = torch.empty_like(range)
246-
tranges[1, :] = bins / (range[1, :] - range[0, :])
247-
tranges[0, :] = 1 - range[0, :] * tranges[1, :]
248-
k = torch.addcmul(
249-
tranges[0, :].reshape(-1, 1), sample, tranges[1, :].reshape(-1, 1)
250-
).long() # Get the right index
251-
k = torch.max(
252-
k, torch.zeros([], device=device, dtype=torch.long)
253-
) # Underflow bin
254-
k = torch.min(k, (bins + 1).reshape(-1, 1))
255-
256-
multiindex = torch.ones_like(bins)
257-
multiindex[1:] = torch.cumprod(torch.flip(bins[1:], [0]) + 2, -1).long()
258-
multiindex = torch.flip(multiindex, [0])
259-
l = torch.sum(k * multiindex.reshape(-1, 1), 0)
260-
hist = torch.bincount(
261-
l, minlength=(multiindex[0] * (bins[0] + 2)).item(), weights=weights
262-
)
263-
hist = hist.reshape(tuple(bins + 2))
264-
if remove_overflow:
265-
core = D * (slice(1, -1),)
266-
hist = hist[core]
267-
return hist, edges

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
setup(
99
name="cheetah-accelerator",
10-
version="0.5.16",
10+
version="0.5.17",
1111
author="Jan Kaiser & Oliver Stein",
1212
author_email="[email protected]",
1313
url="https://github.com/desy-ml/cheetah",

test/intro.ipynb

Lines changed: 69 additions & 38 deletions
Large diffs are not rendered by default.

test/ocelot_vs_joss.ipynb

Lines changed: 28 additions & 31 deletions
Large diffs are not rendered by default.

test/olivers_test.ipynb

Lines changed: 32 additions & 35 deletions
Large diffs are not rendered by default.

test/testmore.ipynb

Lines changed: 67 additions & 26 deletions
Large diffs are not rendered by default.

test/testtime.ipynb

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
"output_type": "stream",
1818
"text": [
1919
"[INFO ] : : \u001b[0mbeam.py: module NUMBA is not installed. Install it to speed up calculation\u001b[0m\n",
20-
"[INFO ] : : \u001b[0mhigh_order.py: module NUMBA is not installed. Install it to speed up calculation\u001b[0m\n",
21-
"[INFO ] : : : : : : : : : : : \u001b[0mradiation_py.py: module NUMBA is not installed. Install it to speed up calculation\u001b[0m\n",
20+
"[INFO ] : : : : : : : : \u001b[0mhigh_order.py: module NUMBA is not installed. Install it to speed up calculation\u001b[0m\n",
21+
"[INFO ] \u001b[0mradiation_py.py: module NUMBA is not installed. Install it to speed up calculation\u001b[0m\n",
22+
"[INFO ] \u001b[0mradiation_py.py: module NUMBA is not installed. Install it to speed up calculation\u001b[0m\n",
2223
"[INFO ] \u001b[0mcsr.py: module NUMBA is not installed. Install it to speed up calculation\u001b[0m\n",
2324
"[INFO ] \u001b[0mcsr.py: module PYFFTW is not installed. Install it to speed up calculation.\u001b[0m\n",
2425
"[INFO ] \u001b[0mcsr.py: module NUMEXPR is not installed. Install it to speed up calculation\u001b[0m\n",
@@ -48,8 +49,8 @@
4849
"metadata": {},
4950
"outputs": [],
5051
"source": [
51-
"beam1 = cheetah.ParameterBeam.from_astra(\"../distributions/ACHIP_EA1_2021.1351.001\")\n",
52-
"beam2 = cheetah.ParticleBeam.from_astra(\"../distributions/ACHIP_EA1_2021.1351.001\")"
52+
"beam1 = cheetah.ParameterBeam.from_astra(\"../benchmark/cheetah/ACHIP_EA1_2021.1351.001\")\n",
53+
"beam2 = cheetah.ParticleBeam.from_astra(\"../benchmark/cheetah/ACHIP_EA1_2021.1351.001\")"
5354
]
5455
},
5556
{
@@ -77,7 +78,7 @@
7778
"name": "stdout",
7879
"output_type": "stream",
7980
"text": [
80-
"91.7 µs ± 221 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
81+
"86.6 µs ± 816 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
8182
]
8283
}
8384
],
@@ -95,7 +96,7 @@
9596
"name": "stdout",
9697
"output_type": "stream",
9798
"text": [
98-
"902 µs ± 21.7 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
99+
"955 µs ± 15.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
99100
]
100101
}
101102
],
@@ -127,7 +128,7 @@
127128
"name": "stdout",
128129
"output_type": "stream",
129130
"text": [
130-
"z = 42.34949999999999 / 42.34949999999999 : applied: 3.36 ms ± 1.63 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
131+
"z = 42.34949999999999 / 42.34949999999999. Applied: 2.82 ms ± 1.33 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
131132
]
132133
}
133134
],
@@ -146,7 +147,7 @@
146147
],
147148
"metadata": {
148149
"kernelspec": {
149-
"display_name": "Python 3.9.13 ('cheetah-test')",
150+
"display_name": "rl39",
150151
"language": "python",
151152
"name": "python3"
152153
},
@@ -160,12 +161,12 @@
160161
"name": "python",
161162
"nbconvert_exporter": "python",
162163
"pygments_lexer": "ipython3",
163-
"version": "3.9.13"
164+
"version": "3.9.15"
164165
},
165166
"orig_nbformat": 4,
166167
"vscode": {
167168
"interpreter": {
168-
"hash": "947d9ee3b458d99f0eb80dac10135f6d6a35887bd0ce9fb941e727a4631c373a"
169+
"hash": "343fe3b89e2d7877d61a0509fd880204236e5c07449e4c121f53f2530ef83fc9"
169170
}
170171
}
171172
},

0 commit comments

Comments
 (0)