1- from typing import Dict , List , Type , cast
1+ from typing import Dict , List , Type
22
33import pytest
44from langchain .chat_models import init_chat_model
1212 LayoutSerializedOutputData ,
1313 PosterLayoutSerializedData ,
1414 PosterLayoutSerializedOutputData ,
15- ProcessedLayoutData ,
1615 Rico25SerializedData ,
1716 Rico25SerializedOutputData ,
1817)
2423from layout_prompter .preprocessors import ContentAwareProcessor
2524from layout_prompter .settings import PosterLayoutSettings , Rico25Settings , TaskSettings
2625from layout_prompter .transforms import DiscretizeBboxes
27- from layout_prompter .typehints import PilImage
2826from layout_prompter .utils import get_num_workers
2927from layout_prompter .utils .testing import LayoutPrompterTestCase
3028from layout_prompter .visualizers import ContentAwareVisualizer
@@ -69,24 +67,7 @@ def test_gen_type_task(
6967 input_schema : Type [LayoutSerializedData ],
7068 output_schema : Type [LayoutSerializedOutputData ],
7169 ):
72- # tng_dataset = layout_dataset["train"]
73- # val_dataset = layout_dataset["validation"]
74- # tst_dataset = layout_dataset["test"]
75-
76- # processor = GenTypeProcessor()
77-
78- # examples = cast(
79- # List[ProcessedLayoutData],
80- # processor.batch(
81- # inputs=tng_dataset,
82- # # config={
83- # # "max_concurrency": 4,
84- # # },
85- # ),
86- # )
87-
88- # breakpoint()
89- pass
70+ raise NotImplementedError
9071
9172 @pytest .mark .parametrize (
9273 argnames = ("layout_dataset" , "settings" , "input_schema" , "output_schema" ),
@@ -119,15 +100,12 @@ def test_content_aware_generation(
119100 processor = ContentAwareProcessor ()
120101
121102 # Process the training dataset to get candidate examples
122- candidate_examples = cast (
123- List [ProcessedLayoutData ],
124- processor .batch (
125- inputs = tng_dataset ,
126- config = {
127- "max_concurrency" : get_num_workers (max_concurrency = 4 ),
128- "callbacks" : [ProgressBarCallback (total = len (tng_dataset ))],
129- },
130- ),
103+ candidate_examples = processor .batch (
104+ inputs = tng_dataset ,
105+ config = {
106+ "max_concurrency" : get_num_workers (max_concurrency = 4 ),
107+ "callbacks" : [ProgressBarCallback (total = len (tng_dataset ))],
108+ },
131109 )
132110
133111 # Select a random test example
@@ -142,13 +120,11 @@ def test_content_aware_generation(
142120 bbox_discretizer = DiscretizeBboxes ()
143121
144122 # Apply the bbox discretizer to candidate examples and test data
145- candidate_examples = cast (
146- List [ProcessedLayoutData ],
147- bbox_discretizer .batch (
148- candidate_examples ,
149- config = {"configurable" : {"target_canvas_size" : target_canvas_size }},
150- ),
123+ candidate_examples = bbox_discretizer .batch (
124+ candidate_examples ,
125+ config = {"configurable" : {"target_canvas_size" : target_canvas_size }},
151126 )
127+
152128 processed_test_data = bbox_discretizer .invoke (
153129 processed_test_data ,
154130 config = {"configurable" : {"target_canvas_size" : target_canvas_size }},
@@ -187,18 +163,16 @@ def test_content_aware_generation(
187163 canvas_size = settings .canvas_size ,
188164 labels = settings .labels ,
189165 )
190- visualizations = cast (
191- List [PilImage ],
192- visualizer .batch (
193- inputs = output .ranked_outputs ,
194- config = {
195- "configurable" : {
196- "resize_ratio" : 2.0 ,
197- "bg_image" : test_data .content_image .copy (),
198- "content_bboxes" : processed_test_data .discrete_content_bboxes ,
199- }
200- },
201- ),
166+ # Perform the visualization
167+ visualizations = visualizer .batch (
168+ inputs = output .ranked_outputs ,
169+ config = {
170+ "configurable" : {
171+ "resize_ratio" : 2.0 ,
172+ "bg_image" : test_data .content_image .copy (),
173+ "content_bboxes" : processed_test_data .discrete_content_bboxes ,
174+ }
175+ },
202176 )
203177
204178 # Create the save directory
0 commit comments