Skip to content

Commit b582acc

Browse files
authored
Merge pull request #214 from DiamondLightSource/datavalidator
Data validation checker
2 parents 4d7ad21 + 760512a commit b582acc

File tree

17 files changed

+503
-39
lines changed

17 files changed

+503
-39
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
template <typename Type>
2+
__global__ void remove_nan_inf(Type *data, int Z, int M, int N, int *result) {
3+
const long i = blockDim.x * blockIdx.x + threadIdx.x;
4+
const long j = blockDim.y * blockIdx.y + threadIdx.y;
5+
const long k = blockDim.z * blockIdx.z + threadIdx.z;
6+
7+
if (i >= N || j >= M || k >= Z)
8+
return;
9+
10+
long long index = static_cast<long long>(i) + N * static_cast<long long>(j) + N * M * static_cast<long long>(k);
11+
12+
float val = float(data[index]); /*needs a cast to float for isnan isinf functions to work*/
13+
Type zero = 0;
14+
if (isnan(val) || isinf(val)) {
15+
result[0] = 1;
16+
data[index] = zero;
17+
}
18+
19+
}

httomolibgpu/misc/corr.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@
1818
# Created By : Tomography Team at DLS <[email protected]>
1919
# Created Date: 21/October/2022
2020
# ---------------------------------------------------------------------------
21-
""" Module for data correction. For more detailed information see :ref:`data_correction_module`.
22-
23-
"""
21+
"""Module for data correction. For more detailed information see :ref:`data_correction_module`."""
2422

2523
import numpy as np
2624
from typing import Union
@@ -38,6 +36,7 @@
3836
else:
3937
load_cuda_module = Mock()
4038

39+
from httomolibgpu.misc.supp_func import data_checker
4140

