Skip to content

Commit 75a8531

Browse files
committed
Add a consistent rendering protocol.
This change provides a consistent API to render a htpy object as HTML or iterate over it. This commit introduces stream_chunks() which is identical with __iter__() but with a better name. With the introduction of Fragment, this commit makes render_node and iter_node redundant and they will be deprecated in another commit. More info: #86 (comment)
1 parent f7486d4 commit 75a8531

File tree

7 files changed

+92
-78
lines changed

7 files changed

+92
-78
lines changed

htpy/__init__.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,16 +133,29 @@ class ContextProvider(t.Generic[T]):
133133
def __iter__(self) -> Iterator[str]:
134134
return iter_node(self)
135135

136-
def __str__(self) -> str:
136+
def __str__(self) -> _Markup:
137137
return render_node(self)
138138

139+
__html__ = __str__
140+
141+
def stream_chunks(self) -> Iterator[str]:
142+
return iter_node(self)
143+
139144

140145
@dataclasses.dataclass(frozen=True)
141146
class ContextConsumer(t.Generic[T]):
142147
context: Context[T]
143148
debug_name: str
144149
func: Callable[[T], Node]
145150

151+
def __str__(self) -> _Markup:
152+
return render_node(self)
153+
154+
__html__ = __str__
155+
156+
def stream_chunks(self) -> Iterator[str]:
157+
return iter_node(self)
158+
146159

147160
class _NO_DEFAULT:
148161
pass
@@ -189,13 +202,14 @@ def _iter_node_context(x: Node, context_dict: dict[Context[t.Any], t.Any]) -> It
189202
elif isinstance(x, ContextProvider):
190203
yield from _iter_node_context(x.node, {**context_dict, x.context: x.value}) # pyright: ignore [reportUnknownMemberType]
191204
elif isinstance(x, ContextConsumer):
192-
context_value = context_dict.get(x.context, x.context.default)
205+
context_value = context_dict.get(x.context, x.context.default) # pyright: ignore
206+
193207
if context_value is _NO_DEFAULT:
194208
raise LookupError(
195-
f'Context value for "{x.context.name}" does not exist, '
209+
f'Context value for "{x.context.name}" does not exist, ' # pyright: ignore
196210
f"requested by {x.debug_name}()."
197211
)
198-
yield from _iter_node_context(x.func(context_value), context_dict)
212+
yield from _iter_node_context(x.func(context_value), context_dict) # pyright: ignore
199213
elif isinstance(x, Fragment):
200214
for node in x._nodes: # pyright: ignore [reportPrivateUsage]
201215
yield from _iter_node_context(node, context_dict)
@@ -282,6 +296,9 @@ def __call__(self: BaseElementSelf, *args: t.Any, **kwargs: t.Any) -> BaseElemen
282296
def __iter__(self) -> Iterator[str]:
283297
return self._iter_context({})
284298

299+
def stream_chunks(self) -> Iterator[str]:
300+
return self._iter_context({})
301+
285302
def _iter_context(self, ctx: dict[Context[t.Any], t.Any]) -> Iterator[str]:
286303
yield f"<{self._name}{self._attrs}>"
287304
yield from _iter_node_context(self._children, ctx)
@@ -363,11 +380,14 @@ def __init__(self, *nodes: Node) -> None:
363380
def __iter__(self) -> Iterator[str]:
364381
return iter_node(self)
365382

366-
def __str__(self) -> str:
383+
def __str__(self) -> _Markup:
367384
return render_node(self)
368385

369386
__html__ = __str__
370387

388+
def stream_chunks(self) -> Iterator[str]:
389+
return iter_node(self)
390+
371391

372392
def render_node(node: Node) -> _Markup:
373393
return _Markup("".join(iter_node(node)))
@@ -538,3 +558,9 @@ def __html__(self) -> str: ...
538558
| Callable
539559
| Iterable
540560
)
561+
562+
563+
class Renderable(t.Protocol):
564+
def __str__(self) -> _Markup: ...
565+
def __html__(self) -> _Markup: ...
566+
def stream_chunks(self) -> Iterator[str]: ...

tests/conftest.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import pytest
77

8-
from htpy import Node, iter_node
8+
import htpy as h
99

1010
if t.TYPE_CHECKING:
1111
from collections.abc import Callable, Generator
@@ -17,7 +17,7 @@ class Trace:
1717

1818

1919
RenderResult: t.TypeAlias = list[str | Trace]
20-
RenderFixture: t.TypeAlias = t.Callable[[Node], RenderResult]
20+
RenderFixture: t.TypeAlias = t.Callable[[h.Renderable], RenderResult]
2121
TraceFixture: t.TypeAlias = t.Callable[[str], None]
2222

2323

@@ -52,14 +52,14 @@ def func(description: str) -> None:
5252
def render(render_result: RenderResult) -> Generator[RenderFixture, None, None]:
5353
called = False
5454

55-
def func(node: Node) -> RenderResult:
55+
def func(renderable: h.Renderable) -> RenderResult:
5656
nonlocal called
5757

