Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions imgviz/draw/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from ._arrow import arrow
from ._arrow import arrow_
from ._circle import circle
from ._circle import circle_
from ._ellipse import ellipse
Expand Down
97 changes: 97 additions & 0 deletions imgviz/draw/_arrow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import numpy as np
import PIL.Image
import PIL.ImageDraw
from numpy.typing import ArrayLike
from numpy.typing import NDArray

from .. import _utils
from ._ink import Ink
from ._ink import get_pil_ink


def arrow(
image: NDArray[np.uint8],
yx1: ArrayLike,
yx2: ArrayLike,
fill: Ink,
width: int = 1,
head_length_ratio: float = 0.1,
head_angle: float = 30.0,
) -> NDArray[np.uint8]:
"""Draw arrow on numpy array with Pillow.

Args:
image: Input image.
yx1: Tail (y, x) where the arrow starts.
yx2: Tip (y, x) where the arrowhead is drawn.
fill: RGB color to draw the arrow.
width: Line width.
head_length_ratio: Arrowhead length as a fraction of the shaft length.
head_angle: Half-angle of the arrowhead in degrees.

Returns:
Output image.
"""
dst = _utils.numpy_to_pillow(image)
arrow_(
image=dst,
yx1=yx1,
yx2=yx2,
fill=fill,
width=width,
head_length_ratio=head_length_ratio,
head_angle=head_angle,
)
return _utils.pillow_to_numpy(dst)


def arrow_(
image: PIL.Image.Image,
yx1: ArrayLike,
yx2: ArrayLike,
fill: Ink,
width: int = 1,
head_length_ratio: float = 0.1,
head_angle: float = 30.0,
) -> None:
"""Draw arrow on PIL image in-place.

Args:
image: PIL image to draw on (modified in-place).
yx1: Tail (y, x) where the arrow starts.
yx2: Tip (y, x) where the arrowhead is drawn.
fill: RGB color to draw the arrow.
width: Line width.
head_length_ratio: Arrowhead length as a fraction of the shaft length.
head_angle: Half-angle of the arrowhead in degrees.
"""
if not isinstance(image, PIL.Image.Image):
raise TypeError(
f"image must be PIL.Image.Image, but got {type(image).__name__}"
)
tail = np.asarray(yx1, dtype=float)
tip = np.asarray(yx2, dtype=float)
if tail.shape != (2,):
raise ValueError(f"yx1 must have shape (2,), but got {tail.shape}")
if tip.shape != (2,):
raise ValueError(f"yx2 must have shape (2,), but got {tip.shape}")

draw = PIL.ImageDraw.Draw(image)
pil_fill = get_pil_ink(fill)

y1, x1 = float(tail[0]), float(tail[1])
y2, x2 = float(tip[0]), float(tip[1])
draw.line([x1, y1, x2, y2], fill=pil_fill, width=width)

shaft = tip - tail
length = float(np.linalg.norm(shaft))
if length == 0:
return
uy, ux = shaft / length
head_length = length * head_length_ratio
for sign in (1, -1):
a = np.radians(sign * head_angle)
cos_a, sin_a = np.cos(a), np.sin(a)
by = y2 - head_length * float(uy * cos_a - ux * sin_a)
bx = x2 - head_length * float(uy * sin_a + ux * cos_a)
draw.line([x2, y2, bx, by], fill=pil_fill, width=width)
76 changes: 76 additions & 0 deletions tests/unit/draw/_arrow_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import numpy as np
import PIL.Image
import pytest

import imgviz


def test_arrow() -> None:
img = np.full((100, 100, 3), 255, dtype=np.uint8)
res = imgviz.draw.arrow(img, yx1=(20, 20), yx2=(20, 80), fill=(255, 0, 0))
assert res.shape == img.shape
assert res.dtype == img.dtype
assert not np.array_equal(res, img)


def test_arrow_adds_head_over_plain_line() -> None:
img = np.full((100, 100, 3), 255, dtype=np.uint8)
line = imgviz.draw.line(img, yx=[(20, 20), (20, 80)], fill=(255, 0, 0))
arrow = imgviz.draw.arrow(img, yx1=(20, 20), yx2=(20, 80), fill=(255, 0, 0))
changed_by_head = np.any(line != arrow, axis=2)
assert changed_by_head.sum() > 0


def test_arrow_head_is_at_the_tip() -> None:
img = np.full((100, 100, 3), 255, dtype=np.uint8)
line = imgviz.draw.line(img, yx=[(20, 20), (20, 80)], fill=(255, 0, 0))
arrow = imgviz.draw.arrow(img, yx1=(20, 20), yx2=(20, 80), fill=(255, 0, 0))
ys, xs = np.where(np.any(line != arrow, axis=2))
assert xs.mean() > 50 # closer to the tip (x=80) than the tail (x=20)


def test_arrow_head_length_ratio_scales_head() -> None:
img = np.full((100, 100, 3), 255, dtype=np.uint8)
small = imgviz.draw.arrow(
img, yx1=(50, 10), yx2=(50, 90), fill=(255, 0, 0), head_length_ratio=0.1
)
large = imgviz.draw.arrow(
img, yx1=(50, 10), yx2=(50, 90), fill=(255, 0, 0), head_length_ratio=0.3
)
assert np.any(small != img, axis=2).sum() < np.any(large != img, axis=2).sum()


def test_arrow_head_is_at_the_tip_diagonal() -> None:
img = np.full((100, 100, 3), 255, dtype=np.uint8)
line = imgviz.draw.line(img, yx=[(20, 20), (80, 80)], fill=(255, 0, 0))
arrow = imgviz.draw.arrow(img, yx1=(20, 20), yx2=(80, 80), fill=(255, 0, 0))
ys, xs = np.where(np.any(line != arrow, axis=2))
assert ys.mean() > 50 and xs.mean() > 50 # head near tip (80, 80), not tail


def test_arrow_zero_length_skips_head() -> None:
img = np.full((100, 100, 3), 255, dtype=np.uint8)
line = imgviz.draw.line(img, yx=[(50, 50), (50, 50)], fill=(255, 0, 0))
arrow = imgviz.draw.arrow(img, yx1=(50, 50), yx2=(50, 50), fill=(255, 0, 0))
assert np.array_equal(arrow, line) # no arrowhead, just the degenerate shaft


def test_arrow_in_place() -> None:
pil = PIL.Image.fromarray(np.full((100, 100, 3), 255, dtype=np.uint8))
before = np.asarray(pil).copy()
imgviz.draw.arrow_(pil, yx1=(20, 20), yx2=(20, 80), fill=(255, 0, 0))
assert not np.array_equal(np.asarray(pil), before)


def test_arrow_rejects_non_pil_image() -> None:
img = np.full((100, 100, 3), 255, dtype=np.uint8)
with pytest.raises(TypeError, match="PIL.Image.Image"):
imgviz.draw.arrow_(img, yx1=(20, 20), yx2=(20, 80), fill=(255, 0, 0))


@pytest.mark.parametrize("point", ["yx1", "yx2"])
def test_arrow_rejects_bad_point_shape(point: str) -> None:
img = np.full((100, 100, 3), 255, dtype=np.uint8)
kwargs = {"yx1": (20, 20), "yx2": (20, 80), point: (1, 2, 3)}
with pytest.raises(ValueError, match="shape"):
imgviz.draw.arrow(img, fill=(255, 0, 0), **kwargs)
Loading