Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions tests/unit/_flow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@
import imgviz


@pytest.fixture
def flow_3_4() -> NDArray[np.float32]:
flow = np.zeros((4, 4, 2), dtype=np.float32)
flow[:, :, 0] = 3.0
flow[:, :, 1] = 4.0
return flow


@pytest.mark.parametrize(
"use_class", [pytest.param(False, id="flow2rgb"), pytest.param(True, id="Flow2Rgb")]
)
Expand Down Expand Up @@ -44,3 +52,56 @@ def test_flow2rgb_handles_negative_zero_v() -> None:
assert flow_viz.dtype == np.uint8
assert flow_viz.shape == (4, 4, 3)
np.testing.assert_array_equal(flow_viz, imgviz.flow2rgb(flow_pos))


def test_flow2rgb_rejects_non_3d() -> None:
flow = np.zeros((4, 4), dtype=np.float32)
with pytest.raises(ValueError, match="flow must be 3 dimensional"):
imgviz.flow2rgb(flow)


def test_flow2rgb_rejects_wrong_channel_count() -> None:
flow = np.zeros((4, 4, 3), dtype=np.float32)
with pytest.raises(ValueError, match=r"flow must have shape \(H, W, 2\)"):
imgviz.flow2rgb(flow)


def test_flow2rgb_rejects_non_float_dtype() -> None:
flow = np.zeros((4, 4, 2), dtype=np.uint8)
with pytest.raises(ValueError, match="flow dtype must be float"):
imgviz.flow2rgb(flow)


def test_flow2rgb_return_max_reports_norm(flow_3_4: NDArray[np.float32]) -> None:
viz, max_norm = imgviz.flow2rgb(flow_3_4, return_max=True)

assert max_norm == pytest.approx(5.0)
np.testing.assert_array_equal(viz, imgviz.flow2rgb(flow_3_4, max_norm=max_norm))


def test_Flow2Rgb_caches_max_norm_after_first_call(
flow_3_4: NDArray[np.float32],
) -> None:
converter = imgviz.Flow2Rgb()
assert converter.max_norm is None

converter(flow_3_4)
assert converter.max_norm == pytest.approx(5.0)

flow_b = np.zeros((4, 4, 2), dtype=np.float32)
flow_b[:, :, 0] = 6.0
flow_b[:, :, 1] = 8.0
viz_b = converter(flow_b)

assert converter.max_norm == pytest.approx(5.0)
np.testing.assert_array_equal(viz_b, imgviz.flow2rgb(flow_b, max_norm=5.0))


def test_Flow2Rgb_keeps_explicit_max_norm(flow_3_4: NDArray[np.float32]) -> None:
converter = imgviz.Flow2Rgb(max_norm=2.0)
assert converter.max_norm == pytest.approx(2.0)

viz = converter(flow_3_4)

assert converter.max_norm == pytest.approx(2.0)
np.testing.assert_array_equal(viz, imgviz.flow2rgb(flow_3_4, max_norm=2.0))