4241
__all__ = [
4342
"median_filter",
@@ -74,7 +73,6 @@ def median_filter(
7473
If the input array is not three dimensional.
7574
"""
7675
input_type = data.dtype
77-
7876
if input_type not in ["float32", "uint16"]:
7977
raise ValueError("The input data should be either float32 or uint16 data type")
8078

@@ -84,6 +82,10 @@ def median_filter(
8482
else:
8583
raise ValueError("The input array must be a 3D array")
8684

85+
data = data_checker(
86+
data, verbosity=True, method_name="median_filter_or_remove_outlier"
87+
)
88+
8789
if kernel_size not in [3, 5, 7, 9, 11, 13]:
8890
raise ValueError("Please select a correct kernel size: 3, 5, 7, 9, 11, 13")
8991

httomolibgpu/misc/denoise.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@
1818
# Created By : Tomography Team at DLS <[email protected]>
1919
# Created Date: 18/December/2024
2020
# ---------------------------------------------------------------------------
21-
""" Module for data denoising. For more detailed information see :ref:`data_denoising_module`.
22-
"""
21+
"""Module for data denoising. For more detailed information see :ref:`data_denoising_module`."""
2322

2423
import numpy as np
2524
from typing import Union, Optional
@@ -29,9 +28,10 @@
2928
cp = cupywrapper.cp
3029
cupy_run = cupywrapper.cupy_run
3130

32-
from numpy import float32
3331
from unittest.mock import Mock
3432

33+
from httomolibgpu.misc.supp_func import data_checker
34+
3535
if cupy_run:
3636
from ccpi.filters.regularisersCuPy import ROF_TV, PD_TV
3737
else:
@@ -82,6 +82,8 @@ def total_variation_ROF(
8282
If the input array is not float32 data type.
8383
"""
8484

85+
data = data_checker(data, verbosity=True, method_name="total_variation_ROF")
86+
8587
return ROF_TV(
8688
data, regularisation_parameter, iterations, time_marching_parameter, gpu_id
8789
)
@@ -127,6 +129,8 @@ def total_variation_PD(
127129
If the input array is not float32 data type.
128130
"""
129131

132+
data_checker(data, verbosity=True, method_name="total_variation_PD")
133+
130134
methodTV = 0
131135
if not isotropic:
132136
methodTV = 1

httomolibgpu/misc/morph.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535

3636
from typing import Literal
3737

38+
from httomolibgpu.misc.supp_func import data_checker
39+
3840
__all__ = [
3941
"sino_360_to_180",
4042
"data_resampler",
@@ -66,6 +68,8 @@ def sino_360_to_180(
6668
if data.ndim != 3:
6769
raise ValueError("only 3D data is supported")
6870

71+
data = data_checker(data, verbosity=True, method_name="sino_360_to_180")
72+
6973
dx, dy, dz = data.shape
7074

7175
overlap = int(np.round(overlap))
@@ -136,6 +140,8 @@ def data_resampler(
136140
data = cp.expand_dims(data, 1)
137141
axis = 1
138142

143+
data = data_checker(data, verbosity=True, method_name="data_resampler")
144+
139145
N, M, Z = cp.shape(data)
140146

141147
if axis == 0:

httomolibgpu/misc/rescale.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@
1818
# Created By : Tomography Team at DLS <[email protected]>
1919
# Created Date: 1 March 2024
2020
# ---------------------------------------------------------------------------
21-
""" Module for data rescaling. For more detailed information see :ref:`data_rescale_module`.
22-
23-
"""
21+
"""Module for data rescaling. For more detailed information see :ref:`data_rescale_module`."""
2422

2523
import numpy as np
2624
from httomolibgpu import cupywrapper
@@ -30,6 +28,8 @@
3028

3129
from typing import Literal, Optional, Tuple, Union
3230

31+
from httomolibgpu.misc.supp_func import data_checker
32+
3333
__all__ = [
3434
"rescale_to_int",
3535
]
@@ -80,6 +80,8 @@ def rescale_to_int(
8080
else:
8181
output_dtype = np.uint32
8282

83+
data = data_checker(data, verbosity=True, method_name="rescale_to_int")
84+
8385
if cupy_run:
8486
xp = cp.get_array_module(data)
8587
else:
@@ -109,7 +111,6 @@ def rescale_to_int(
109111
if xp.__name__ == "numpy":
110112
if input_max == pow(2, 32):
111113
input_max -= 1
112-
data[np.logical_not(np.isfinite(data))] = 0
113114
res = np.copy(data.astype(float))
114115
res[data.astype(float) < input_min] = int(input_min)
115116
res[data.astype(float) > input_max] = int(input_max)

httomolibgpu/misc/supp_func.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
# ---------------------------------------------------------------------------
4+
# Copyright 2022 Diamond Light Source Ltd.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
# ---------------------------------------------------------------------------
18+
# Created By : Tomography Team at DLS <[email protected]>
19+
# Created Date: 02/June/2025
20+
# ---------------------------------------------------------------------------
21+
"""This is a collection of supplementary functions (utils) to perform various data checks"""
22+
23+
from httomolibgpu import cupywrapper
24+
from typing import Optional
25+
26+
cp = cupywrapper.cp
27+
cupy_run = cupywrapper.cupy_run
28+
29+
import numpy as np
30+
31+
from unittest.mock import Mock
32+
33+
if cupy_run:
34+
from httomolibgpu.cuda_kernels import load_cuda_module
35+
else:
36+
load_cuda_module = Mock()
37+
38+
39+
def _naninfs_check(
40+
data: cp.ndarray,
41+
verbosity: bool = True,
42+
method_name: Optional[str] = None,
43+
) -> cp.ndarray:
44+
"""
45+
This function finds NaN's, +-Inf's in the input data and then prints the warnings and correct the data if correction is enabled.
46+
47+
Parameters
48+
----------
49+
data : cp.ndarray
50+
Input CuPy or Numpy array either float32 or uint16 data type.
51+
verbosity : bool
52+
If enabled, then the printing of the warning happens when data contains infs or nans
53+
method_name : str, optional.
54+
Method's name for which input data is tested.
55+
56+
Returns
57+
-------
58+
ndarray
59+
Uncorrected or corrected (nans and infs converted to zeros) input array.
60+
"""
61+
present_nans_infs_b = False
62+
63+
if cupy_run:
64+
xp = cp.get_array_module(data)
65+
else:
66+
import numpy as xp
67+
68+
if xp.__name__ == "cupy":
69+
input_type = data.dtype
70+
if len(data.shape) == 2:
71+
dy, dx = data.shape
72+
dz = 1
73+
else:
74+
dz, dy, dx = data.shape
75+
76+
present_nans_infs = cp.zeros(shape=(1)).astype(cp.uint8)
77+
78+
block_x = 128
79+
# setting grid/block parameters
80+
block_dims = (block_x, 1, 1)
81+
grid_x = (dx + block_x - 1) // block_x
82+
grid_y = dy
83+
grid_z = dz
84+
grid_dims = (grid_x, grid_y, grid_z)
85+
params = (data, dz, dy, dx, present_nans_infs)
86+
87+
kernel_args = "remove_nan_inf<{0}>".format(
88+
"float" if input_type == "float32" else "unsigned short"
89+
)
90+
91+
module = load_cuda_module("remove_nan_inf", name_expressions=[kernel_args])
92+
remove_nan_inf_kernel = module.get_function(kernel_args)
93+
remove_nan_inf_kernel(grid_dims, block_dims, params)
94+
95+
if present_nans_infs[0].get() == 1:
96+
present_nans_infs_b = True
97+
else:
98+
if not np.all(np.isfinite(data)):
99+
present_nans_infs_b = True
100+
np.nan_to_num(data, copy=False, nan=0.0, posinf=0.0, neginf=0.0)
101+
102+
if present_nans_infs_b:
103+
if verbosity:
104+
print(
105+
f"Warning!!! Input data to method: {method_name} contains Inf's or/and NaN's. This will be corrected but it is recommended to check the validity of input to the method."
106+
)
107+
108+
return data
109+
110+
111+
def _zeros_check(
112+
data: cp.ndarray,
113+
verbosity: bool = True,
114+
percentage_threshold: float = 50,
115+
method_name: Optional[str] = None,
116+
) -> bool:
117+
"""
118+
This function finds all zeros present in the data. If the amount of zeros is larger than percentage_threshold it prints the warning.
119+
120+
Parameters
121+
----------
122+
data : cp.ndarray
123+
Input CuPy or Numpy array.
124+
verbosity : bool
125+
If enabled, then the printing of the warning happens when data contains infs or nans.
126+
percentage_threshold: float:
127+
If the number of zeros in input data is more than the percentage of all data points, then print the data warning
128+
method_name : str, optional.
129+
Method's name for which input data is tested.
130+
131+
Returns
132+
-------
133+
bool
134+
True if the data contains too many zeros
135+
"""
136+
if cupy_run:
137+
xp = cp.get_array_module(data)
138+
else:
139+
import numpy as xp
140+
141+
nonzero_elements_total = 1
142+
for tot_elements_mult in data.shape:
143+
nonzero_elements_total *= tot_elements_mult
144+
145+
warning_zeros = False
146+
zero_elements_total = nonzero_elements_total - int(xp.count_nonzero(data))
147+
148+
if (zero_elements_total / nonzero_elements_total) * 100 >= percentage_threshold:
149+
warning_zeros = True
150+
if verbosity:
151+
print(
152+
f"Warning!!! Input data to method: {method_name} contains more than {percentage_threshold} percent of zeros."
153+
)
154+
155+
return warning_zeros
156+
157+
158+
def data_checker(
159+
data: cp.ndarray,
160+
verbosity: bool = True,
161+
method_name: Optional[str] = None,
162+
) -> bool:
163+
"""
164+
Function that performs the variety of checks on input data, in some cases also correct the data and prints warnings.
165+
Currently it checks for: the presence of infs and nans in data; the number of zero elements.
166+
167+
Parameters
168+
----------
169+
data : xp.ndarray
170+
Input CuPy or Numpy array either float32 or uint16 data type.
171+
verbosity : bool
172+
If enabled, then the printing of the warning happens when data contains infs or nans.
173+
method_name : str, optional.
174+
Method's name for which input data is tested.
175+
176+
Returns
177+
-------
178+
cp.ndarray
179+
Returns corrected or not data array.
180+
"""
181+
182+
data = _naninfs_check(data, verbosity=verbosity, method_name=method_name)
183+
184+
_zeros_check(
185+
data,
186+
verbosity=verbosity,
187+
percentage_threshold=50,
188+
method_name=method_name,
189+
)
190+
191+
return data

httomolibgpu/prep/alignment.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535

3636
from typing import Dict, List, Tuple
3737

38+
from httomolibgpu.misc.supp_func import data_checker
39+
3840
__all__ = [
3941
"distortion_correction_proj_discorpy",
4042
]
@@ -86,6 +88,10 @@ def distortion_correction_proj_discorpy(
8688
if len(data.shape) == 2:
8789
data = cp.expand_dims(data, axis=0)
8890

91+
data = data_checker(
92+
data, verbosity=True, method_name="distortion_correction_proj_discorpy"
93+
)
94+
8995
# Get info from metadata txt file
9096
xcenter, ycenter, list_fact = _load_metadata_txt(metadata_path)
9197

0 commit comments

Comments
 (0)