Skip to content

Commit 7703554

Browse files
Merge pull request #790 from analysiscenter/r0.9.0
v0.9.0
2 parents d520c04 + f42eb39 commit 7703554

File tree

7 files changed

+409
-79
lines changed

7 files changed

+409
-79
lines changed

batchflow/models/torch/base.py

Lines changed: 214 additions & 74 deletions
Large diffs are not rendered by default.

batchflow/models/torch/base_mixins.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55
import numpy as np
66
import torch
77

8-
from ...plotter import plot
98
from ...decorators import deprecated
109

10+
from ...utils_import import try_import
11+
plot = try_import(module='...plotter', package=__name__, attribute='plot',
12+
help='Try `pip install batchflow[image]`!')
13+
1114
# Also imports `tensorboard`, if necessary
1215

1316

batchflow/plotter/morphology.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
"""Morphological operations implemented with numba to replace cv2 dependency."""
2+
3+
import numpy as np
4+
from numba import njit, prange
5+
6+
7+
@njit
8+
def dilate(image, kernel, iterations=1):
9+
"""Dilate an image using a structuring element.
10+
11+
Parameters
12+
----------
13+
image : numpy.ndarray
14+
Input image to dilate.
15+
kernel : numpy.ndarray
16+
Structuring element (kernel) for dilation. Should contain 1s where
17+
the structuring element is active and 0s elsewhere.
18+
iterations : int, optional
19+
Number of times to apply the dilation. Default is 1.
20+
21+
Returns
22+
-------
23+
numpy.ndarray
24+
Dilated image with the same shape and dtype as input.
25+
26+
"""
27+
result = image.copy()
28+
29+
for _ in range(iterations):
30+
result = _single_dilate(result, kernel)
31+
32+
return result
33+
34+
@njit
35+
def erode(image, kernel, iterations=1):
36+
"""Erode an image using a structuring element.
37+
38+
Parameters
39+
----------
40+
image : numpy.ndarray
41+
Input image to erode.
42+
kernel : numpy.ndarray
43+
Structuring element (kernel) for erosion. Should contain 1s where
44+
the structuring element is active and 0s elsewhere.
45+
iterations : int, optional
46+
Number of times to apply the erosion. Default is 1.
47+
48+
Returns
49+
-------
50+
numpy.ndarray
51+
Eroded image with the same shape and dtype as input.
52+
53+
"""
54+
result = image.copy()
55+
56+
for _ in range(iterations):
57+
result = _single_erode(result, kernel)
58+
59+
return result
60+
61+
@njit(parallel=True)
62+
def _single_dilate(image, kernel):
63+
"""Single iteration of dilation operation."""
64+
height, width = image.shape
65+
kh, kw = kernel.shape
66+
kh_half, kw_half = kh // 2, kw // 2
67+
68+
# Create output array
69+
result = np.zeros_like(image)
70+
71+
# Apply dilation - for each output pixel, find max in kernel neighborhood
72+
for i in prange(height):
73+
for j in range(width):
74+
max_val = image[i, j] # Start with current pixel value
75+
76+
for ki in range(kh):
77+
for kj in range(kw):
78+
if kernel[ki, kj] > 0: # Only consider active kernel elements
79+
# Calculate the source image coordinates
80+
img_i = i + ki - kh_half
81+
img_j = j + kj - kw_half
82+
83+
# Check bounds
84+
if 0 <= img_i < height and 0 <= img_j < width:
85+
if image[img_i, img_j] > max_val:
86+
max_val = image[img_i, img_j]
87+
88+
result[i, j] = max_val
89+
90+
return result
91+
92+
@njit(parallel=True)
93+
def _single_erode(image, kernel):
94+
"""Single iteration of erosion operation."""
95+
height, width = image.shape
96+
kh, kw = kernel.shape
97+
kh_half, kw_half = kh // 2, kw // 2
98+
99+
# Create output array
100+
result = np.zeros_like(image)
101+
102+
# Apply erosion - for each output pixel, find min in kernel neighborhood
103+
for i in prange(height):
104+
for j in range(width):
105+
min_val = image[i, j] # Start with current pixel value
106+
107+
for ki in range(kh):
108+
for kj in range(kw):
109+
if kernel[ki, kj] > 0: # Only consider active kernel elements
110+
# Calculate the source image coordinates
111+
img_i = i + ki - kh_half
112+
img_j = j + kj - kw_half
113+
114+
# Check bounds - treat out of bounds as 0 for erosion
115+
if 0 <= img_i < height and 0 <= img_j < width:
116+
if image[img_i, img_j] < min_val:
117+
min_val = image[img_i, img_j]
118+
else:
119+
# Outside bounds treated as 0, so erosion result should be 0
120+
min_val = 0
121+
break
122+
123+
result[i, j] = min_val
124+
125+
return result

batchflow/plotter/plot.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def flatten(self, data):
103103

