Skip to content

Commit 89526cb

Browse files
committed
Unify flattening
1 parent 788c14d commit 89526cb

File tree

1 file changed

+23
-24
lines changed

1 file changed

+23
-24
lines changed

tdom/processor.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -200,20 +200,24 @@ def _resolve_attrs(
200200
return _resolve_html_attrs(interpolated_attrs)
201201

202202

203+
def _flatten_nodes(nodes: t.Iterable[Node]) -> list[Node]:
204+
"""Flatten a list of Nodes, expanding any Fragments."""
205+
flat: list[Node] = []
206+
for node in nodes:
207+
if isinstance(node, Fragment):
208+
flat.extend(node.children)
209+
else:
210+
flat.append(node)
211+
return flat
212+
213+
203214
def _substitute_and_flatten_children(
204215
children: t.Iterable[TNode], interpolations: tuple[Interpolation, ...]
205216
) -> list[Node]:
206217
"""Substitute placeholders in a list of children and flatten any fragments."""
207-
new_children: list[Node] = []
208-
for child in children:
209-
substituted = _resolve_t_node(child, interpolations)
210-
if isinstance(substituted, Fragment):
211-
# This can happen if an interpolation results in a Fragment, for
212-
# instance if it is iterable.
213-
new_children.extend(substituted.children)
214-
else:
215-
new_children.append(substituted)
216-
return new_children
218+
resolved = [_resolve_t_node(child, interpolations) for child in children]
219+
flat = _flatten_nodes(resolved)
220+
return flat
217221

218222

219223
def _node_from_value(value: object) -> Node:
@@ -335,23 +339,18 @@ def _resolve_t_text_ref(
335339
if ref.is_static:
336340
return Text(ref.strings[0])
337341

338-
parts: list[Node] = []
339-
text_t = _resolve_ref(ref, interpolations)
340-
341-
for part in text_t:
342+
def to_node(part: str | Interpolation) -> Node:
342343
if isinstance(part, str):
343-
parts.append(Text(part))
344-
else:
345-
res = _node_from_value(format_interpolation(part))
346-
if isinstance(res, Fragment):
347-
parts.extend(res.children)
348-
else:
349-
parts.append(res)
344+
return Text(part)
345+
return _node_from_value(format_interpolation(part))
346+
347+
parts = [to_node(part) for part in _resolve_ref(ref, interpolations)]
348+
flat = _flatten_nodes(parts)
350349

351-
if len(parts) == 1 and isinstance(parts[0], Text):
352-
return parts[0]
350+
if len(flat) == 1 and isinstance(flat[0], Text):
351+
return flat[0]
353352

354-
return Fragment(children=parts)
353+
return Fragment(children=flat)
355354

356355

357356
def _resolve_t_node(t_node: TNode, interpolations: tuple[Interpolation, ...]) -> Node:

0 commit comments

Comments
 (0)