Skip to content

Commit 4e8a4c0

Browse files
authored
feat: prepare to add content agnostic layout generation (#35)
* feat: enhance layout processing with new processors and dataset handling * feat: add pytest-lazy-fixtures to development dependencies * refactor: remove unused abstract methods and imports from Processor class * refactor: streamline processing and selection in layout tests * refactor: rename _invoke to invoke and update parameter names for clarity * feat: enhance poster layout processing with dataset handling and bbox discretization * feat: add base and content-aware serializers for layout generation * feat: implement layout ranker and selector modules with content-aware functionality * feat: add GenTypeSerializer and refactor layout serialization methods * feat: add GenTypeSerializer to the public API of serializers * feat: add unit tests for layout ranker and content-aware selector modules * fix: ensure proper newline at end of file in test files * refactor: remove unused imports and clean up code in multiple files * chore: update version to 0.7.0 in pyproject.toml and uv.lock * refactor: simplify chain of transformations in GenTypeProcessor fix: raise NotImplementedError in LayoutSerializer invoke method * refactor: comment out unused imports and methods in GenTypeSizeProcessor * fix: add --last-failed option to pytest command in Makefile
1 parent c06c007 commit 4e8a4c0

39 files changed

Lines changed: 870 additions & 357 deletions

Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ typecheck: install
2929
.PHONY: test
3030
test: install
3131
uv run pytest -vs \
32+
--last-failed \
3233
--log-cli-level=INFO \
3334
--import-mode=importlib \
3435
--cov

examples/poster_layout.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
from typing import List, cast
44

55
from langchain.chat_models import init_chat_model
6+
from langchain.smith.evaluation.progress import ProgressBarCallback
67
from tqdm.auto import tqdm
78

89
from layout_prompter import LayoutPrompter
910
from layout_prompter.datasets import load_poster_layout
1011
from layout_prompter.models import (
1112
LayoutData,
13+
PosterLayoutSerializedData,
1214
PosterLayoutSerializedOutputData,
1315
ProcessedLayoutData,
1416
)
@@ -19,6 +21,7 @@
1921
)
2022
from layout_prompter.preprocessors import ContentAwareProcessor
2123
from layout_prompter.settings import PosterLayoutSettings
24+
from layout_prompter.transforms import DiscretizeBboxes
2225
from layout_prompter.utils.workers import get_num_workers
2326
from layout_prompter.visualizers import ContentAwareVisualizer
2427

@@ -59,33 +62,58 @@ def parse_args() -> argparse.Namespace:
5962

6063

6164
def main(args: argparse.Namespace) -> None:
65+
# Load the settings for Poster Layout dataset
6266
settings = PosterLayoutSettings()
67+
# Load the dataset
6368
hf_dataset = load_poster_layout()
6469

70+
# Convert the Hugging Face dataset to a dictionary of LayoutData
6571
dataset = {
6672
split: [
6773
LayoutData.model_validate(data)
6874
for data in tqdm(hf_dataset[split], desc=f"Processing for {split}")
6975
]
7076
for split in hf_dataset
7177
}
78+
tng_dataset, tst_dataset = dataset["train"], dataset["test"]
7279

73-
processor = ContentAwareProcessor(target_canvas_size=settings.canvas_size)
80+
# Define the content-aware processor
81+
processor = ContentAwareProcessor()
82+
83+
# Process the training dataset to generate candidate examples
7484
candidate_examples = cast(
7585
List[ProcessedLayoutData],
7686
processor.batch(
77-
inputs=dataset["train"],
87+
inputs=tng_dataset,
7888
config={
7989
"max_concurrency": args.num_workers or get_num_workers(),
90+
"callbacks": [ProgressBarCallback(total=len(tng_dataset))],
8091
},
8192
),
8293
)
83-
# inference_examples = processor.invoke(input=dataset["test"])
8494

95+
# Select a random test example or use a fixed index for reproducibility
8596
# idx = random.choice(range(len(dataset["test"])))
8697
idx = 443
87-
inference_example = cast(
88-
ProcessedLayoutData, processor.invoke(input=dataset["test"][idx])
98+
test_data = tst_dataset[idx]
99+
100+
# Process the test data
101+
inference_example = cast(ProcessedLayoutData, processor.invoke(input=test_data))
102+
103+
# Define the discretizer for bounding boxes
104+
bbox_discretizer = DiscretizeBboxes()
105+
106+
# Apply the bbox discretizer to candidate examples and test data
107+
candidate_examples = cast(
108+
List[ProcessedLayoutData],
109+
bbox_discretizer.batch(
110+
candidate_examples,
111+
config={"configurable": {"target_canvas_size": settings.canvas_size}},
112+
),
113+
)
114+
inference_example = bbox_discretizer.invoke(
115+
inference_example,
116+
config={"configurable": {"target_canvas_size": settings.canvas_size}},
89117
)
90118

91119
layout_prompter = LayoutPrompter(
@@ -101,9 +129,16 @@ def main(args: argparse.Namespace) -> None:
101129
model=args.model_id,
102130
),
103131
ranker=LayoutPrompterRanker(),
104-
schema=PosterLayoutSerializedOutputData,
105132
)
106-
outputs = layout_prompter.invoke(input=inference_example)
133+
outputs = layout_prompter.invoke(
134+
input=inference_example,
135+
config={
136+
"configurable": {
137+
"input_schema": PosterLayoutSerializedData,
138+
"output_schema": PosterLayoutSerializedOutputData,
139+
}
140+
},
141+
)
107142

