Skip to content

Commit 21208a9

Browse files
authored
Merge pull request #221 from DiamondLightSource/rescaletoint_C
Making rescale_to_int GPU only function
2 parents 3e42834 + 02696ca commit 21208a9

File tree

3 files changed

+28
-74
lines changed

3 files changed

+28
-74
lines changed

httomolibgpu/misc/rescale.py

Lines changed: 26 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -24,32 +24,32 @@
2424
from httomolibgpu import cupywrapper
2525

2626
cp = cupywrapper.cp
27-
cupy_run = cupywrapper.cupy_run
2827

2928
from typing import Literal, Optional, Tuple, Union
3029

3130
from httomolibgpu.misc.supp_func import data_checker
3231

32+
3333
__all__ = [
3434
"rescale_to_int",
3535
]
3636

3737

3838
def rescale_to_int(
39-
data: Union[np.ndarray, cp.ndarray],
39+
data: cp.ndarray,
4040
perc_range_min: float = 0.0,
4141
perc_range_max: float = 100.0,
4242
bits: Literal[8, 16, 32] = 8,
4343
glob_stats: Optional[Tuple[float, float, float, int]] = None,
44-
) -> Union[np.ndarray, cp.ndarray]:
44+
) -> cp.ndarray:
4545
"""
4646
Rescales the data given as float32 type and converts it into the range of an unsigned integer type
4747
with the given number of bits. For more detailed information and examples, see :ref:`method_rescale_to_int`.
4848
4949
Parameters
5050
----------
51-
data : Union[np.ndarray, cp.ndarray]
52-
Input data as a numpy or cupy array (the function is cpu-gpu agnostic)
51+
data : cp.ndarray
52+
Input data as a cupy array
5353
perc_range_min: float, optional
5454
The lower cutoff point in the input data, in percent of the data range (defaults to 0).
5555
The lower bound is computed as min + perc_range_min/100*(max-min)
@@ -69,7 +69,7 @@ def rescale_to_int(
6969
7070
Returns
7171
-------
72-
Union[np.ndarray, cp.ndarray]
72+
cp.ndarray
7373
The original data, clipped to the range specified with the perc_range_min and
7474
perc_range_max, and scaled to the full range of the output integer type
7575
"""
@@ -82,18 +82,13 @@ def rescale_to_int(
8282

8383
data = data_checker(data, verbosity=True, method_name="rescale_to_int")
8484

85-
if cupy_run:
86-
xp = cp.get_array_module(data)
87-
else:
88-
import numpy as xp
89-
9085
# get the min and max integer values of the output type
91-
output_min = xp.iinfo(output_dtype).min
92-
output_max = xp.iinfo(output_dtype).max
86+
output_min = cp.iinfo(output_dtype).min
87+
output_max = cp.iinfo(output_dtype).max
9388

9489
if not isinstance(glob_stats, tuple):
95-
min_value = float(xp.min(data))
96-
max_value = float(xp.max(data))
90+
min_value = float(cp.min(data))
91+
max_value = float(cp.max(data))
9792
else:
9893
min_value = glob_stats[0]
9994
max_value = glob_stats[1]
@@ -102,32 +97,21 @@ def rescale_to_int(
10297
input_min = (perc_range_min * (range_intensity) / 100) + min_value
10398
input_max = (perc_range_max * (range_intensity) / 100) + min_value
10499

100+
factor = cp.float32(1.0)
105101
if (input_max - input_min) != 0.0:
106-
factor = xp.float32((output_max - output_min) / (input_max - input_min))
107-
else:
108-
factor = 1.0
109-
110-
res = xp.empty(data.shape, dtype=output_dtype)
111-
if xp.__name__ == "numpy":
112-
if input_max == pow(2, 32):
113-
input_max -= 1
114-
res = np.copy(data.astype(float))
115-
res[data.astype(float) < input_min] = int(input_min)
116-
res[data.astype(float) > input_max] = int(input_max)
117-
res -= input_min
118-
res *= factor
119-
res = output_dtype(res)
120-
else:
121-
rescale_kernel = cp.ElementwiseKernel(
122-
"T x, raw T input_min, raw T input_max, raw T factor",
123-
"O out",
124-
"""
125-
T x_clean = isnan(x) || isinf(x) ? T(0) : x;
126-
T x_clipped = x_clean < input_min ? input_min : (x_clean > input_max ? input_max : x_clean);
127-
T x_rebased = x_clipped - input_min;
128-
out = O(x_rebased * factor);
129-
""",
130-
"rescale_to_int",
131-
)
132-
rescale_kernel(data, input_min, input_max, factor, res)
102+
factor = cp.float32((output_max - output_min) / (input_max - input_min))
103+
104+
res = cp.empty(data.shape, dtype=output_dtype)
105+
rescale_kernel = cp.ElementwiseKernel(
106+
"T x, raw T input_min, raw T input_max, raw T factor",
107+
"O out",
108+
"""
109+
T x_clean = isnan(x) || isinf(x) ? T(0) : x;
110+
T x_clipped = x_clean < input_min ? input_min : (x_clean > input_max ? input_max : x_clean);
111+
T x_rebased = x_clipped - input_min;
112+
out = O(x_rebased * factor);
113+
""",
114+
"rescale_to_int",
115+
)
116+
rescale_kernel(data, input_min, input_max, factor, res)
133117
return res

httomolibgpu/misc/supp_func.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def data_checker(
162162
) -> bool:
163163
"""
164164
Function that performs the variety of checks on input data, in some cases also correct the data and prints warnings.
165-
Currently it checks for: the presence of infs and nans in data.
165+
Currently it checks for: the presence of infs and nans in data.
166166
167167
Parameters
168168
----------

tests/test_misc/test_rescale.py

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,11 @@ def test_rescale_no_change():
1414
res_dev = rescale_to_int(
1515
data_dev, bits=8, glob_stats=(0.0, 255.0, 100.0, data.size)
1616
)
17-
res_cpu = rescale_to_int(data, bits=8, glob_stats=(0.0, 255.0, 100.0, data.size))
1817

1918
res = cp.asnumpy(res_dev).astype(np.float32)
2019

2120
assert res_dev.dtype == np.uint8
22-
assert res_cpu.dtype == np.uint8
2321
np.testing.assert_array_equal(res, data)
24-
np.testing.assert_array_equal(res, res_cpu)
2522

2623

2724
@pytest.mark.parametrize("bits", [8, 16, 32])
@@ -31,28 +28,22 @@ def test_rescale_no_change_no_stats(bits: Literal[8, 16, 32]):
3128
data[13, 1] = (2**bits) - 1
3229
data_dev = cp.asarray(data)
3330
res_dev = rescale_to_int(data_dev, bits=bits)
34-
res_cpu = rescale_to_int(data, bits=bits)
3531

3632
res_dev_float32 = cp.asnumpy(res_dev).astype(np.float32)
3733

3834
assert res_dev.dtype.itemsize == bits // 8
3935
np.testing.assert_array_equal(res_dev_float32, data)
40-
assert res_cpu.dtype.itemsize == bits // 8
41-
res_cpu_float32 = np.float32(res_cpu)
42-
np.testing.assert_array_equal(res_dev_float32, res_cpu_float32)
4336

4437

4538
def test_rescale_double():
4639
data = np.ones((30, 50), dtype=np.float32)
4740

4841
data_dev = cp.asarray(data)
4942
res_dev = rescale_to_int(data_dev, bits=8, glob_stats=(0, 2, 100, data.size))
50-
res_cpu = rescale_to_int(data, bits=8, glob_stats=(0, 2, 100, data.size))
5143

5244
res = cp.asnumpy(res_dev).astype(np.float32)
5345

5446
np.testing.assert_array_almost_equal(res, 127.0)
55-
np.testing.assert_array_almost_equal(res_cpu, 127.0)
5647

5748

5849
def test_rescale_handles_nan_inf():
@@ -63,25 +54,21 @@ def test_rescale_handles_nan_inf():
6354

6455
data_dev = cp.asarray(data)
6556
res_dev = rescale_to_int(data_dev, bits=8, glob_stats=(0, 2, 100, data.size))
66-
res_cpu = rescale_to_int(data, bits=8, glob_stats=(0, 2, 100, data.size))
6757

6858
res = cp.asnumpy(res_dev).astype(np.float32)
6959

7060
np.testing.assert_array_equal(res[0, 0:3], 0.0)
71-
np.testing.assert_array_equal(res_cpu[0, 0:3], 0.0)
7261

7362

7463
def test_rescale_double_offset():
7564
data = np.ones((30, 50), dtype=np.float32) + 10
7665

7766
data_dev = cp.asarray(data)
7867
res_dev = rescale_to_int(data_dev, bits=8, glob_stats=(10, 12, 100, data.size))
79-
res_cpu = rescale_to_int(data, bits=8, glob_stats=(10, 12, 100, data.size))
8068

8169
res = cp.asnumpy(res_dev).astype(np.float32)
8270

8371
np.testing.assert_array_almost_equal(res, 127.0)
84-
np.testing.assert_array_almost_equal(res_cpu, 127.0)
8572

8673

8774
@pytest.mark.parametrize("bits", [8, 16])
@@ -99,14 +86,6 @@ def test_rescale_double_offset_min_percentage(bits: Literal[8, 16, 32]):
9986
perc_range_max=90.0,
10087
)
10188

102-
res_cpu = rescale_to_int(
103-
data,
104-
bits=bits,
105-
glob_stats=(10, 20, 100, data.size),
106-
perc_range_min=10.0,
107-
perc_range_max=90.0,
108-
)
109-
11089
res = cp.asnumpy(res_dev).astype(np.float32)
11190

11291
max = (2**bits) - 1
@@ -116,22 +95,13 @@ def test_rescale_double_offset_min_percentage(bits: Literal[8, 16, 32]):
11695
assert res[0, 0] == 0.0
11796
assert res[0, 1] == max
11897

119-
res_cpu = res_cpu.astype(np.float32)
120-
np.testing.assert_array_almost_equal(res_cpu[1:, :], num)
121-
assert res_cpu[0, 0] == 0.0
122-
assert res_cpu[0, 1] == max
123-
12498

12599
def test_tomo_data_scale(data):
126-
data_cpu = data.get()
127100
res_dev = rescale_to_int(
128101
data.astype(cp.float32), perc_range_min=10, perc_range_max=90, bits=8
129102
)
130-
res_cpu = rescale_to_int(data_cpu, perc_range_min=10, perc_range_max=90, bits=8)
131103
res = res_dev.get()
132-
assert res_dev.dtype == np.uint8
133-
assert res_dev.dtype == np.uint8
134-
np.testing.assert_array_equal(res_cpu, res)
104+
assert res.dtype == np.uint8
135105

136106

137107
@pytest.mark.perf

0 commit comments

Comments
 (0)