Skip to content

Commit 20ffa85

Browse files
authored
feat: organize base class (#36)
* feat: add LayoutTransform class and update visualizer batch method * feat: refactor transforms to inherit from LayoutTransform * refactor: clean up unused imports and improve test structure
1 parent 8efe288 commit 20ffa85

9 files changed

Lines changed: 81 additions & 60 deletions

File tree

src/layout_prompter/transforms/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from .base import LayoutTransform
12
from .discretize_bboxes import DiscretizeBboxes
23
from .label_dict_sort import LabelDictSort
34
from .lexicographic_sort import LexicographicSort
45
from .saliency_map_to_bboxes import SaliencyMapToBboxes
56
from .shuffle_elements import ShuffleElements
67

78
__all__ = [
9+
"LayoutTransform",
810
"DiscretizeBboxes",
911
"LabelDictSort",
1012
"LexicographicSort",
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import abc
2+
from typing import Any, List, Optional, Union
3+
4+
from langchain_core.runnables import RunnableSerializable
5+
from langchain_core.runnables.config import RunnableConfig
6+
7+
from layout_prompter.models import LayoutData, ProcessedLayoutData
8+
9+
10+
class LayoutTransform(RunnableSerializable):
11+
@abc.abstractmethod
12+
def invoke(
13+
self,
14+
input: Union[LayoutData, ProcessedLayoutData],
15+
config: Optional[RunnableConfig] = None,
16+
**kwargs: Any,
17+
) -> Any:
18+
raise NotImplementedError
19+
20+
def batch(
21+
self,
22+
inputs: Union[List[LayoutData], List[ProcessedLayoutData]],
23+
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
24+
*,
25+
return_exceptions: bool = False,
26+
**kwargs: Any | None,
27+
) -> List:
28+
return super().batch(
29+
inputs, config, return_exceptions=return_exceptions, **kwargs
30+
)

src/layout_prompter/transforms/discretize_bboxes.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,22 @@
11
import copy
22
from typing import Any, Optional, Union
33

4-
from langchain_core.runnables import RunnableSerializable
54
from langchain_core.runnables.config import RunnableConfig
65
from loguru import logger
76

87
from layout_prompter.models import CanvasSize, LayoutData, ProcessedLayoutData
98
from layout_prompter.utils import Configuration
109

10+
from .base import LayoutTransform
11+
1112

1213
class DiscretizeBboxesConfig(Configuration):
1314
"""Configuration for Transform classes."""
1415

1516
target_canvas_size: CanvasSize
1617

1718

18-
class DiscretizeBboxes(RunnableSerializable):
19+
class DiscretizeBboxes(LayoutTransform):
1920
name: str = "discretize-bboxes"
2021

2122
def invoke(

src/layout_prompter/transforms/label_dict_sort.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import copy
22
from typing import Any, List, Tuple, Union, cast
33

4-
from langchain_core.runnables import Runnable
54
from langchain_core.runnables.config import RunnableConfig
65
from loguru import logger
76

87
from layout_prompter.models import LayoutData, NormalizedBbox, ProcessedLayoutData
98

9+
from .base import LayoutTransform
1010

11-
class LabelDictSort(Runnable):
11+
12+
class LabelDictSort(LayoutTransform):
1213
name: str = "label-dict-sort"
1314

1415
def invoke(

src/layout_prompter/transforms/lexicographic_sort.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import copy
22
from typing import Any, List, Tuple, Union, cast
33

4-
from langchain_core.runnables import Runnable
54
from langchain_core.runnables.config import RunnableConfig
65
from loguru import logger
76

87
from layout_prompter.models import LayoutData, NormalizedBbox, ProcessedLayoutData
98

9+
from .base import LayoutTransform
1010

11-
class LexicographicSort(Runnable):
11+
12+
class LexicographicSort(LayoutTransform):
1213
name: str = "lexicographic-sort"
1314

1415
def invoke(

src/layout_prompter/transforms/saliency_map_to_bboxes.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33
import cv2
44
import numpy as np
5-
from langchain_core.runnables import RunnableSerializable
65

76
from layout_prompter.typehints import PilImage
87

8+
from .base import LayoutTransform
99

10-
class SaliencyMapToBboxes(RunnableSerializable):
10+
11+
class SaliencyMapToBboxes(LayoutTransform):
1112
name: str = "saliency-map-to-bboxes"
1213

1314
threshold: int = 100
Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
# import copy
22
# from typing import Any, Union
33

4-
from langchain_core.runnables import RunnableSerializable
54

65
# from langchain_core.runnables.config import RunnableConfig
76
# from loguru import logger
8-
97
# from layout_prompter.models import CanvasSize, LayoutData, ProcessedLayoutData
8+
from .base import LayoutTransform
109

1110

12-
class ShuffleElements(RunnableSerializable):
11+
class ShuffleElements(LayoutTransform):
1312
name: str = "shuffle-elements"

src/layout_prompter/visualizers/base.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,5 +107,17 @@ def invoke(
107107
input: Union[ProcessedLayoutData, LayoutSerializedOutputData],
108108
config: Optional[RunnableConfig] = None,
109109
**kwargs: Any,
110-
) -> Any:
110+
) -> PilImage:
111111
raise NotImplementedError
112+
113+
def batch(
114+
self,
115+
inputs: Union[List[ProcessedLayoutData], List[LayoutSerializedOutputData]],
116+
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
117+
*,
118+
return_exceptions: bool = False,
119+
**kwargs: Any | None,
120+
) -> List[PilImage]:
121+
return super().batch(
122+
inputs, config, return_exceptions=return_exceptions, **kwargs
123+
)

tests/layout_prompter_test.py

Lines changed: 22 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, List, Type, cast
1+
from typing import Dict, List, Type
22

33
import pytest
44
from langchain.chat_models import init_chat_model
@@ -12,7 +12,6 @@
1212
LayoutSerializedOutputData,
1313
PosterLayoutSerializedData,
1414
PosterLayoutSerializedOutputData,
15-
ProcessedLayoutData,
1615
Rico25SerializedData,
1716
Rico25SerializedOutputData,
1817
)
@@ -24,7 +23,6 @@
2423
from layout_prompter.preprocessors import ContentAwareProcessor
2524
from layout_prompter.settings import PosterLayoutSettings, Rico25Settings, TaskSettings
2625
from layout_prompter.transforms import DiscretizeBboxes
27-
from layout_prompter.typehints import PilImage
2826
from layout_prompter.utils import get_num_workers
2927
from layout_prompter.utils.testing import LayoutPrompterTestCase
3028
from layout_prompter.visualizers import ContentAwareVisualizer
@@ -69,24 +67,7 @@ def test_gen_type_task(
6967
input_schema: Type[LayoutSerializedData],
7068
output_schema: Type[LayoutSerializedOutputData],
7169
):
72-
# tng_dataset = layout_dataset["train"]
73-
# val_dataset = layout_dataset["validation"]
74-
# tst_dataset = layout_dataset["test"]
75-
76-
# processor = GenTypeProcessor()
77-
78-
# examples = cast(
79-
# List[ProcessedLayoutData],
80-
# processor.batch(
81-
# inputs=tng_dataset,
82-
# # config={
83-
# # "max_concurrency": 4,
84-
# # },
85-
# ),
86-
# )
87-
88-
# breakpoint()
89-
pass
70+
raise NotImplementedError
9071

9172
@pytest.mark.parametrize(
9273
argnames=("layout_dataset", "settings", "input_schema", "output_schema"),
@@ -119,15 +100,12 @@ def test_content_aware_generation(
119100
processor = ContentAwareProcessor()
120101

121102
# Process the training dataset to get candidate examples
122-
candidate_examples = cast(
123-
List[ProcessedLayoutData],
124-
processor.batch(
125-
inputs=tng_dataset,
126-
config={
127-
"max_concurrency": get_num_workers(max_concurrency=4),
128-
"callbacks": [ProgressBarCallback(total=len(tng_dataset))],
129-
},
130-
),
103+
candidate_examples = processor.batch(
104+
inputs=tng_dataset,
105+
config={
106+
"max_concurrency": get_num_workers(max_concurrency=4),
107+
"callbacks": [ProgressBarCallback(total=len(tng_dataset))],
108+
},
131109
)
132110

133111
# Select a random test example
@@ -142,13 +120,11 @@ def test_content_aware_generation(
142120
bbox_discretizer = DiscretizeBboxes()
143121

144122
# Apply the bbox discretizer to candidate examples and test data
145-
candidate_examples = cast(
146-
List[ProcessedLayoutData],
147-
bbox_discretizer.batch(
148-
candidate_examples,
149-
config={"configurable": {"target_canvas_size": target_canvas_size}},
150-
),
123+
candidate_examples = bbox_discretizer.batch(
124+
candidate_examples,
125+
config={"configurable": {"target_canvas_size": target_canvas_size}},
151126
)
127+
152128
processed_test_data = bbox_discretizer.invoke(
153129
processed_test_data,
154130
config={"configurable": {"target_canvas_size": target_canvas_size}},
@@ -187,18 +163,16 @@ def test_content_aware_generation(
187163
canvas_size=settings.canvas_size,
188164
labels=settings.labels,
189165
)
190-
visualizations = cast(
191-
List[PilImage],
192-
visualizer.batch(
193-
inputs=output.ranked_outputs,
194-
config={
195-
"configurable": {
196-
"resize_ratio": 2.0,
197-
"bg_image": test_data.content_image.copy(),
198-
"content_bboxes": processed_test_data.discrete_content_bboxes,
199-
}
200-
},
201-
),
166+
# Perform the visualization
167+
visualizations = visualizer.batch(
168+
inputs=output.ranked_outputs,
169+
config={
170+
"configurable": {
171+
"resize_ratio": 2.0,
172+
"bg_image": test_data.content_image.copy(),
173+
"content_bboxes": processed_test_data.discrete_content_bboxes,
174+
}
175+
},
202176
)
203177

204178
# Create the save directory

0 commit comments

Comments
 (0)