Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/htpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from htpy._compiler import compile as compile
from htpy._contexts import Context as Context
from htpy._contexts import ContextConsumer as ContextConsumer
from htpy._contexts import ContextProvider as ContextProvider
Expand Down
108 changes: 108 additions & 0 deletions src/htpy/_compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import ast
import dataclasses
import inspect
import typing as t
from collections.abc import Callable, Iterator, Mapping

from markupsafe import Markup

from ._contexts import Context
from ._elements import BaseElement
from ._rendering import iter_chunks_node
from ._types import Node, Renderable

_py_compile = compile


class CompiledElement:
def __init__(self, *parts: Node):
self.parts = [Markup(part) if isinstance(part, str) else part for part in parts]

def __str__(self) -> Markup:
return Markup("".join(self.iter_chunks()))

def __html__(self) -> Markup:
return Markup("".join(self.iter_chunks()))

def iter_chunks(self, context: Mapping[Context[t.Any], t.Any] | None = None) -> Iterator[str]:
return iter_chunks_node(self.parts, context)

def encode(self, encoding: str = "utf-8", errors: str = "strict") -> bytes:
return str(self).encode(encoding, errors)


@dataclasses.dataclass
class StaticCallArgs:
args: tuple[t.Any]
kwargs: dict[str, t.Any]

@classmethod
def from_call(cls, node: ast.Call) -> t.Self | None:
if not all(isinstance(arg, ast.Constant) for arg in node.args):
return None

if any(not isinstance(kwarg.value, ast.Constant) for kwarg in node.keywords):
return None

return cls(
tuple(arg.value for arg in node.args), # type:ignore
{kwarg.arg: kwarg.value.value for kwarg in node.keywords}, # type:ignore
)


def compile[T, **P](func: Callable[P, T]) -> Callable[P, T | Renderable]:
globals = func.__globals__
node = ast.parse(inspect.getsource(func))

class HtpyCompiler(ast.NodeTransformer):
def visit_FunctionDef(self, node: ast.FunctionDef) -> t.Any:
filtered_decorators: list[ast.expr] = [
dec
for dec in node.decorator_list
# TODO: Removing the decorator like this is fragile.
if isinstance(dec, ast.Name) and dec.id != "compile"
]

return ast.FunctionDef(
name=node.name,
args=node.args,
body=[self.visit(x) for x in node.body],
decorator_list=filtered_decorators,
returns=node.returns,
)

def visit_Call(self, node: ast.Call) -> t.Any:
match node.func:
case ast.Attribute(value=ast.Name(id=name), attr=attr):
resolved = getattr(globals[name], attr)

if not isinstance(resolved, BaseElement):
return node

if (static_call_args := StaticCallArgs.from_call(node)) is None:
return None

elem = resolved(*static_call_args.args, **static_call_args.kwargs)
return ast.Call(
ast.Name(id="__htpy__CompiledElement", ctx=ast.Load()),
[
ast.Constant(
value=Markup(elem).unescape(),
)
],
)
return

return node

new = ast.fix_missing_locations(HtpyCompiler().visit(node))
globals = {
**func.__globals__,
"__htpy__CompiledElement": CompiledElement,
}
exec(
_py_compile(new, filename=func.__code__.co_filename, mode="exec"),
globals,
)

return globals[func.__name__] # type: ignore
16 changes: 16 additions & 0 deletions tests/test_compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import htpy as h
from htpy._compiler import CompiledElement


@h.compile
def trivial() -> h.VoidElement:
return h.img(src="lol.bmp")


def test_trivial() -> None:
result = trivial()

assert isinstance(result, CompiledElement)
assert result.parts == ['<img src="lol.bmp">']

assert str(result) == '<img src="lol.bmp">'
5 changes: 5 additions & 0 deletions tests/test_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

import htpy as h
from htpy._compiler import CompiledElement

example_ctx: h.Context[str] = h.Context("example_ctx", default="default!")

Expand Down Expand Up @@ -61,6 +62,10 @@ def expected_bytes(self) -> bytes:
example_with_children(title="title!")["children!"],
["<div>", "<h1>", "title!", "</h1>", "<p>", "children!", "</p>", "</div>"],
),
RenderableTestCase(
CompiledElement('<div id="foo">', "hi!", "</foo>"),
['<div id="foo">', "hi!", "</foo>"],
),
]


Expand Down
Loading