Skip to content

Commit fc4d9c3

Browse files
authored
Merge pull request #179 from DiamondLightSource/zenodo
CI: Run tests on larger datasets
2 parents 12302ac + 2058f31 commit fc4d9c3

File tree

6 files changed

+267
-1
lines changed

6 files changed

+267
-1
lines changed

.github/workflows/httomolibgpu_tests_run_iris.yml

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ jobs:
1414
image: nvidia/cuda:12.6.3-devel-ubi8
1515
env:
1616
NVIDIA_VISIBLE_DEVICES: ${{ env.NVIDIA_VISIBLE_DEVICES }}
17+
options: --gpus all --runtime=nvidia
1718

1819
defaults:
1920
run:
@@ -23,6 +24,10 @@ jobs:
2324
- name: Checkout repository code
2425
uses: actions/checkout@v4
2526

27+
- name: Set up CUDA environment
28+
run: |
29+
echo "LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH" >> $GITHUB_ENV
30+
2631
- name: Create conda environment
2732
uses: mamba-org/setup-micromamba@v1
2833
with:
@@ -36,6 +41,14 @@ jobs:
3641
pip install .[dev]
3742
micromamba list
3843
39-
- name: Run tests
44+
- name: Run unit tests on small data
4045
run: |
4146
pytest tests/
47+
48+
# Optional: Run Zenodo tests only if PR has a label
49+
- name: Download and run Zenodo tests
50+
if: contains(github.event.pull_request.labels.*.name, 'run-zenodo-tests')
51+
run: |
52+
chmod +x ./.scripts/download_zenodo.py
53+
./.scripts/download_zenodo.py zenodo-tests/large_data_archive
54+
pytest zenodo-tests/

.github/workflows/main-checks.yml

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
name: Main Branch Tests
2+
on:
3+
push:
4+
branches:
5+
- main
6+
7+
jobs:
8+
iris-gpu:
9+
runs-on: iris-gpu
10+
container:
11+
image: nvidia/cuda:12.6.3-devel-ubi8
12+
env:
13+
NVIDIA_VISIBLE_DEVICES: ${{ env.NVIDIA_VISIBLE_DEVICES }}
14+
options: --gpus all --runtime=nvidia
15+
16+
defaults:
17+
run:
18+
shell: bash -l {0}
19+
20+
steps:
21+
- uses: actions/checkout@v4
22+
23+
- name: Set up CUDA environment
24+
run: |
25+
echo "LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH" >> $GITHUB_ENV
26+
27+
- name: Create conda environment
28+
uses: mamba-org/setup-micromamba@v1
29+
with:
30+
environment-file: conda/environment.yml
31+
environment-name: httomo
32+
post-cleanup: 'all'
33+
init-shell: bash
34+
35+
- name: Download test data from Zenodo
36+
run: |
37+
chmod +x ./.scripts/download_zenodo.py
38+
./.scripts/download_zenodo.py zenodo-tests/large_data_archive
39+
40+
- name: Install httomolibgpu
41+
run: |
42+
pip install .[dev]
43+
micromamba list
44+
45+
- name: Run all tests (including Zenodo)
46+
run: |
47+
pytest tests/ zenodo-tests/

.scripts/download_zenodo.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#!/usr/bin/env python3
2+
3+
import json
4+
import urllib.request
5+
import hashlib
6+
import sys
7+
import os
8+
from pathlib import Path
9+
10+
11+
def calculate_md5(filename):
12+
"""Calculate MD5 hash of a file."""
13+
md5_hash = hashlib.md5()
14+
with open(filename, "rb") as f:
15+
for chunk in iter(lambda: f.read(4096), b""):
16+
md5_hash.update(chunk)
17+
return md5_hash.hexdigest()
18+
19+
20+
def download_zenodo_files(output_dir: Path):
21+
"""
22+
Download all files from Zenodo record 14338424 and verify their checksums.
23+
24+
Args:
25+
output_dir: Directory where files should be downloaded
26+
"""
27+
try:
28+
print("Fetching files from Zenodo record 14338424...")
29+
with urllib.request.urlopen("https://zenodo.org/api/records/14338424") as response:
30+
data = json.loads(response.read())
31+
32+
# Create output directory if it doesn't exist
33+
output_dir.mkdir(parents=True, exist_ok=True)
34+
35+
# Now 'files' is a list, not a dictionary
36+
for file_info in data["files"]:
37+
filename = file_info["key"] # The 'key' is the filename
38+
output_file = output_dir / filename
39+
print(f"Downloading {filename}...")
40+
url = file_info["links"]["self"] # The link to download the file
41+
42+
expected_md5 = file_info["checksum"].split(":")[1] # Extract MD5 hash
43+
44+
# Download the file
45+
urllib.request.urlretrieve(url, output_file)
46+
47+
# Verify checksum
48+
actual_md5 = calculate_md5(output_file)
49+
if actual_md5 == expected_md5:
50+
print(f"✓ Verified {filename}")
51+
else:
52+
print(f"✗ Checksum verification failed for {filename}")
53+
print(f"Expected: {expected_md5}")
54+
print(f"Got: {actual_md5}")
55+
sys.exit(1)
56+
57+
print("\nAll files downloaded and verified successfully!")
58+
59+
except Exception as e:
60+
print(f"Error: {str(e)}", file=sys.stderr)
61+
sys.exit(1)
62+
63+
64+
if __name__ == "__main__":
65+
if len(sys.argv) != 2:
66+
print("Usage: download_zenodo.py <output_directory>")
67+
sys.exit(1)
68+
69+
output_dir = Path(sys.argv[1])
70+
download_zenodo_files(output_dir)

