Skip to content

Commit 1a7d933

Browse files
authored
Merge pull request #52 from DiamondLightSource/estimator_fix
Estimator fix
2 parents 9aac61c + 6273432 commit 1a7d933

File tree

5 files changed

+390
-3
lines changed

5 files changed

+390
-3
lines changed

httomo_backends/methods_database/packages/backends/httomolibgpu/httomolibgpu.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,15 @@ recon:
163163
memory_gpu:
164164
multiplier: None
165165
method: module
166+
LPRec3d_tomobar:
167+
pattern: sinogram
168+
output_dims_change: True
169+
implementation: gpu_cupy
170+
save_result_default: True
171+
padding: False
172+
memory_gpu:
173+
multiplier: None
174+
method: module
166175
SIRT3d_tomobar:
167176
pattern: sinogram
168177
output_dims_change: True

httomo_backends/methods_database/packages/backends/httomolibgpu/supporting_funcs/recon/algorithm.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,16 @@
2323
import math
2424
from typing import Tuple
2525
import numpy as np
26-
from httomo_backends.cufft import CufftType, cufft_estimate_1d
26+
from httomo_backends.cufft import CufftType, cufft_estimate_1d, cufft_estimate_2d
2727

2828
__all__ = [
2929
"_calc_memory_bytes_FBP3d_tomobar",
30+
"_calc_memory_bytes_LPRec3d_tomobar",
3031
"_calc_memory_bytes_SIRT3d_tomobar",
3132
"_calc_memory_bytes_CGLS3d_tomobar",
3233
"_calc_output_dim_FBP2d_astra",
3334
"_calc_output_dim_FBP3d_tomobar",
35+
"_calc_output_dim_LPRec3d_tomobar",
3436
"_calc_output_dim_SIRT3d_tomobar",
3537
"_calc_output_dim_CGLS3d_tomobar",
3638
]
@@ -58,6 +60,10 @@ def _calc_output_dim_FBP3d_tomobar(non_slice_dims_shape, **kwargs):
5860
return __calc_output_dim_recon(non_slice_dims_shape, **kwargs)
5961

6062

63+
def _calc_output_dim_LPRec3d_tomobar(non_slice_dims_shape, **kwargs):
64+
return __calc_output_dim_recon(non_slice_dims_shape, **kwargs)
65+
66+
6167
def _calc_output_dim_SIRT3d_tomobar(non_slice_dims_shape, **kwargs):
6268
return __calc_output_dim_recon(non_slice_dims_shape, **kwargs)
6369

