From 5251e9f83e572462ed2f0cb33c1d08963a09d046 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Wed, 16 Oct 2024 05:27:50 +0000 Subject: [PATCH] Improve: Stress-testing translations Extends the testing suite with several fuzzy tests for binary string tranlations from Albumentations core image processing library. https://github.com/albumentations-team/albucore/blob/c0e924b9d2e74f787be413c33b05f42575cdc2c7/tests/test_lut.py#L25-L33 Co-authored-by: Vladimir Iglovikov <5481618+ternaus@users.noreply.github.com --- scripts/test.py | 59 ++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 58 insertions(+), 1 deletion(-) diff --git a/scripts/test.py b/scripts/test.py index e6c8f64b..8f7859c0 100644 --- a/scripts/test.py +++ b/scripts/test.py @@ -1,6 +1,6 @@ from random import choice, randint from string import ascii_lowercase -from typing import Optional +from typing import Optional, Sequence, Dict import tempfile import os @@ -718,6 +718,63 @@ def test_alignment_score_random(first_length: int, second_length: int): ) == -baseline_edit_distance(a, b) +def baseline_translate(body: str, lut: Sequence) -> str: + return "".join([chr(lut[ord(c)]) for c in body]) + + +def translation_table_to_dict(lut: Sequence) -> Dict[str, str]: + return {chr(i): chr(lut[i]) for i in range(256)} + + +@pytest.mark.skipif(not numpy_available, reason="NumPy is not installed") +@pytest.mark.parametrize("length", range(1, 300)) +def test_translations(length: int): + + map_identity = np.arange(256, dtype=np.uint8) + map_invert = np.arange(255, -1, -1, dtype=np.uint8) + map_threshold = np.where(np.arange(256) > 127, 255, 0).astype(np.uint8) + dict_identity = translation_table_to_dict(map_identity) + dict_invert = translation_table_to_dict(map_invert) + dict_threshold = translation_table_to_dict(map_threshold) + view_identity = memoryview(map_identity) + view_invert = memoryview(map_invert) + view_threshold = memoryview(map_threshold) + + body = get_random_string(length=length) + body_bytes = body.encode("utf-8") + + # Check mapping strings and byte-strings into new strings + assert sz.translate(body, view_identity) == body + assert sz.translate(body_bytes, view_identity) == body_bytes + assert sz.translate(body_bytes, view_identity) == body_bytes.translate( + view_identity + ) + assert sz.translate(body_bytes, view_invert) == body_bytes.translate(view_invert) + assert sz.translate(body_bytes, view_threshold) == body_bytes.translate( + view_threshold + ) + + # Check in-place translations + body_after_identity = str(body) + sz.translate(body, view_identity, inplace=True) + assert body == body_after_identity.translate(dict_identity) + body_after_invert = str(body) + sz.translate(body, view_invert, inplace=True) + assert body == body_after_invert.translate(dict_invert) + body_after_threshold = str(body) + sz.translate(body, view_threshold, inplace=True) + assert body == body_after_threshold.translate(dict_threshold) + + +@pytest.mark.repeat(3) +@pytest.mark.parametrize("length", range(1, 300)) +@pytest.mark.skipif(not numpy_available, reason="NumPy is not installed") +def test_translations_random(length: int): + body = get_random_string(length=length) + lut = np.random.randint(0, 256, size=256, dtype=np.uint8) + assert sz.translate(body, memoryview(lut)) == baseline_translate(body, lut) + + @pytest.mark.parametrize("list_length", [10, 20, 30, 40, 50]) @pytest.mark.parametrize("part_length", [5, 10]) @pytest.mark.parametrize("variability", [2, 3])