5858
if called:
5959
raise AssertionError("render() must only be called once per test")
6060

6161
called = True
62-
for chunk in iter_node(node):
62+
for chunk in renderable.stream_chunks():
6363
render_result.append(chunk)
6464

6565
return render_result

tests/test_comment.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,3 @@ def test_escape_three_dashes(render: RenderFixture) -> None:
2222

2323
def test_escape_four_dashes(render: RenderFixture) -> None:
2424
assert render(div[comment("foo----bar")]) == ["<div>", "<!-- foobar -->", "</div>"]
25-
26-
27-
def test_str() -> None:
28-
assert str(comment("foo")) == "<!-- foo -->"

tests/test_context.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import typing as t
44

5-
import markupsafe
65
import pytest
76

87
from htpy import Context, Fragment, Node, div
@@ -34,19 +33,6 @@ def test_context_provider(render: RenderFixture) -> None:
3433
assert render(result) == ["<div>", "Hello: c!", "</div>"]
3534

3635

37-
class Test_provider_outer_api:
38-
"""Ensure provider implements __iter__/__str__"""
39-
40-
def test_iter(self) -> None:
41-
result = letter_ctx.provider("c", div[display_letter("Hello")])
42-
assert list(result) == ["<div>", "Hello: c!", "</div>"]
43-
44-
def test_str(self) -> None:
45-
result = str(letter_ctx.provider("c", div[display_letter("Hello")]))
46-
assert result == "<div>Hello: c!</div>"
47-
assert isinstance(result, markupsafe.Markup)
48-
49-
5036
def test_no_default(render: RenderFixture) -> None:
5137
with pytest.raises(
5238
LookupError,

tests/test_element.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,6 @@ def test_void_element_repr() -> None:
2626
assert repr(htpy.hr("#a")) == """<VoidElement '<hr id="a">'>"""
2727

2828

29-
def test_markup_str() -> None:
30-
result = str(div(id="a"))
31-
assert isinstance(result, str)
32-
assert isinstance(result, markupsafe.Markup)
33-
assert result == '<div id="a"></div>'
34-
35-
3629
def test_element_type() -> None:
3730
assert_type(div, Element)
3831
assert isinstance(div, Element)
@@ -44,13 +37,6 @@ def test_element_type() -> None:
4437
assert isinstance(div()["a"], Element)
4538

4639

47-
def test_html_protocol() -> None:
48-
element = div["test"]
49-
result = element.__html__()
50-
assert result == "<div>test</div>"
51-
assert isinstance(result, markupsafe.Markup)
52-
53-
5440
def test_markupsafe_escape() -> None:
5541
result = markupsafe.escape(div["test"])
5642
assert result == "<div>test</div>"

tests/test_nodes.py

Lines changed: 0 additions & 37 deletions
This file was deleted.

tests/test_renderable.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from dataclasses import dataclass
2+
3+
import markupsafe
4+
import pytest
5+
6+
import htpy as h
7+
8+
example_ctx: h.Context[str] = h.Context("example_ctx", default="default!")
9+
10+
11+
@example_ctx.consumer
12+
def example_consumer(value: str) -> str:
13+
return value
14+
15+
16+
@dataclass(frozen=True)
17+
class RenderableTestCase:
18+
renderable: h.Renderable
19+
expected_str: str
20+
expected_chunks: list[str]
21+
22+
23+
cases = [
24+
RenderableTestCase(h.a, "<a></a>", ["<a>", "</a>"]),
25+
RenderableTestCase(h.img, "<img>", ["<img>"]),
26+
RenderableTestCase(example_ctx.provider("hi!", "stuff!"), "stuff!", ["stuff!"]),
27+
RenderableTestCase(example_consumer(), "default!", ["default!"]),
28+
RenderableTestCase(h.Fragment("fragment!"), "fragment!", ["fragment!"]),
29+
# comment() is a Fragment but test it anyways for completeness
30+
RenderableTestCase(h.comment("comment!"), "<!-- comment! -->", ["<!-- comment! -->"]),
31+
]
32+
33+
34+
@pytest.mark.parametrize("case", cases)
35+
def test_str(case: RenderableTestCase) -> None:
36+
result = str(case.renderable)
37+
assert isinstance(result, str)
38+
assert isinstance(result, markupsafe.Markup)
39+
assert result == case.expected_str
40+
41+
42+
@pytest.mark.parametrize("case", cases)
43+
def test_html(case: RenderableTestCase) -> None:
44+
result = case.renderable.__html__()
45+
assert isinstance(result, str)
46+
assert isinstance(result, markupsafe.Markup)
47+
assert result == case.expected_str
48+
49+
50+
@pytest.mark.parametrize("case", cases)
51+
def test_stream_chunks(case: RenderableTestCase) -> None:
52+
result = list(case.renderable.stream_chunks())
53+
54+
# Ensure we get str back, not markup.
55+
assert type(result[0]) is str
56+
57+
assert result == case.expected_chunks

0 commit comments

Comments
 (0)