@@ -153,6 +159,112 @@ def _calc_memory_bytes_FBP3d_tomobar(
153159
return (tot_memory_bytes, fixed_amount)
154160

155161

162+
163+
def _calc_memory_bytes_LPRec3d_tomobar(
164+
non_slice_dims_shape: Tuple[int, int],
165+
dtype: np.dtype,
166+
**kwargs,
167+
) -> Tuple[int, int]:
168+
# Based on: https://github.com/dkazanc/ToMoBAR/pull/112/commits/4704ecdc6ded3dd5ec0583c2008aa104f30a8a39
169+
170+
angles_tot = non_slice_dims_shape[0]
171+
DetectorsLengthH = non_slice_dims_shape[1]
172+
SLICES = 200 # dummy multiplier+divisor to pass large batch size threshold
173+
174+
n = DetectorsLengthH
175+
176+
odd_horiz = False
177+
if (n % 2) != 0:
178+
n = n + 1 # dealing with the odd horizontal detector size
179+
odd_horiz = True
180+
181+
eps = 1e-4 # accuracy of usfft
182+
mu = -np.log(eps) / (2 * n * n)
183+
m = int(np.ceil(2 * n * 1 / np.pi * np.sqrt(-mu * np.log(eps) + (mu * n) * (mu * n) / 4)))
184+
185+
center_size = 6144
186+
center_size = min(center_size, n * 2 + m * 2)
187+
188+
oversampling_level = 2 # at least 2 or larger required
189+
ne = oversampling_level * n
190+
padding_m = ne // 2 - n // 2
191+
192+
if "angles" in kwargs:
193+
angles = kwargs["angles"]
194+
sorted_theta_cpu = np.sort(angles)
195+
theta_full_range = abs(sorted_theta_cpu[angles_tot-1] - sorted_theta_cpu[0])
196+
angle_range_pi_count = 1 + int(np.ceil(theta_full_range / math.pi))
197+
else:
198+
angle_range_pi_count = 1 + int(np.ceil(2)) # assume a 2 * PI projection angle range
199+
200+
output_dims = __calc_output_dim_recon(non_slice_dims_shape, **kwargs)
201+
if odd_horiz:
202+
output_dims = tuple(x + 1 for x in output_dims)
203+
204+
in_slice_size = np.prod(non_slice_dims_shape) * dtype.itemsize
205+
padded_in_slice_size = np.prod(non_slice_dims_shape) * np.float32().itemsize
206+
theta_size = angles_tot * np.float32().itemsize
207+
sorted_theta_indices_size = angles_tot * np.int64().itemsize
208+
sorted_theta_size = angles_tot * np.float32().itemsize
209+
recon_output_size = (n + 1) * (n + 1) * np.float32().itemsize if odd_horiz else n * n * np.float32().itemsize # 264
210+
linspace_size = n * np.float32().itemsize
211+
meshgrid_size = 2 * n * n * np.float32().itemsize
212+
phi_size = 6 * n * n * np.float32().itemsize
213+
angle_range_size = center_size * center_size * 1 + angle_range_pi_count * 2 * np.int32().itemsize
214+
c1dfftshift_size = n * np.int8().itemsize
215+
c2dfftshift_slice_size = 4 * n * n * np.int8().itemsize
216+
filter_size = (n // 2 + 1) * np.float32().itemsize
217+
rfftfreq_size = filter_size
218+
scaled_filter_size = filter_size
219+
tmp_p_input_slice = np.prod(non_slice_dims_shape) * np.float32().itemsize
220+
padded_tmp_p_input_slice = angles_tot * (n + padding_m * 2) * dtype.itemsize
221+
rfft_result_size = padded_tmp_p_input_slice
222+
filtered_rfft_result_size = rfft_result_size
223+
rfft_plan_slice_size = cufft_estimate_1d(nx=(n + padding_m * 2),fft_type=CufftType.CUFFT_R2C,batch=angles_tot * SLICES) / SLICES
224+
irfft_result_size = filtered_rfft_result_size
225+
irfft_scratch_memory_size = filtered_rfft_result_size
226+
irfft_plan_slice_size = cufft_estimate_1d(nx=(n + padding_m * 2),fft_type=CufftType.CUFFT_C2R,batch=angles_tot * SLICES) / SLICES
227+
conversion_to_complex_size = np.prod(non_slice_dims_shape) * np.complex64().itemsize / 2
228+
datac_size = np.prod(non_slice_dims_shape) * np.complex64().itemsize / 2
229+
fde_size = (2 * m + 2 * n) * (2 * m + 2 * n) * np.complex64().itemsize / 2
230+
shifted_datac_size = datac_size
231+
fft_result_size = datac_size
232+
backshifted_datac_size = datac_size
233+
scaled_backshifted_datac_size = datac_size
234+
fft_plan_slice_size = cufft_estimate_1d(nx=n,fft_type=CufftType.CUFFT_C2C,batch=angles_tot * SLICES) / SLICES
235+
fde_view_size = 4 * n * n * np.complex64().itemsize / 2
236+
shifted_fde_view_size = fde_view_size
237+
ifft2_scratch_memory_size = fde_view_size
238+
ifft2_plan_slice_size = cufft_estimate_2d(nx=(2 * n),ny=(2 * n),fft_type=CufftType.CUFFT_C2C) / 2
239+
fde2_size = n * n * np.complex64().itemsize / 2
240+
concatenate_size = fde2_size
241+
circular_mask_size = np.prod(output_dims) / 2 * np.int64().itemsize * 4
242+
243+
after_recon_swapaxis_slice = np.prod(non_slice_dims_shape) * np.float32().itemsize
244+
245+
tot_memory_bytes = int(
246+
max(
247+
in_slice_size + padded_in_slice_size + recon_output_size + rfft_plan_slice_size + irfft_plan_slice_size + tmp_p_input_slice + padded_tmp_p_input_slice + rfft_result_size + filtered_rfft_result_size + irfft_result_size + irfft_scratch_memory_size
248+
, in_slice_size + padded_in_slice_size + recon_output_size + rfft_plan_slice_size + irfft_plan_slice_size + tmp_p_input_slice + datac_size + conversion_to_complex_size
249+
, in_slice_size + padded_in_slice_size + recon_output_size + rfft_plan_slice_size + irfft_plan_slice_size + fft_plan_slice_size + fde_size + datac_size + shifted_datac_size + fft_result_size + backshifted_datac_size + scaled_backshifted_datac_size
250+
, in_slice_size + padded_in_slice_size + recon_output_size + rfft_plan_slice_size + irfft_plan_slice_size + fft_plan_slice_size + ifft2_plan_slice_size + shifted_fde_view_size + ifft2_scratch_memory_size
251+
, in_slice_size + padded_in_slice_size + recon_output_size + rfft_plan_slice_size + irfft_plan_slice_size + fft_plan_slice_size + ifft2_plan_slice_size + fde2_size + concatenate_size
252+
, in_slice_size + padded_in_slice_size + recon_output_size + rfft_plan_slice_size + irfft_plan_slice_size + fft_plan_slice_size + ifft2_plan_slice_size + after_recon_swapaxis_slice
253+
)
254+
)
255+
256+
fixed_amount = int(
257+
max(
258+
theta_size + phi_size + linspace_size + meshgrid_size
259+
, theta_size + sorted_theta_indices_size + sorted_theta_size + phi_size + angle_range_size + c1dfftshift_size + c2dfftshift_slice_size + filter_size + rfftfreq_size + scaled_filter_size
260+
, theta_size + sorted_theta_indices_size + sorted_theta_size + phi_size + circular_mask_size
261+
)
262+
)
263+
264+
return (tot_memory_bytes * 1.1, fixed_amount)
265+
266+
267+
156268
def _calc_memory_bytes_SIRT3d_tomobar(
157269
non_slice_dims_shape: Tuple[int, int],
158270
dtype: np.dtype,
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
from os import path
2+
import sys
3+
import traceback
4+
5+
from cupy.cuda import memory_hook
6+
7+
8+
class PeakMemoryLineProfileHook(memory_hook.MemoryHook):
9+
name = 'LineProfileHook'
10+
11+
def __init__(self, running_peak_root_file_names, max_depth=0):
12+
self._memory_frames = {}
13+
self._root = MemoryFrame(None, None)
14+
self._filename = path.abspath(__file__)
15+
self._max_depth = max_depth
16+
self.running_peak_root_file_names = running_peak_root_file_names
17+
18+
# callback
19+
def malloc_preprocess(self, device_id, size, mem_size):
20+
self._create_frame_tree(used_bytes=mem_size)
21+
22+
# callback
23+
def alloc_preprocess(self, device_id, mem_size):
24+
self._create_frame_tree(acquired_bytes=mem_size)
25+
26+
def free_preprocess(self, **kwargs):
27+
mem_size = kwargs.get("mem_size")
28+
if mem_size is not None:
29+
self._create_frame_tree(freed_bytes=mem_size)
30+
31+
def _create_frame_tree(self, used_bytes=0, acquired_bytes=0, freed_bytes=0):
32+
self._root.used_bytes += used_bytes
33+
self._root.acquired_bytes += acquired_bytes
34+
self._root.freed_bytes += freed_bytes
35+
parent = self._root
36+
for depth, stackframe in enumerate(self._extract_stackframes()):
37+
if 0 < self._max_depth <= depth + 1:
38+
break
39+
memory_frame = self._add_frame(parent, stackframe)
40+
memory_frame.used_bytes += used_bytes
41+
memory_frame.acquired_bytes += acquired_bytes
42+
memory_frame.freed_bytes += freed_bytes
43+
parent = memory_frame
44+
45+
def _extract_stackframes(self):
46+
stackframes = traceback.extract_stack()
47+
stackframes = [StackFrame(st) for st in stackframes]
48+
stackframes = [
49+
st for st in stackframes if st.filename != self._filename]
50+
return stackframes
51+
52+
def _key_frame(self, parent, stackframe):
53+
return (parent,
54+
stackframe.filename,
55+
stackframe.lineno,
56+
stackframe.name)
57+
58+
def _add_frame(self, parent, stackframe):
59+
key = self._key_frame(parent, stackframe)
60+
if key in self._memory_frames:
61+
memory_frame = self._memory_frames[key]
62+
else:
63+
memory_frame = MemoryFrame(parent, stackframe)
64+
self._memory_frames[key] = memory_frame
65+
return memory_frame
66+
67+
def print_report(self, file=sys.stdout):
68+
"""Prints a report of line memory profiling."""
69+
line = '_root (%s, %s, %s)\n' % self._root.humanized_bytes()
70+
file.write(line)
71+
72+
running_peak_bytes = [0]
73+
running_used_bytes = [0]
74+
for child in self._root.children:
75+
self._print_frame(child, running_peak_bytes, running_used_bytes, depth=1, file=file)
76+
file.flush()
77+
78+
def _print_frame(self, memory_frame, running_peak_bytes, running_used_bytes, depth=0, file=sys.stdout):
79+
indent = ' ' * (depth * 2)
80+
st = memory_frame.stackframe
81+
used_bytes, acquired_bytes, freed_bytes = memory_frame.humanized_bytes()
82+
83+
humanized_running_peak_bytes = None
84+
humanized_running_used_bytes = None
85+
if path.basename(st.filename) in self.running_peak_root_file_names:
86+
running_used_bytes[0] += memory_frame.used_bytes
87+
running_peak_bytes[0] = max(running_peak_bytes[0], running_used_bytes[0])
88+
running_used_bytes[0] -= memory_frame.freed_bytes
89+
90+
humanized_running_peak_bytes = MemoryFrame.humanized_size(running_peak_bytes[0])
91+
humanized_running_used_bytes = MemoryFrame.humanized_size(running_used_bytes[0])
92+
93+
line = '%s%s:%s:%s (%s, %s, %s, %s, %s)\n' % (
94+
indent, st.filename, st.lineno, st.name,
95+
used_bytes, acquired_bytes, freed_bytes, humanized_running_peak_bytes, humanized_running_used_bytes)
96+
file.write(line)
97+
for child in memory_frame.children:
98+
self._print_frame(child, running_peak_bytes, running_used_bytes, depth=depth + 1, file=file)
99+
100+
101+
class StackFrame(object):
102+
"""Compatibility layer for outputs of traceback.extract_stack().
103+
104+
Attributes:
105+
filename (string): filename
106+
lineno (int): line number
107+
name (string): function name
108+
"""
109+
110+
def __init__(self, obj):
111+
if isinstance(obj, tuple): # < 3.5
112+
self.filename = obj[0]
113+
self.lineno = obj[1]
114+
self.name = obj[2]
115+
else: # >= 3.5 FrameSummary
116+
self.filename = obj.filename
117+
self.lineno = obj.lineno
118+
self.name = obj.name
119+
120+
121+
class MemoryFrame(object):
122+
"""A single stack frame along with sum of memory usage at the frame.
123+
124+
Attributes:
125+
stackframe (FrameSummary): stackframe from traceback.extract_stack().
126+
parent (MemoryFrame): parent frame, that is, caller.
127+
children (list of MemoryFrame): child frames, that is, callees.
128+
used_bytes (int): memory bytes that users used from CuPy memory pool.
129+
acquired_bytes (int): memory bytes that CuPy memory pool acquired
130+
freed_bytes (int): memory bytes that were released to the CuPy memory pool
131+
from GPU device.
132+
"""
133+
134+
def __init__(self, parent, stackframe):
135+
self.stackframe = stackframe
136+
self.children = []
137+
self._set_parent(parent)
138+
self.used_bytes = 0
139+
self.acquired_bytes = 0
140+
self.freed_bytes = 0
141+
142+
def humanized_bytes(self):
143+
used_bytes = MemoryFrame.humanized_size(self.used_bytes)
144+
acquired_bytes = MemoryFrame.humanized_size(self.acquired_bytes)
145+
freed_bytes = MemoryFrame.humanized_size(self.freed_bytes)
146+
return (used_bytes, acquired_bytes, freed_bytes)
147+
148+
@staticmethod
149+
def humanized_size(size):
150+
for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E']:
151+
if size < 1024.0:
152+
return '%3.2f%sB' % (size, unit)
153+
size /= 1024.0
154+
return '%.2f%sB' % (size, 'Z')
155+
156+
def _set_parent(self, parent):
157+
if parent and parent not in parent.children:
158+
self.parent = parent
159+
parent.children.append(self)

tests/conftest.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,36 @@ def ensure_clean_memory():
2121
cache.clear()
2222

2323

24+
def pytest_configure(config):
25+
config.addinivalue_line(
26+
"markers", "full: mark tests to run more GPU-memory consuming tests"
27+
)
28+
29+
30+
def pytest_addoption(parser):
31+
parser.addoption(
32+
"--full",
33+
action="store_true",
34+
default=False,
35+
help="run more GPU memory hungry tests",
36+
)
37+
38+
39+
def pytest_collection_modifyitems(config, items):
40+
if config.getoption("--full"):
41+
skip_other = pytest.mark.skip(reason="not a GPU hungry test")
42+
for item in items:
43+
if "full" not in item.keywords:
44+
item.add_marker(skip_other)
45+
else:
46+
skip_perf = pytest.mark.skip(
47+
reason="this GPU memory hungry test, use '--full' to run"
48+
)
49+
for item in items:
50+
if "full" in item.keywords:
51+
item.add_marker(skip_perf)
52+
53+
2454
@pytest.fixture(scope="session")
2555
def test_data_path():
2656
return CUR_DIR / "test_data"

0 commit comments

Comments
 (0)