Skip to content

Commit 9a1c70f

Browse files
authored
Merge pull request #170 from wkentaro/flags2rgb
feat: add flags2rgb composition
2 parents 3bdf377 + 05fb367 commit 9a1c70f

8 files changed

Lines changed: 406 additions & 4 deletions

File tree

README.md

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ pip install imgviz[all]
4747
```python
4848
# getting_started.py
4949

50+
import numpy as np
51+
5052
import imgviz
5153

5254
# sample data of rgb, depth, class label and instance masks
@@ -72,6 +74,15 @@ masks = data["masks"] == 1
7274
captions = [data["class_names"][l] for l in labels]
7375
maskviz = imgviz.instances2rgb(gray, masks=masks, labels=labels, captions=captions)
7476

77+
# per-instance flags as pie glyphs
78+
centers = np.array([np.argwhere(m).mean(axis=0) for m in masks])
79+
flags = np.column_stack(
80+
(masks.sum(axis=(1, 2)) < 7000, centers[:, 1] < rgb.shape[1] / 2)
81+
)
82+
flagviz = imgviz.flags2rgb(
83+
rgb, flags=flags, centers=centers, flag_names=["small", "left"], wedges="all"
84+
)
85+
7586
# tile instance masks
7687
insviz = [
7788
(rgb * m[:, :, None])[b[0] : b[2], b[1] : b[3]] for b, m in zip(bboxes, masks)
@@ -81,9 +92,9 @@ insviz = imgviz.resize(insviz, height=rgb.shape[0])
8192

8293
# tile visualization
8394
tiled = imgviz.tile(
84-
[rgb, depthviz, labelviz, maskviz, insviz],
95+
[rgb, depthviz, labelviz, maskviz, flagviz, insviz],
8596
row=1,
86-
col=5,
97+
col=6,
8798
border=(255, 255, 255),
8899
border_width=5,
89100
)
@@ -108,6 +119,10 @@ tiled = imgviz.tile(
108119
<td><pre><a href="examples/draw.py">examples/draw.py</a></pre></td>
109120
<td><img src="https://github.com/wkentaro/imgviz/raw/main/examples/assets/draw.jpg" width="37.79047619047619%" /></td>
110121
</tr>
122+
<tr>
123+
<td><pre><a href="examples/flags2rgb.py">examples/flags2rgb.py</a></pre></td>
124+
<td><img src="https://github.com/wkentaro/imgviz/raw/main/examples/assets/flags2rgb.jpg" width="77.93103448275862%" /></td>
125+
</tr>
111126
<tr>
112127
<td><pre><a href="examples/flow2rgb.py">examples/flow2rgb.py</a></pre></td>
113128
<td><img src="https://github.com/wkentaro/imgviz/raw/main/examples/assets/flow2rgb.jpg" width="52.21052631578947%" /></td>

assets/getting_started.jpg

54.3 KB
Loading

examples/assets/flags2rgb.jpg

55.8 KB
Loading

examples/flags2rgb.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#!/usr/bin/env python
2+
3+
import matplotlib.pyplot as plt
4+
import numpy as np
5+
6+
import imgviz
7+
8+
9+
def flags2rgb() -> None:
10+
data = imgviz.data.arc2017()
11+
12+
masks = data["masks"] == 1
13+
bboxes = data["bboxes"]
14+
centers = np.array([np.argwhere(mask).mean(axis=0) for mask in masks])
15+
flags = np.column_stack(
16+
(
17+
masks.sum(axis=(1, 2)) < 7000,
18+
(bboxes[:, 2] - bboxes[:, 0]) > (bboxes[:, 3] - bboxes[:, 1]),
19+
centers[:, 1] < data["rgb"].shape[1] / 2,
20+
)
21+
)
22+
flag_names = ["small", "tall", "left"]
23+
24+
flagviz1 = imgviz.flags2rgb(
25+
data["rgb"], flags=flags, centers=centers, flag_names=flag_names
26+
)
27+
flagviz2 = imgviz.flags2rgb(
28+
data["rgb"], flags=flags, centers=centers, flag_names=flag_names, wedges="all"
29+
)
30+
31+
plt.figure(dpi=200)
32+
33+
plt.subplot(131)
34+
plt.title("rgb")
35+
plt.imshow(data["rgb"])
36+
plt.axis("off")
37+
38+
plt.subplot(132)
39+
plt.title('flags (wedges="on")')
40+
plt.imshow(flagviz1)
41+
plt.axis("off")
42+
43+
plt.subplot(133)
44+
plt.title('flags (wedges="all")')
45+
plt.imshow(flagviz2)
46+
plt.axis("off")
47+
48+
49+
if __name__ == "__main__":
50+
from _base import run_example
51+
52+
run_example(flags2rgb)

getting_started.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
# -----------------------------------------------------------------------------
1010
# GETTING_STARTED {{
11+
import numpy as np
12+
1113
import imgviz
1214

1315
# sample data of rgb, depth, class label and instance masks
@@ -33,6 +35,15 @@
3335
captions = [data["class_names"][l] for l in labels]
3436
maskviz = imgviz.instances2rgb(gray, masks=masks, labels=labels, captions=captions)
3537

38+
# per-instance flags as pie glyphs
39+
centers = np.array([np.argwhere(m).mean(axis=0) for m in masks])
40+
flags = np.column_stack(
41+
(masks.sum(axis=(1, 2)) < 7000, centers[:, 1] < rgb.shape[1] / 2)
42+
)
43+
flagviz = imgviz.flags2rgb(
44+
rgb, flags=flags, centers=centers, flag_names=["small", "left"], wedges="all"
45+
)
46+
3647
# tile instance masks
3748
insviz = [
3849
(rgb * m[:, :, None])[b[0] : b[2], b[1] : b[3]] for b, m in zip(bboxes, masks)
@@ -42,9 +53,9 @@
4253

4354
# tile visualization
4455
tiled = imgviz.tile(
45-
[rgb, depthviz, labelviz, maskviz, insviz],
56+
[rgb, depthviz, labelviz, maskviz, flagviz, insviz],
4657
row=1,
47-
col=5,
58+
col=6,
4859
border=(255, 255, 255),
4960
border_width=5,
5061
)

imgviz/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from ._diff import diff
2626
from ._dtype import bool2ubyte
2727
from ._dtype import float2ubyte
28+
from ._flags import flags2rgb
2829
from ._flow import Flow2Rgb
2930
from ._flow import flow2rgb
3031
from ._instances import instances2rgb

imgviz/_flags.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Sequence
4+
from typing import Final
5+
from typing import Literal
6+
7+
import numpy as np
8+
from numpy.typing import NDArray
9+
10+
from . import _color
11+
from . import _label
12+
from . import _utils
13+
from . import components
14+
from . import draw as draw_module
15+
16+
17+
def flags2rgb(
18+
image: NDArray[np.uint8],
19+
flags: NDArray[np.bool_],
20+
centers: NDArray[np.floating],
21+
flag_names: Sequence[str],
22+
diameter: float = 30,
23+
flag_colors: NDArray[np.uint8] | None = None,
24+
wedges: Literal["on", "all"] = "on",
25+
font_size: int = 25,
26+
font_path: str | None = None,
27+
loc: Literal["lt", "rt", "lb", "rb"] = "rb",
28+
) -> NDArray[np.uint8]:
29+
"""Visualize per-instance boolean flags as pie glyphs with a legend.
30+
31+
Wedge angle encodes flag identity, never quantity. With ``wedges="on"``,
32+
color identifies the flag and angle is just packing: only active flags get
33+
wedges, packed clockwise in flag-index order (zero active flags draws no
34+
pie, one draws a solid disc). With ``wedges="all"``, both color and angle
35+
identify the flag: every flag keeps a fixed wedge at a fixed angle, drawn
36+
light gray when off, and the legend gains one synthetic ("off", gray)
37+
entry. Around 5 wedges stay legible at the default diameter.
38+
39+
Args:
40+
image: RGB image with shape (H, W, 3).
41+
flags: Boolean flags with shape (N, F).
42+
centers: Pie centers with shape (N, 2). [(cy, cx), ...]
43+
flag_names: Flag names with length F.
44+
diameter: Diameter of each pie in pixels.
45+
flag_colors: Color for each flag with shape (F, 3). By default,
46+
:func:`~imgviz.label_colormap` rows 1 to F are used.
47+
wedges: Which flags get wedges ('on', 'all').
48+
font_size: Font size of the legend.
49+
font_path: Font path.
50+
loc: Location of legend ('lt', 'rt', 'lb', 'rb').
51+
52+
Returns:
53+
Visualized image with shape (H, W, 3).
54+
"""
55+
OFF_COLOR: Final = (200, 200, 200)
56+
57+
if not isinstance(image, np.ndarray):
58+
raise TypeError(f"image must be a numpy array, but got {type(image).__name__}")
59+
if image.dtype != np.uint8:
60+
raise ValueError(f"image dtype must be np.uint8, but got {image.dtype}")
61+
if image.ndim == 2:
62+
image = _color.gray2rgb(image)
63+
if image.ndim != 3:
64+
raise ValueError(f"image must be 2 or 3 dimensional, but got {image.ndim}")
65+
66+
if not isinstance(flags, np.ndarray):
67+
raise TypeError(f"flags must be a numpy array, but got {type(flags).__name__}")
68+
if flags.dtype != bool:
69+
raise ValueError(f"flags dtype must be bool, but got {flags.dtype}")
70+
if flags.ndim != 2:
71+
raise ValueError(f"flags must be 2 dimensional (N, F), but got {flags.ndim}")
72+
n_instances, n_flags = flags.shape
73+
74+
centers = np.asarray(centers)
75+
if centers.shape != (n_instances, 2):
76+
raise ValueError(
77+
f"centers shape must be (N, 2) matching {n_instances} instances, "
78+
f"but got {centers.shape}"
79+
)
80+
81+
if len(flag_names) != n_flags:
82+
raise ValueError(
83+
f"flag_names must have one name per flag, "
84+
f"but got {len(flag_names)=}, {n_flags=}"
85+
)
86+
87+
if flag_colors is None:
88+
flag_colors = _label.label_colormap()[1 : n_flags + 1]
89+
if flag_colors.shape != (n_flags, 3):
90+
raise ValueError(
91+
f"flag_colors shape must be ({n_flags}, 3), but got {flag_colors.shape}"
92+
)
93+
94+
if wedges not in ("on", "all"):
95+
raise ValueError(f"unsupported wedges: {wedges}")
96+
97+
dst = _utils.numpy_to_pillow(image)
98+
for i in range(n_instances):
99+
fills: list[draw_module.Ink]
100+
if wedges == "on":
101+
fills = [flag_colors[j] for j in range(n_flags) if flags[i, j]]
102+
else:
103+
fills = [
104+
flag_colors[j] if flags[i, j] else OFF_COLOR for j in range(n_flags)
105+
]
106+
if not fills:
107+
continue
108+
cy, cx = centers[i]
109+
draw_module.pie_(
110+
image=dst,
111+
center=(cy, cx),
112+
diameter=diameter,
113+
fills=fills,
114+
outline=(255, 255, 255),
115+
width=1,
116+
)
117+
118+
items: list[components.LegendItem] = list(zip(flag_names, flag_colors))
119+
if wedges == "all":
120+
items.append(("off", OFF_COLOR))
121+
components.legend_(
122+
image=dst, items=items, font_size=font_size, font_path=font_path, loc=loc
123+
)
124+
return _utils.pillow_to_numpy(dst)

0 commit comments

Comments
 (0)