104104
def dilate(self, data):
105105
""" Apply dilation to array. """
106-
import cv2
106+
from .morphology import dilate
107107
dilation_config = self.config.get('dilate', False)
108108

109109
default_kernel = np.ones((3, 1), dtype=np.uint8)
@@ -116,7 +116,7 @@ def dilate(self, data):
116116
dilation_config = {'kernel': np.ones(dilation_config, dtype=np.uint8)}
117117
elif 'kernel' in dilation_config and isinstance(dilation_config['kernel'], tuple):
118118
dilation_config['kernel'] = np.ones(dilation_config['kernel'], dtype=np.uint8)
119-
data = cv2.dilate(data.astype(np.float32), **dilation_config)
119+
data = dilate(data.astype(np.float32), **dilation_config)
120120
return data
121121

122122
def mask(self, data):

batchflow/tests/model_save_load_test.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
""" Test for model saving and loading """
22

3+
import os
34
import pickle
45

56
import pytest
@@ -186,3 +187,49 @@ def test_bare_model(self, save_path, model_class, pickle_module, outputs):
186187
loaded_predictions = model_load.predict(*args, **kwargs)
187188

188189
assert (np.concatenate(saved_predictions) == np.concatenate(loaded_predictions)).all()
190+
191+
@pytest.mark.parametrize("fmt", [None, 'onnx', 'openvino', 'safetensors'])
192+
@pytest.mark.parametrize("pickle_metadata", [False, True])
193+
def test_save_load_format(self, save_path, model_class, fmt, pickle_metadata):
194+
num_classes = 10
195+
dataset_size = 10
196+
image_shape = (2, 100, 100)
197+
198+
save_kwargs = {
199+
None: {},
200+
'onnx': dict(batch_size=dataset_size),
201+
'openvino': {},
202+
'safetensors': {},
203+
}
204+
load_kwargs = {
205+
None: {},
206+
'onnx': {},
207+
'openvino': {'device': 'cpu'},
208+
'safetensors': {},
209+
}
210+
211+
if fmt == 'openvino' and not pickle_metadata:
212+
save_path = os.path.splitext(save_path)[0] + '.xml'
213+
214+
model_config = {
215+
'classes': num_classes,
216+
'inputs_shapes': image_shape,
217+
'output': 'sigmoid'
218+
}
219+
220+
model_save = model_class(config=model_config)
221+
222+
batch_shape = (dataset_size, *image_shape)
223+
images_array = np.random.random(batch_shape)
224+
225+
inputs = images_array.astype('float32')
226+
227+
saved_predictions = model_save.predict(inputs, outputs='sigmoid')
228+
model_save.save(path=save_path, pickle_metadata=pickle_metadata, fmt=fmt, **save_kwargs[fmt])
229+
230+
load_config = {} if fmt != 'safetensors' else model_save.config
231+
model_load = model_class(config=load_config)
232+
model_load.load(path=save_path, fmt='pt' if pickle_metadata else fmt, **load_kwargs[fmt])
233+
loaded_predictions = model_load.predict(inputs, outputs='sigmoid')
234+
235+
assert np.isclose(np.concatenate(saved_predictions), np.concatenate(loaded_predictions), atol=1e-3).all()

batchflow/tests/research_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,7 @@ def f(a):
570570
assert research.results.df.iloc[0].a == f(2)
571571
assert research.results.df.iloc[0].b == f(3)
572572

573+
@pytest.mark.slow
573574
@pytest.mark.parametrize('dump_results', [False, True])
574575
@pytest.mark.parametrize('redirect_stdout', [True, 0, 1, 2, 3])
575576
@pytest.mark.parametrize('redirect_stderr', [True, 0, 1, 2, 3])

pyproject.toml

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "batchflow"
3-
version = "0.8.12"
3+
version = "0.9.0"
44
description = "ML pipelines, model configuration and batch management"
55
authors = [{ name = "Roman Kh", email = "[email protected]" }]
66
license = {text = "Apache License 2.0"}
@@ -25,7 +25,8 @@ dependencies = [
2525
"numba>=0.56",
2626
"llvmlite",
2727
"scipy>=1.9",
28-
"tqdm>=4.19"
28+
"tqdm>=4.19",
29+
"pytest>=8.3.4",
2930
]
3031

3132
[project.optional-dependencies]
@@ -74,6 +75,19 @@ telegram = [
7475
"pillow>=9.4,<11.0",
7576
]
7677

78+
safetensors = [
79+
"safetensors>=0.5.3",
80+
]
81+
82+
onnx = [
83+
"onnx>=1.14.0",
84+
"onnx2torch>=1.5.0",
85+
]
86+
87+
openvino = [
88+
"openvino>=2025.0.0",
89+
]
90+
7791
other = [
7892
"urllib3>=1.25"
7993
]

0 commit comments

Comments
 (0)