@@ -131,10 +131,18 @@ class ContextProvider(t.Generic[T]):
131131 node : Node
132132
133133 def __iter__ (self ) -> Iterator [str ]:
134- return iter_node (self )
134+ return _stream_chunks (self , {} )
135135
136- def __str__ (self ) -> str :
137- return render_node (self )
136+ def __str__ (self ) -> _Markup :
137+ return _as_markup (self )
138+
139+ __html__ = __str__
140+
141+ def stream_chunks (self ) -> Iterator [str ]:
142+ return _stream_chunks (self , {})
143+
144+ def encode (self , encoding : str = "utf-8" , errors : str = "strict" ) -> bytes :
145+ return str (self ).encode (encoding , errors )
138146
139147
140148@dataclasses .dataclass (frozen = True )
@@ -143,6 +151,17 @@ class ContextConsumer(t.Generic[T]):
143151 debug_name : str
144152 func : Callable [[T ], Node ]
145153
154+ def __str__ (self ) -> _Markup :
155+ return _as_markup (self )
156+
157+ __html__ = __str__
158+
159+ def stream_chunks (self ) -> Iterator [str ]:
160+ return _stream_chunks (self , {})
161+
162+ def encode (self , encoding : str = "utf-8" , errors : str = "strict" ) -> bytes :
163+ return str (self ).encode (encoding , errors )
164+
146165
147166class _NO_DEFAULT :
148167 pass
@@ -168,10 +187,10 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> ContextConsumer[T]:
168187
169188
170189def iter_node (x : Node ) -> Iterator [str ]:
171- return _iter_node_context ( x , {} )
190+ return fragment [ x ]. stream_chunks ( )
172191
173192
174- def _iter_node_context (x : Node , context_dict : dict [Context [t .Any ], t .Any ]) -> Iterator [str ]:
193+ def _stream_chunks (x : Node , context_dict : dict [Context [t .Any ], t .Any ]) -> Iterator [str ]:
175194 while not isinstance (x , BaseElement ) and callable (x ):
176195 x = x ()
177196
@@ -187,24 +206,25 @@ def _iter_node_context(x: Node, context_dict: dict[Context[t.Any], t.Any]) -> It
187206 if isinstance (x , BaseElement ):
188207 yield from x ._iter_context (context_dict ) # pyright: ignore [reportPrivateUsage]
189208 elif isinstance (x , ContextProvider ):
190- yield from _iter_node_context (x .node , {** context_dict , x .context : x .value }) # pyright: ignore [reportUnknownMemberType]
209+ yield from _stream_chunks (x .node , {** context_dict , x .context : x .value }) # pyright: ignore [reportUnknownMemberType]
191210 elif isinstance (x , ContextConsumer ):
192- context_value = context_dict .get (x .context , x .context .default )
211+ context_value = context_dict .get (x .context , x .context .default ) # pyright: ignore
212+
193213 if context_value is _NO_DEFAULT :
194214 raise LookupError (
195- f'Context value for "{ x .context .name } " does not exist, '
215+ f'Context value for "{ x .context .name } " does not exist, ' # pyright: ignore
196216 f"requested by { x .debug_name } ()."
197217 )
198- yield from _iter_node_context (x .func (context_value ), context_dict )
218+ yield from _stream_chunks (x .func (context_value ), context_dict ) # pyright: ignore
199219 elif isinstance (x , Fragment ):
200- yield from _iter_node_context (x ._node , context_dict ) # pyright: ignore
220+ yield from _stream_chunks (x ._node , context_dict ) # pyright: ignore
201221 elif isinstance (x , str | _HasHtml ):
202222 yield str (_escape (x ))
203223 elif isinstance (x , int ):
204224 yield str (x )
205225 elif isinstance (x , Iterable ) and not isinstance (x , _KnownInvalidChildren ): # pyright: ignore [reportUnnecessaryIsInstance]
206226 for child in x :
207- yield from _iter_node_context (child , context_dict )
227+ yield from _stream_chunks (child , context_dict )
208228 else :
209229 raise TypeError (f"{ x !r} is not a valid child element" )
210230
@@ -231,7 +251,7 @@ def __init__(self, name: str, attrs_str: str = "", children: Node = None) -> Non
231251 self ._children = children
232252
233253 def __str__ (self ) -> _Markup :
234- return _Markup ( "" . join ( self ) )
254+ return _as_markup ( self )
235255
236256 __html__ = __str__
237257
@@ -281,14 +301,14 @@ def __call__(self: BaseElementSelf, *args: t.Any, **kwargs: t.Any) -> BaseElemen
281301 def __iter__ (self ) -> Iterator [str ]:
282302 return self ._iter_context ({})
283303
304+ def stream_chunks (self ) -> Iterator [str ]:
305+ return self ._iter_context ({})
306+
284307 def _iter_context (self , ctx : dict [Context [t .Any ], t .Any ]) -> Iterator [str ]:
285308 yield f"<{ self ._name } { self ._attrs } >"
286- yield from _iter_node_context (self ._children , ctx )
309+ yield from _stream_chunks (self ._children , ctx )
287310 yield f"</{ self ._name } >"
288311
289- # Allow starlette Response.render to directly render this element without
290- # explicitly casting to str:
291- # https://github.com/encode/starlette/blob/5ed55c441126687106109a3f5e051176f88cd3e6/starlette/responses.py#L44-L49
292312 def encode (self , encoding : str = "utf-8" , errors : str = "strict" ) -> bytes :
293313 return str (self ).encode (encoding , errors )
294314
@@ -358,13 +378,19 @@ def __init__(self) -> None:
358378 self ._node : Node = None
359379
360380 def __iter__ (self ) -> Iterator [str ]:
361- return iter_node (self )
381+ return _stream_chunks (self , {} )
362382
363- def __str__ (self ) -> str :
364- return render_node (self )
383+ def __str__ (self ) -> _Markup :
384+ return _as_markup (self )
365385
366386 __html__ = __str__
367387
388+ def stream_chunks (self ) -> Iterator [str ]:
389+ return _stream_chunks (self , {})
390+
391+ def encode (self , encoding : str = "utf-8" , errors : str = "strict" ) -> bytes :
392+ return str (self ).encode (encoding , errors )
393+
368394
369395class _FragmentGetter :
370396 def __getitem__ (self , node : Node ) -> Fragment :
@@ -376,8 +402,12 @@ def __getitem__(self, node: Node) -> Fragment:
376402fragment = _FragmentGetter ()
377403
378404
405+ def _as_markup (renderable : Renderable ) -> _Markup :
406+ return _Markup ("" .join (renderable .stream_chunks ()))
407+
408+
379409def render_node (node : Node ) -> _Markup :
380- return _Markup ("" . join ( iter_node ( node )) )
410+ return _Markup (fragment [ node ] )
381411
382412
383413def comment (text : str ) -> Fragment :
@@ -545,3 +575,14 @@ def __html__(self) -> str: ...
545575 | Callable
546576 | Iterable
547577)
578+
579+
580+ class Renderable (t .Protocol ):
581+ def __str__ (self ) -> _Markup : ...
582+ def __html__ (self ) -> _Markup : ...
583+ def stream_chunks (self ) -> Iterator [str ]: ...
584+
585+ # Allow starlette Response.render to directly render this element without
586+ # explicitly casting to str:
587+ # https://github.com/encode/starlette/blob/5ed55c441126687106109a3f5e051176f88cd3e6/starlette/responses.py#L44-L49
588+ def encode (self , encoding : str = "utf-8" , errors : str = "strict" ) -> bytes : ...
0 commit comments