zenodo-tests/conftest.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import os
2+
import cupy as cp
3+
import numpy as np
4+
import pytest
5+
6+
CUR_DIR = os.path.abspath(os.path.dirname(__file__))
7+
8+
9+
@pytest.fixture(scope="session")
10+
def test_data_path():
11+
return os.path.join(CUR_DIR, "large_data_archive")
12+
13+
14+
@pytest.fixture(scope="session")
15+
def data_i12LFOV_file(test_data_path):
16+
in_file = os.path.join(test_data_path, "i12LFOV.npz")
17+
return np.load(in_file)
18+
19+
20+
@pytest.fixture(scope="session")
21+
def data_i12_sandstone_file(test_data_path):
22+
in_file = os.path.join(test_data_path, "i12_sandstone_50sinoslices.npz")
23+
return np.load(in_file)
24+
25+
26+
@pytest.fixture(scope="session")
27+
def data_geant4sim_file(test_data_path):
28+
in_file = os.path.join(test_data_path, "geant4_640_540_proj360.npz")
29+
return np.load(in_file)
30+
31+
@pytest.fixture
32+
def i12LFOV_data(data_i12LFOV_file):
33+
return (
34+
cp.asarray(data_i12LFOV_file["projdata"]),
35+
data_i12LFOV_file["angles"],
36+
cp.asarray(data_i12LFOV_file["flats"]),
37+
cp.asarray(data_i12LFOV_file["darks"]),
38+
)
39+
40+
41+
@pytest.fixture
42+
def i12sandstone_data(data_i12_sandstone_file):
43+
return (
44+
cp.asarray(data_i12_sandstone_file["projdata"]),
45+
data_i12_sandstone_file["angles"],
46+
cp.asarray(data_i12_sandstone_file["flats"]),
47+
cp.asarray(data_i12_sandstone_file["darks"]),
48+
)
49+
50+
51+
@pytest.fixture
52+
def geantsim_data(data_geant4sim_file):
53+
return (
54+
cp.asarray(data_geant4sim_file["projdata"]),
55+
data_geant4sim_file["angles"],
56+
cp.asarray(data_geant4sim_file["flats"]),
57+
cp.asarray(data_geant4sim_file["darks"]),
58+
)
59+
60+
61+
@pytest.fixture
62+
def ensure_clean_memory():
63+
cp.get_default_memory_pool().free_all_blocks()
64+
cp.get_default_pinned_memory_pool().free_all_blocks()
65+
yield None
66+
cp.get_default_memory_pool().free_all_blocks()
67+
cp.get_default_pinned_memory_pool().free_all_blocks()

zenodo-tests/test_recon/__init__.py

Whitespace-only changes.
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import cupy as cp
2+
import numpy as np
3+
import pytest
4+
5+
from httomolibgpu.prep.normalize import normalize
6+
from httomolibgpu.recon.rotation import find_center_vo
7+
8+
9+
def test_center_vo_i12LFOV(i12LFOV_data, ensure_clean_memory):
10+
projdata = i12LFOV_data[0]
11+
flats = i12LFOV_data[2]
12+
darks = i12LFOV_data[3]
13+
del i12LFOV_data
14+
15+
data_normalised = normalize(projdata, flats, darks, minus_log=False)
16+
del flats, darks, projdata
17+
18+
mid_slice = data_normalised.shape[1] // 2
19+
cor = find_center_vo(data_normalised[:, mid_slice, :])
20+
21+
assert cor == 1197.75
22+
assert cor.dtype == np.float32
23+
24+
25+
def test_center_vo_average_i12LFOV(i12LFOV_data, ensure_clean_memory):
26+
projdata = i12LFOV_data[0]
27+
flats = i12LFOV_data[2]
28+
darks = i12LFOV_data[3]
29+
del i12LFOV_data
30+
31+
data_normalised = normalize(projdata, flats, darks, minus_log=False)
32+
del flats, darks, projdata
33+
34+
cor = find_center_vo(data_normalised[:, 10:25, :], average_radius=5)
35+
36+
assert cor == 1199.25
37+
assert cor.dtype == np.float32
38+
39+
40+
def test_center_vo_i12_sandstone(i12sandstone_data, ensure_clean_memory):
41+
projdata = i12sandstone_data[0]
42+
flats = i12sandstone_data[2]
43+
darks = i12sandstone_data[3]
44+
del i12sandstone_data
45+
46+
data_normalised = normalize(projdata, flats, darks, minus_log=True)
47+
del flats, darks, projdata
48+
49+
mid_slice = data_normalised.shape[1] // 2
50+
cor = find_center_vo(data_normalised[:, mid_slice, :])
51+
52+
assert cor == 1253.75
53+
assert cor.dtype == np.float32
54+
55+
56+
def test_center_vo_i12_geantsim(geantsim_data, ensure_clean_memory):
57+
projdata = geantsim_data[0]
58+
flats = geantsim_data[2]
59+
darks = geantsim_data[3]
60+
del geantsim_data
61+
62+
data_normalised = normalize(projdata, flats, darks, minus_log=True)
63+
del flats, darks, projdata
64+
65+
mid_slice = data_normalised.shape[1] // 2
66+
cor = find_center_vo(data_normalised[:, mid_slice, :])
67+
68+
assert cor == 319.5
69+
assert cor.dtype == np.float32

0 commit comments

Comments
 (0)