Skip to content

Commit ff43018

Browse files
authored
Merge pull request #5 from mitmedialab/rgb-array-input-fix
accept RGB/RGBA arrays in derive_configs (not just 2-D / paths)
2 parents 843297e + e8907c8 commit ff43018

1 file changed

Lines changed: 18 additions & 3 deletions

File tree

release/auto_config.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
from PIL import Image
5050
from scipy.ndimage import distance_transform_edt
5151

52+
from .skeletonize import _to_grayscale_float
53+
5254
# Reference stroke width the hand-tuned defaults assume. Computed as
5355
# the rough median of stroke widths in the bundled examples — 2.8 px
5456
# for the smallest, 4.5 px for the largest, with most at ~3 px.
@@ -63,15 +65,28 @@ def _load_binary(
6365
6466
The pipeline's ``Skeletonize.Config.Binarize(threshold=0.5)`` is
6567
the standard — values below 0.5 are strokes, above are background.
68+
69+
Accepts the same source shapes ``Skeletonize`` does: a file path, a
70+
2-D grayscale array (bool / integer / float), or a 3-D RGB/RGBA
71+
array. A multi-channel array is collapsed to a single grayscale
72+
channel via the exact converter the pipeline uses, so passing a raw
73+
``np.asarray(rgb_image)`` to ``default_pipeline`` works rather than
74+
raising deep inside config derivation.
6675
"""
6776
if isinstance(source, str):
6877
arr = np.asarray(Image.open(source).convert("L"), dtype=np.float64) / 255.0
6978
elif isinstance(source, np.ndarray):
7079
if source.dtype == np.bool_:
7180
return source
72-
arr = source.astype(np.float64)
73-
if arr.max() > 1.0 + 1e-6:
74-
arr = arr / 255.0
81+
if source.ndim == 3:
82+
# RGB/RGBA: composite-over-white (RGBA) then luma-reduce,
83+
# exactly as ``Skeletonize`` does for array sources. The
84+
# result is already a float grayscale in [0, 1].
85+
arr = _to_grayscale_float(source)
86+
else:
87+
arr = source.astype(np.float64)
88+
if arr.max() > 1.0 + 1e-6:
89+
arr = arr / 255.0
7590
else:
7691
raise TypeError(f"unsupported source type {type(source)!r}")
7792
return arr < 0.5

0 commit comments

Comments
 (0)