108143
visualizer = ContentAwareVisualizer(
109144
canvas_size=settings.canvas_size, labels=settings.labels

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "layout-prompter"
3-
version = "0.6.0"
3+
version = "0.7.0"
44
description = "LangChain-based LayoutPrompter for content-agnostic/content-aware layout generation powered by LLM."
55
readme = "README.md"
66
authors = [
@@ -41,6 +41,7 @@ dev = [
4141
"mypy>=1.0.0",
4242
"pytest>=6.0.0",
4343
"pytest-cov>=6.0.0",
44+
"pytest-lazy-fixtures>=1.2.0",
4445
"ruff>=0.1.5",
4546
]
4647

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from .poster_layout import load_poster_layout, load_raw_poster_layout
2-
from .rico import load_rico
2+
from .rico import load_raw_rico, load_rico25
33

44
__all__ = [
5-
"load_rico",
5+
"load_raw_rico",
66
"load_poster_layout",
77
"load_raw_poster_layout",
8+
"load_rico25",
89
]

src/layout_prompter/datasets/poster_layout.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def convert_to_layout_data_format(example):
7575
labels = np.array(list(map(id2label, anns["cls_elem"])))
7676
assert len(bboxes) == len(labels)
7777

78-
# Convert bboxes to [x, y, w, h] format
78+
# Convert bboxes to (left, top, width, height) format
7979
bboxes[:, 2] -= bboxes[:, 0]
8080
bboxes[:, 3] -= bboxes[:, 1]
8181

Lines changed: 120 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,131 @@
11
import datasets as ds
2+
import numpy as np
3+
from loguru import logger
24

5+
from layout_prompter.models import LayoutData
6+
from layout_prompter.settings import Rico25Settings
7+
from layout_prompter.utils import normalize_bboxes
38

4-
def load_rico(
9+
10+
def _filter_empty_bboxes(example):
11+
return len([bbox for child in example["children"] for bbox in child["bounds"]]) > 0
12+
13+
14+
def _filter_too_many_bboxes(example, max_elements: int = 10):
15+
return (
16+
len([bbox for child in example["children"] for bbox in child["bounds"]])
17+
<= max_elements
18+
)
19+
20+
21+
def load_raw_rico(
522
dataset_name: str = "creative-graphic-design/Rico",
23+
dataset_type: str = "ui-screenshots-and-hierarchies-with-semantic-annotations",
624
) -> ds.DatasetDict:
25+
# Load the RICO dataset
726
dataset = ds.load_dataset(
827
dataset_name,
9-
name="ui-screenshots-and-view-hierarchies",
28+
name=dataset_type,
1029
)
1130
assert isinstance(dataset, ds.DatasetDict)
31+
return dataset
32+
33+
34+
def load_rico25(
35+
dataset_name: str = "creative-graphic-design/Rico",
36+
dataset_type: str = "ui-screenshots-and-hierarchies-with-semantic-annotations",
37+
num_proc: int = 32,
38+
max_elements: int = 10,
39+
) -> ds.DatasetDict:
40+
# Load the RICO settings
41+
settings = Rico25Settings()
42+
43+
# Load the RICO dataset
44+
dataset = load_raw_rico(
45+
dataset_name=dataset_name,
46+
dataset_type=dataset_type,
47+
)
48+
49+
dataset = dataset.filter(
50+
_filter_empty_bboxes,
51+
desc="Filter out empty bboxes",
52+
num_proc=num_proc,
53+
)
54+
dataset = dataset.filter(
55+
_filter_too_many_bboxes,
56+
fn_kwargs={"max_elements": max_elements},
57+
desc="Filter by max elements",
58+
num_proc=num_proc,
59+
)
60+
61+
train_feature = dataset["train"].features
62+
train_children_feature = train_feature["children"].feature
63+
component_labeler = train_children_feature.feature["component_label"]
64+
65+
def convert_to_layout_data(example):
66+
# Get the canvas size
67+
W, H = example["bounds"][2:]
68+
69+
# Get the children associated with the example
70+
children = example["children"]
71+
72+
# # Get bboxes from children and filter out invalid ones
73+
bboxes = np.array(
74+
[bbox for child in children for bbox in child["bounds"]],
75+
)
76+
77+
# Get labels from children
78+
labels = [
79+
component_labeler.int2str(label_id)
80+
for child in children
81+
for label_id in child["component_label"]
82+
]
83+
84+
# Ensure bboxes and labels have the same length
85+
assert len(bboxes) == len(labels)
86+
87+
# Convert bboxes to (left, top, width, height) format
88+
bboxes[:, 2] -= bboxes[:, 0]
89+
bboxes[:, 3] -= bboxes[:, 1]
90+
91+
# Normalize bboxes
92+
bboxes = normalize_bboxes(bboxes=bboxes, w=W, h=H)
93+
94+
# Get the canvas size as a dictionary
95+
canvas_size = settings.canvas_size.model_dump()
96+
97+
data = {
98+
"bboxes": [
99+
{
100+
"left": bbox[0],
101+
"top": bbox[1],
102+
"width": bbox[2],
103+
"height": bbox[3],
104+
}
105+
for bbox in bboxes.tolist()
106+
],
107+
"labels": labels,
108+
"canvas_size": canvas_size,
109+
"encoded_image": None,
110+
"content_bboxes": None,
111+
}
112+
113+
try:
114+
# Ensure the data conforms to the `LayoutData` model
115+
assert LayoutData.model_validate(data)
116+
except Exception as err:
117+
logger.trace(f"Data validation failed: {err}. Data: {example=}. ")
118+
return None
119+
120+
return data
121+
122+
dataset = dataset.map(
123+
convert_to_layout_data,
124+
desc="Convert RICO dataset to LayoutData format",
125+
remove_columns=dataset.column_names["train"],
126+
num_proc=num_proc,
127+
)
128+
129+
logger.debug(dataset)
12130

13131
return dataset

src/layout_prompter/models/serialized_data.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,31 +11,31 @@
1111
]
1212

1313
Rico25ClassNames = Literal[
14-
"text",
15-
"image",
16-
"icon",
17-
"list-item",
18-
"text-button",
19-
"toolbar",
20-
"web-view",
21-
"input",
22-
"card",
23-
"advertisement",
24-
"background-image",
25-
"drawer",
26-
"radio-button",
27-
"checkbox",
28-
"multi-tab",
29-
"pager-indicator",
30-
"modal",
31-
"on/off-switch",
32-
"slider",
33-
"map-view",
34-
"button-bar",
35-
"video",
36-
"bottom-navigation",
37-
"number-stepper",
38-
"date-picker",
14+
"Text",
15+
"Image",
16+
"Icon",
17+
"Text Button",
18+
"List Item",
19+
"Input",
20+
"Background Image",
21+
"Card",
22+
"Web View",
23+
"Radio Button",
24+
"Drawer",
25+
"Checkbox",
26+
"Advertisement",
27+
"Modal",
28+
"Pager Indicator",
29+
"Slider",
30+
"On/Off Switch",
31+
"Button Bar",
32+
"Toolbar",
33+
"Number Stepper",
34+
"Multi-Tab",
35+
"Date Picker",
36+
"Map View",
37+
"Video",
38+
"Bottom Navigation",
3939
]
4040

4141

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .base import LayoutRanker
2+
from .layout_prompter import LayoutPrompterRanker
3+
4+
__all__ = [
5+
"LayoutRanker",
6+
"LayoutPrompterRanker",
7+
]
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import abc
2+
from typing import Any, List, Optional
3+
4+
from langchain_core.runnables import RunnableSerializable
5+
from langchain_core.runnables.config import RunnableConfig
6+
7+
from layout_prompter.models import LayoutSerializedOutputData
8+
9+
10+
class LayoutRanker(RunnableSerializable):
11+
"""Base class for layout ranking algorithms."""
12+
13+
@abc.abstractmethod
14+
def invoke(
15+
self,
16+
input: List[LayoutSerializedOutputData],
17+
config: Optional[RunnableConfig] = None,
18+
**kwargs: Any,
19+
) -> List[LayoutSerializedOutputData]:
20+
raise NotImplementedError

src/layout_prompter/modules/rankers.py renamed to src/layout_prompter/modules/rankers/layout_prompter.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
import abc
21
from typing import Any, List, Optional, Tuple
32

43
import numpy as np
5-
from langchain_core.runnables import RunnableSerializable
64
from langchain_core.runnables.config import RunnableConfig
75
from pydantic import model_validator
86
from typing_extensions import Self
@@ -13,18 +11,7 @@
1311
compute_overlap,
1412
)
1513

16-
17-
class LayoutRanker(RunnableSerializable):
18-
"""Base class for layout ranking algorithms."""
19-
20-
@abc.abstractmethod
21-
def invoke(
22-
self,
23-
input: List[LayoutSerializedOutputData],
24-
config: Optional[RunnableConfig] = None,
25-
**kwargs: Any,
26-
) -> List[LayoutSerializedOutputData]:
27-
raise NotImplementedError
14+
from .base import LayoutRanker
2815

2916

3017
class LayoutPrompterRanker(LayoutRanker):

0 commit comments

Comments
 (0)