diff --git a/pygexml/image.py b/pygexml/image.py new file mode 100644 index 0000000..8fcd7e4 --- /dev/null +++ b/pygexml/image.py @@ -0,0 +1,9 @@ +from dataclasses import dataclass +from dataclasses_json import DataClassJsonMixin + + +@dataclass +class Image(DataClassJsonMixin): + filename: str + width: int | None + height: int | None diff --git a/pygexml/page.py b/pygexml/page.py index 116b1ce..ff05343 100644 --- a/pygexml/page.py +++ b/pygexml/page.py @@ -9,6 +9,7 @@ from lxml.etree import _Element as Element, QName from .geometry import Point, Box, Polygon, GeometryError +from .image import Image def find_child(element: Element, name: str) -> Element | None: @@ -225,7 +226,7 @@ def all_words(self) -> Iterable[str]: @dataclass class Page(DataClassJsonMixin): - image_filename: str + image: Image regions: dict[ID, TextRegion] @classmethod @@ -234,12 +235,24 @@ def from_xml(cls, element: Element) -> "Page": raise PageXMLError("Wrong element given") if "imageFilename" not in element.attrib: - raise PageXMLError("No filename found") + raise PageXMLError("No image filename found") regions = find_children(element, "TextRegion") return Page( - image_filename=str(element.attrib["imageFilename"]), + image=Image( + filename=str(element.attrib["imageFilename"]), + width=( + int(element.attrib["imageWidth"]) + if "imageWidth" in element.attrib + else None + ), + height=( + int(element.attrib["imageHeight"]) + if "imageHeight" in element.attrib + else None + ), + ), regions={ tr.id: tr for tr in (TextRegion.from_xml(region) for region in regions) }, @@ -289,8 +302,22 @@ def from_alto(cls, element: Element) -> "Page": text_blocks = find_children(printspace_element, "TextBlock") + # ALTO allows for float values, but we convert to int for consistency with PAGE XML + image_width = ( + int(float(page_element.attrib["WIDTH"])) + if "WIDTH" in page_element.attrib + else None + ) + image_height = ( + int(float(page_element.attrib["HEIGHT"])) + if "HEIGHT" in page_element.attrib + else None + ) + return Page( - image_filename=image_filename, + image=Image( + filename=image_filename, width=image_width, height=image_height + ), regions={ tb.id: tb for tb in (TextRegion.from_alto(tb) for tb in text_blocks) }, diff --git a/pygexml/strategies.py b/pygexml/strategies.py index d2964ce..851b8ee 100644 --- a/pygexml/strategies.py +++ b/pygexml/strategies.py @@ -6,6 +6,7 @@ import hypothesis.strategies as st from pygexml.geometry import Point, Box, Polygon +from pygexml.image import Image from pygexml.page import Coords, Page, TextLine, TextRegion st_points = st.builds(Point, x=st.integers(min_value=0), y=st.integers(min_value=0)) @@ -60,10 +61,32 @@ def st_simple_text(**kwargs): ), ) +st_images = st.builds( + Image, + filename=st_simple_text(), + width=st.one_of(st.none(), st.integers(min_value=1)), + height=st.one_of(st.none(), st.integers(min_value=1)), +) + +st_images_with_dimensions = st.builds( + Image, + filename=st_simple_text(), + width=st.integers(min_value=1), + height=st.integers(min_value=1), +) + @st.composite def st_pages(draw): - image_filename = draw(st_simple_text()) + image = draw(st_images) + regions = {tr.id: tr for tr in draw(st.lists(st_text_regions))} + page = Page(image=image, regions=regions) + return page + + +@st.composite +def st_pages_with_dimensions(draw): + image = draw(st_images_with_dimensions) regions = {tr.id: tr for tr in draw(st.lists(st_text_regions))} - page = Page(image_filename=image_filename, regions=regions) + page = Page(image=image, regions=regions) return page diff --git a/pygexml/svg.py b/pygexml/svg.py new file mode 100644 index 0000000..2467d50 --- /dev/null +++ b/pygexml/svg.py @@ -0,0 +1,133 @@ +from lxml import etree +from lxml.etree import _Element as Element + +from .page import Page, TextRegion, TextLine + +SVG_NS = "http://www.w3.org/2000/svg" +XLINK_NS = "http://www.w3.org/1999/xlink" + + +class SVGError(Exception): + pass + + +def _coords_path(coords_str: str) -> str: + return f"M {coords_str} Z" + + +def _baseline_path_d(line: TextLine) -> str: + box = line.coords.polygon.bounding_box() + y_baseline = box.top_left.y + (box.bottom_right.y - box.top_left.y) * 2 // 3 + return f"M {box.top_left.x},{y_baseline} {box.bottom_right.x},{y_baseline}" + + +def _line_to_svg(line: TextLine) -> Element: + g = etree.Element(f"{{{SVG_NS}}}g", attrib={"id": line.id, "class": "TextLine"}) + etree.SubElement( + g, + f"{{{SVG_NS}}}path", + attrib={ + "d": _coords_path(str(line.coords)), + "class": "Coords", + }, + ) + etree.SubElement( + g, + f"{{{SVG_NS}}}path", + attrib={ + "id": f"bl-{line.id}", + "d": _baseline_path_d(line), + "class": "Baseline", + }, + ) + if line.text: + text = etree.SubElement(g, f"{{{SVG_NS}}}text") + text_path = etree.SubElement( + text, + f"{{{SVG_NS}}}textPath", + attrib={f"{{{XLINK_NS}}}href": f"#bl-{line.id}", "textLength": "100%"}, + ) + tspan = etree.SubElement( + text_path, f"{{{SVG_NS}}}tspan", attrib={"class": "Text"} + ) + tspan.text = line.text + return g + + +def _region_to_svg(region: TextRegion) -> Element: + g = etree.Element(f"{{{SVG_NS}}}g", attrib={"id": region.id, "class": "TextRegion"}) + etree.SubElement( + g, + f"{{{SVG_NS}}}path", + attrib={ + "d": _coords_path(str(region.coords)), + "class": "Coords", + }, + ) + for line in region.textlines.values(): + g.append(_line_to_svg(line)) + return g + + +def _default_style(width: int, height: int) -> Element: + font_size = max(width, height) // 60 + style = etree.Element(f"{{{SVG_NS}}}style") + style.text = ( + f"\n" + f" path.Coords {{ fill: rgba(100,160,255,0.12); stroke: steelblue; stroke-width: {max(width, height) // 1500}; }}\n" + f" path.Baseline {{ stroke: #e74c3c; stroke-width: {max(width, height) // 2000}; fill: none; }}\n" + f" .TextLine text {{ font-size: {font_size}px; font-family: serif; fill: #000; opacity: 0; transition: opacity 0.15s; }}\n" + f" .TextLine:hover text {{ opacity: 1; }}\n" + f" " + ) + return style + + +def page_to_svg(page: Page, include_style: bool = True) -> Element: + if page.image.width is None: + raise SVGError("Image width is required for SVG generation") + if page.image.height is None: + raise SVGError("Image height is required for SVG generation") + + width = page.image.width + height = page.image.height + + svg = etree.Element( + f"{{{SVG_NS}}}svg", + # the official way to do it although stubs are wrong: + nsmap={None: SVG_NS, "xlink": XLINK_NS}, # type: ignore + attrib={ + "width": str(width), + "height": str(height), + "viewBox": f"0 0 {width} {height}", + }, + ) + + etree.SubElement( + svg, + f"{{{SVG_NS}}}image", + attrib={ + "x": "0", + "y": "0", + "width": str(width), + "height": str(height), + f"{{{XLINK_NS}}}href": page.image.filename, + "preserveAspectRatio": "none", + }, + ) + + if include_style: + svg.insert(0, _default_style(width, height)) + + for region in page.regions.values(): + svg.append(_region_to_svg(region)) + + return svg + + +def page_to_svg_string(page: Page, include_style: bool = True) -> str: + return etree.tostring( + page_to_svg(page, include_style=include_style), + encoding="unicode", + pretty_print=True, + ) diff --git a/test/test_image.py b/test/test_image.py new file mode 100644 index 0000000..a2a0c9b --- /dev/null +++ b/test/test_image.py @@ -0,0 +1,29 @@ +from hypothesis import given +import hypothesis.strategies as st + +from pygexml.strategies import st_images +from pygexml.image import Image + + +def test_image_example() -> None: + image = Image(filename="a.jpg", width=800, height=600) + assert image.filename == "a.jpg" + assert image.width == 800 + assert image.height == 600 + + +@given( + st.text(), + st.one_of(st.none(), st.integers(min_value=1)), + st.one_of(st.none(), st.integers(min_value=1)), +) +def test_image_arbitrary(filename: str, width: int, height: int) -> None: + image = Image(filename=filename, width=width, height=height) + assert image.filename == filename + assert image.width == width + assert image.height == height + + +@given(st_images) +def test_image_serialization_roundtrip_arbitrary(image: Image) -> None: + assert Image.from_dict(image.to_dict()) == image diff --git a/test/test_page.py b/test/test_page.py index 23745e0..1c6df8e 100644 --- a/test/test_page.py +++ b/test/test_page.py @@ -8,6 +8,7 @@ from pygexml.strategies import * from pygexml.geometry import Point, Box, Polygon +from pygexml.image import Image from pygexml.page import Coords, ID, TextLine, TextRegion, Page ############## Tests for Coords #################### @@ -410,7 +411,7 @@ def test_page_from_element_example() -> None: """)) - assert pa.image_filename == "7895328.jpg" + assert pa.image == Image(filename="7895328.jpg", width=4279, height=5315) assert pa.regions == { "tr-1": TextRegion( id="tr-1", @@ -454,10 +455,22 @@ def test_page_wrong_element() -> None: def test_page_no_filename() -> None: xml = "" - with pytest.raises(Exception, match="No filename found"): + with pytest.raises(Exception, match="No image filename found"): Page.from_xml(etree.fromstring(xml)) +def test_page_no_image_width() -> None: + xml = """""" + pa = Page.from_xml(etree.fromstring(xml)) + assert pa.image == Image(filename="a.jpg", width=None, height=600) + + +def test_page_no_image_height() -> None: + xml = """""" + pa = Page.from_xml(etree.fromstring(xml)) + assert pa.image == Image(filename="a.jpg", width=800, height=None) + + def test_page_from_string() -> None: pa = Page.from_xml_string(""" @@ -479,7 +492,7 @@ def test_page_from_string() -> None: """) # use default PageXML namespace - assert pa.image_filename == "a.jpg" + assert pa.image == Image(filename="a.jpg", width=4217, height=1742) assert pa.regions == { "b": TextRegion( id="b", @@ -515,7 +528,7 @@ def test_from_xml_file_example(tmp_path: Path) -> None: xml_filepath.write_text(content, encoding="utf-8") result = Page.from_xml_file(xml_filepath) - assert result.image_filename == "a.jpg" + assert result.image == Image(filename="a.jpg", width=4217, height=1742) assert result.regions == { "b": TextRegion( id="b", @@ -562,7 +575,7 @@ def test_page_from_alto_example() -> None: """)) - assert pa.image_filename == "a.jpg" + assert pa.image == Image(filename="a.jpg", width=None, height=None) assert pa.regions == { "tr-1": TextRegion( id="tr-1", @@ -589,6 +602,30 @@ def test_page_from_alto_example() -> None: } +def test_page_alto_with_dimensions() -> None: + pa = Page.from_alto(etree.fromstring(""" + + + + a.jpg + + + + + + + + + + + + + + + """)) + assert pa.image == Image(filename="a.jpg", width=800, height=600) + + def test_page_alto_wrong_element() -> None: with pytest.raises(Exception, match="Wrong element given"): Page.from_alto(etree.fromstring("!!!")) @@ -687,7 +724,7 @@ def test_page_alto_from_string() -> None: """ page = Page.from_alto_string(alto_string) - assert page.image_filename == "a.jpg" + assert page.image == Image(filename="a.jpg", width=None, height=None) assert page.regions == { "tr-1": TextRegion( id="tr-1", @@ -730,7 +767,7 @@ def test_page_alto_from_file_example(tmp_path: Path) -> None: filepath.write_text(alto_string, encoding="utf-8") result = Page.from_alto_file(filepath) - assert result.image_filename == "a.jpg" + assert result.image == Image(filename="a.jpg", width=None, height=None) assert result.regions == { "tr-1": TextRegion( id="tr-1", @@ -769,7 +806,7 @@ def test_page_region_lookup_not_found(id: str, page: Page) -> None: def test_page_all_text_and_words() -> None: pa = Page( - image_filename="a", + image=Image(filename="a", width=None, height=None), regions={ "a": TextRegion( id="a", @@ -807,7 +844,7 @@ def test_page_all_arbitrary_text_and_words(page: Page) -> None: def test_page_serialization_roundtrip() -> None: pa = Page( - image_filename="a.jpg", + image=Image(filename="a.jpg", width=1920, height=1080), regions={ "tr-1": TextRegion( id="tr-1", diff --git a/test/test_svg.py b/test/test_svg.py new file mode 100644 index 0000000..6510e04 --- /dev/null +++ b/test/test_svg.py @@ -0,0 +1,221 @@ +from typing import Any + +import pytest +from hypothesis import given +from lxml import etree +from lxml.etree import _Element as Element + +from pygexml.strategies import st_pages_with_dimensions +from pygexml.image import Image +from pygexml.page import Coords, TextLine, TextRegion, Page +from pygexml.svg import SVGError, page_to_svg, page_to_svg_string + +SVG_NS = "http://www.w3.org/2000/svg" +XLINK_NS = "http://www.w3.org/1999/xlink" + + +def make_page(**kwargs: Any) -> Page: + defaults: dict[str, Any] = dict( + image=Image(filename="a.jpg", width=800, height=600), + regions={}, + ) + return Page(**(defaults | kwargs)) + + +############## Tests for page_to_svg #################### + + +def test_page_to_svg_raises_without_image_width() -> None: + page = make_page(image=Image(filename="a.jpg", width=None, height=600)) + with pytest.raises(SVGError, match="width"): + page_to_svg(page) + + +def test_page_to_svg_raises_without_image_height() -> None: + page = make_page(image=Image(filename="a.jpg", width=800, height=None)) + with pytest.raises(SVGError, match="height"): + page_to_svg(page) + + +def test_page_to_svg_returns_svg_element() -> None: + svg = page_to_svg(make_page()) + assert isinstance(svg, Element) + assert svg.tag == f"{{{SVG_NS}}}svg" + + +def test_page_to_svg_dimensions() -> None: + svg = page_to_svg(make_page(image=Image(filename="a.jpg", width=800, height=600))) + assert svg.attrib["width"] == "800" + assert svg.attrib["height"] == "600" + assert svg.attrib["viewBox"] == "0 0 800 600" + + +def test_page_to_svg_image_element() -> None: + svg = page_to_svg(make_page(image=Image(filename="a.jpg", width=800, height=600))) + images = svg.findall(f"{{{SVG_NS}}}image") + assert len(images) == 1 + img = images[0] + assert img.attrib[f"{{{XLINK_NS}}}href"] == "a.jpg" + assert img.attrib["width"] == "800" + assert img.attrib["height"] == "600" + + +def test_page_to_svg_text_regions() -> None: + page = make_page( + regions={ + "r1": TextRegion( + id="r1", + coords=Coords.parse("0,0 10,0 10,10 0,10"), + textlines={ + "l1": TextLine( + id="l1", coords=Coords.parse("1,1 9,1 9,9 1,9"), text="foo" + ), + }, + ), + } + ) + svg = page_to_svg(page) + groups = svg.findall(f"{{{SVG_NS}}}g") + assert len(groups) == 1 + region_g = groups[0] + assert region_g.attrib["id"] == "r1" + assert region_g.attrib["class"] == "TextRegion" + line_groups = region_g.findall(f"{{{SVG_NS}}}g") + assert len(line_groups) == 1 + assert line_groups[0].attrib["id"] == "l1" + assert line_groups[0].attrib["class"] == "TextLine" + + +def test_page_to_svg_coords_path() -> None: + page = make_page( + regions={ + "r1": TextRegion( + id="r1", + coords=Coords.parse("0,0 10,0 10,10 0,10"), + textlines={}, + ), + } + ) + svg = page_to_svg(page) + region_g = svg.find(f"{{{SVG_NS}}}g") + assert region_g is not None + path = region_g.find(f"{{{SVG_NS}}}path") + assert path is not None + assert path.attrib["d"] == "M 0,0 10,0 10,10 0,10 Z" + assert path.attrib["class"] == "Coords" + + +############## Tests for page_to_svg_string #################### + + +def test_page_to_svg_string_example() -> None: + result = page_to_svg_string( + make_page(image=Image(filename="a.jpg", width=800, height=600)) + ) + assert isinstance(result, str) + assert 'xmlns="http://www.w3.org/2000/svg"' in result + assert 'xlink:href="a.jpg"' in result + assert 'viewBox="0 0 800 600"' in result + + +def test_page_to_svg_string_is_valid_xml() -> None: + result = page_to_svg_string(make_page()) + root = etree.fromstring(result.encode("utf-8")) + assert root.tag == f"{{{SVG_NS}}}svg" + + +def test_page_to_svg_string_raises_without_dimensions() -> None: + page = make_page(image=Image(filename="a.jpg", width=None, height=None)) + with pytest.raises(SVGError): + page_to_svg_string(page) + + +@given(st_pages_with_dimensions()) +def test_page_to_svg_string_arbitrary_with_dimensions(page: Page) -> None: + result = page_to_svg_string(page) + root = etree.fromstring(result.encode("utf-8")) + assert root.tag == f"{{{SVG_NS}}}svg" + + +def test_page_to_svg_includes_style_by_default() -> None: + svg = page_to_svg(make_page()) + assert svg.find(f"{{{SVG_NS}}}style") is not None + + +def test_page_to_svg_style_contains_hover_rule() -> None: + svg = page_to_svg(make_page()) + style = svg.find(f"{{{SVG_NS}}}style") + assert style is not None + assert ".TextLine:hover" in (style.text or "") + + +def test_page_to_svg_no_style_when_disabled() -> None: + svg = page_to_svg(make_page(), include_style=False) + assert svg.find(f"{{{SVG_NS}}}style") is None + + +def test_page_to_svg_string_no_style_when_disabled() -> None: + result = page_to_svg_string(make_page(), include_style=False) + assert " Page: + return make_page( + regions={ + "r1": TextRegion( + id="r1", + coords=Coords.parse("0,0 10,0 10,10 0,10"), + textlines={ + "l1": TextLine( + id="l1", coords=Coords.parse("1,1 9,1 9,9 1,9"), text=text + ), + }, + ), + } + ) + + +def get_line_g(page: Page) -> Element: + svg = page_to_svg(page) + region_g = svg.find(f"{{{SVG_NS}}}g") + assert region_g is not None + line_g = region_g.find(f"{{{SVG_NS}}}g") + assert line_g is not None + return line_g + + +def test_page_to_svg_line_has_baseline_path() -> None: + line_g = get_line_g(make_page_with_line()) + paths = line_g.findall(f"{{{SVG_NS}}}path") + assert len(paths) == 2 + baseline = next(p for p in paths if p.attrib.get("class") == "Baseline") + assert baseline.attrib["id"] == "bl-l1" + + +def test_page_to_svg_line_baseline_from_bounding_box() -> None: + # coords "1,1 9,1 9,9 1,9": x=[1,9], y=[1,9], height=8, y_baseline=1+8*2//3=6 + line_g = get_line_g(make_page_with_line()) + paths = line_g.findall(f"{{{SVG_NS}}}path") + baseline = next(p for p in paths if p.attrib.get("class") == "Baseline") + assert baseline.attrib["d"] == "M 1,6 9,6" + + +def test_page_to_svg_line_text_content() -> None: + line_g = get_line_g(make_page_with_line("Hallo Welt")) + text = line_g.find(f"{{{SVG_NS}}}text") + assert text is not None + text_path = text.find(f"{{{SVG_NS}}}textPath") + assert text_path is not None + assert text_path.attrib[f"{{{XLINK_NS}}}href"] == "#bl-l1" + tspan = text_path.find(f"{{{SVG_NS}}}tspan") + assert tspan is not None + assert tspan.text == "Hallo Welt" + assert tspan.attrib["class"] == "Text" + + +def test_page_to_svg_line_no_text_element_when_empty() -> None: + line_g = get_line_g(make_page_with_line("")) + assert line_g.find(f"{{{SVG_NS}}}text") is None