diff --git a/.scripts/download_zenodo.py b/.scripts/download_zenodo.py index f02c6fbd..b86ed8a6 100755 --- a/.scripts/download_zenodo.py +++ b/.scripts/download_zenodo.py @@ -18,15 +18,15 @@ def calculate_md5(filename): def download_zenodo_files(output_dir: Path): """ - Download all files from Zenodo record 14938787 and verify their checksums. + Download all files from Zenodo record 14979785 and verify their checksums. Args: output_dir: Directory where files should be downloaded """ try: - print("Fetching files from Zenodo record 14938787...") + print("Fetching files from Zenodo record 14979785...") with urllib.request.urlopen( - "https://zenodo.org/api/records/14938787" + "https://zenodo.org/api/records/14979785" ) as response: data = json.loads(response.read()) diff --git a/httomolibgpu/prep/stripe.py b/httomolibgpu/prep/stripe.py index 68267ac1..300971da 100644 --- a/httomolibgpu/prep/stripe.py +++ b/httomolibgpu/prep/stripe.py @@ -201,14 +201,23 @@ def remove_all_stripe( Corrected 3D tomographic data as a CuPy or NumPy array. """ - matindex = _create_matindex(data.shape[2], data.shape[0]) - for m in range(data.shape[1]): - sino = data[:, m, :] - sino = _rs_dead(sino, snr, la_size, matindex) - sino = _rs_sort(sino, sm_size, dim) - sino = cp.nan_to_num(sino) - data[:, m, :] = sino - return data + streams = [cp.cuda.Stream() for _ in range(4)] + output = data.copy() + def process_slice(m, stream): + with stream: + output[:, m, :] = _rs_dead(output[:, m, :], snr, la_size) + output[:, m, :] = _rs_sort(output[:, m, :], sm_size, dim) + output[:, m, :] = cp.nan_to_num(output[:, m, :]) + + # Distribute slices across streams + for i in range(data.shape[1]): + stream = streams[i % 4] + process_slice(i, stream) + + for stream in streams: + stream.synchronize() + + return output def _mpolyfit(x, y): @@ -252,7 +261,7 @@ def _detect_stripe(listdata, snr): return listmask -def _rs_large(sinogram, snr, size, matindex, drop_ratio=0.1, norm=True): +def _rs_large(sinogram, snr, size, drop_ratio=0.1, norm=True): """ Remove large stripes. """ @@ -264,35 +273,35 @@ def _rs_large(sinogram, snr, size, matindex, drop_ratio=0.1, norm=True): list1 = cp.mean(sinosort[ndrop : nrow - ndrop], axis=0) list2 = cp.mean(sinosmooth[ndrop : nrow - ndrop], axis=0) listfact = list1 / list2 - # Locate stripes listmask = _detect_stripe(listfact, snr) listmask = binary_dilation(listmask, iterations=1).astype(listmask.dtype) - matfact = cp.tile(listfact, (nrow, 1)) + # Normalize - if norm is True: - sinogram = sinogram / matfact - sinogram1 = cp.transpose(sinogram) - matcombine = cp.asarray(cp.dstack((matindex, sinogram1))) - - ids = cp.argsort(matcombine[:, :, 1], axis=1) - matsort = matcombine.copy() - matsort[:, :, 0] = cp.take_along_axis(matsort[:, :, 0], ids, axis=1) - matsort[:, :, 1] = cp.take_along_axis(matsort[:, :, 1], ids, axis=1) - - matsort[:, :, 1] = cp.transpose(sinosmooth) - ids = cp.argsort(matsort[:, :, 0], axis=1) - matsortback = matsort.copy() - matsortback[:, :, 0] = cp.take_along_axis(matsortback[:, :, 0], ids, axis=1) - matsortback[:, :, 1] = cp.take_along_axis(matsortback[:, :, 1], ids, axis=1) - - sino_corrected = cp.transpose(matsortback[:, :, 1]) + if norm: + sinogram /= cp.tile(listfact, (nrow, 1)) + + sino_transposed = sinogram.T + ids_sort = cp.argsort(sino_transposed, axis=1) + + # Apply sorting without explicit matindex + sino_sorted = cp.take_along_axis(sino_transposed, ids_sort, axis=1) + + # Smoothen sorted sinogram + sino_sorted[:, :] = cp.transpose(sinosmooth) + + # Restore original order + ids_restore = cp.argsort(ids_sort, axis=1) + sino_corrected = cp.take_along_axis(sino_sorted, ids_restore, axis=1).T + + # Apply corrections only to affected columns listxmiss = cp.where(listmask > 0.0)[0] sinogram[:, listxmiss] = sino_corrected[:, listxmiss] + return sinogram -def _rs_dead(sinogram, snr, size, matindex, norm=True): +def _rs_dead(sinogram, snr, size, norm=True): """remove unresponsive and fluctuating stripes""" sinogram = cp.copy(sinogram) # Make it mutable (nrow, _) = sinogram.shape @@ -316,14 +325,15 @@ def _rs_dead(sinogram, snr, size, matindex, norm=True): if len(listxmiss) > 0: ids = cp.searchsorted(listx, listxmiss) weights = (listxmiss - listx[ids - 1]) / (listx[ids] - listx[ids - 1]) - # direct interpolation without making an extra copy - sinogram[:, listxmiss] = sinogram[:, listx[ids - 1]] + weights * ( - sinogram[:, listx[ids]] - sinogram[:, listx[ids - 1]] - ) + left_vals = cp.take(sinogram, listx[ids - 1], axis=1) + right_vals = cp.take(sinogram, listx[ids], axis=1) + diff = right_vals - left_vals + diff *= weights + sinogram[:, listxmiss] = left_vals + diff # Remove residual stripes if norm is True: - sinogram = _rs_large(sinogram, snr, size, matindex) + sinogram = _rs_large(sinogram, snr, size) return sinogram @@ -416,12 +426,3 @@ def raven_filter( data = data[pad_y : height - pad_y, :, pad_x : width - pad_x].real return cp.require(data, requirements="C") - - -def _create_matindex(nrow, ncol): - """ - Create a 2D array of indexes used for the sorting technique. - """ - listindex = cp.arange(0.0, ncol, 1.0) - matindex = cp.tile(listindex, (nrow, 1)) - return matindex.astype(np.float32) diff --git a/remove_all_stripe.py b/remove_all_stripe.py new file mode 100644 index 00000000..88d1acf2 --- /dev/null +++ b/remove_all_stripe.py @@ -0,0 +1,74 @@ +import cupy as cp +import numpy as np +import os +import time +from cupy.cuda import memory_hooks +from datetime import datetime +from math import isclose +from cupyx.profiler import time_range + +from httomolibgpu.prep.stripe import remove_all_stripe + +test_data_path = "/mnt/gpfs03/scratch/data/imaging/tomography/zenodo" +data_path = os.path.join(test_data_path, "synth_tomophantom1.npz") +data_file = np.load(data_path) +projdata = cp.asarray(cp.swapaxes(data_file["projdata"], 0, 1)) +angles = cp.asarray(data_file["angles"]) + +with time_range("all_stripe", color_id=0): + remove_all_stripe( + cp.copy(projdata), + snr=0.1, + la_size=71, + sm_size=31, + dim=1 + ) + + +# cold run +remove_all_stripe( + cp.copy(projdata), + snr=0.1, + la_size=71, + sm_size=31, + dim=1, +) + +dev = cp.cuda.Device() +dev.synchronize() +start = time.perf_counter_ns() +for _ in range(10): + remove_all_stripe( + cp.copy(projdata), + snr=0.1, + la_size=71, + sm_size=31, + dim=1, + ) + +dev.synchronize() +duration_ms = float(time.perf_counter_ns() - start) * 1e-6 / 10 + +print(duration_ms) + + +output = remove_all_stripe(cp.copy(projdata), snr=0.1, la_size=61, sm_size=21, dim=1) +residual_calc = projdata - output +norm_res = cp.linalg.norm(residual_calc.flatten()) +assert isclose(norm_res, 67917.71, abs_tol=10**-2) + +output = remove_all_stripe(cp.copy(projdata), snr=0.001, la_size=61, sm_size=21, dim=1) +residual_calc = projdata - output +norm_res = cp.linalg.norm(residual_calc.flatten()) +assert isclose(norm_res, 70015.51, abs_tol=10**-2) + +hook = memory_hooks.LineProfileHook() +with hook: + remove_all_stripe( + cp.copy(projdata), + snr=0.1, + la_size=71, + sm_size=31, + dim=1 + ) +hook.print_report() diff --git a/zenodo-tests/conftest.py b/zenodo-tests/conftest.py index a04e505d..b5cabfd1 100644 --- a/zenodo-tests/conftest.py +++ b/zenodo-tests/conftest.py @@ -185,6 +185,20 @@ def geant4_dataset1(geant4_dataset1_file): ) +@pytest.fixture(scope="session") +def synth_tomophantom1_file(test_data_path): + in_file = os.path.join(test_data_path, "synth_tomophantom1.npz") + return np.load(in_file) + + +@pytest.fixture +def synth_tomophantom1_dataset(synth_tomophantom1_file): + return ( + cp.asarray(cp.swapaxes(synth_tomophantom1_file["projdata"], 0, 1)), + synth_tomophantom1_file["angles"], + ) + + @pytest.fixture def ensure_clean_memory(): gc.collect() diff --git a/zenodo-tests/test_prep/test_stripe.py b/zenodo-tests/test_prep/test_stripe.py index 69bbced2..2e739cad 100644 --- a/zenodo-tests/test_prep/test_stripe.py +++ b/zenodo-tests/test_prep/test_stripe.py @@ -134,6 +134,38 @@ def test_remove_all_stripe_i12_dataset4( assert output.flags.c_contiguous +@pytest.mark.parametrize( + "dataset_fixture, snr_val, la_size_val, sm_size_val, norm_res_expected", + [ + ("synth_tomophantom1_dataset", 1.0, 61, 21, 53435.61), + ("synth_tomophantom1_dataset", 0.1, 61, 21, 67917.71), + ("synth_tomophantom1_dataset", 0.001, 61, 21, 70015.51), + ], + ids=["snr_1", "snr_2", "snr_3"], +) +def test_remove_all_stripe_synth_tomophantom1_dataset( + request, dataset_fixture, snr_val, la_size_val, sm_size_val, norm_res_expected +): + dataset = request.getfixturevalue(dataset_fixture) + force_clean_gpu_memory() + + output = remove_all_stripe( + cp.copy(dataset[0]), + snr=snr_val, + la_size=la_size_val, + sm_size=sm_size_val, + dim=1, + ) + + residual_calc = dataset[0] - output + norm_res = cp.linalg.norm(residual_calc.flatten()) + + assert isclose(norm_res, norm_res_expected, abs_tol=10**-2) + assert not np.isnan(output).any(), "Output contains NaN values" + assert output.dtype == np.float32 + assert output.flags.c_contiguous + + @pytest.mark.parametrize( "dataset_fixture, nvalue_val, vvalue_val, norm_res_expected", [