Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
name: CI

on:
push:
branches: [main]
pull_request:
branches: [main]
paths-ignore:
- "README.md"

jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.11"]

steps:
- uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
make setup
make install

- name: Format
run: |
make format

- name: Lint
run: |
make lint

- name: Type check
run: |
make typecheck

# - name: Run tests
# run: |
# make test
31 changes: 31 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#
# Installation
#

.PHONY: setup
setup:
pip install -U uv

.PHONY: install
install:
uv sync --all-extras

#
# linter/formatter/typecheck
#

.PHONY: lint
lint: install
uv run ruff check --output-format=github .

.PHONY: format
format: install
uv run ruff format --check --diff .

.PHONY: typecheck
typecheck: install
uv run mypy --cache-dir=/dev/null .

.PHONY: test
test: install
uv run pytest -vsx --log-cli-level=INFO
4 changes: 2 additions & 2 deletions src/layout_prompter/modules/rankers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ def invoke(
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> List[SerializedOutputData]:
breakpoint()
return super().invoke(input, config, **kwargs)
raise NotImplementedError


class LayoutPrompterRanker(BaseModel, LayoutRanker):
Expand Down Expand Up @@ -62,6 +61,7 @@ def invoke(

min_vals = np.min(metrics_arr, axis=0, keepdims=True)
max_vals = np.max(metrics_arr, axis=0, keepdims=True)

scaled_metrics = (metrics_arr - min_vals) / (max_vals - min_vals)

quality = (
Expand Down
12 changes: 6 additions & 6 deletions src/layout_prompter/modules/selectors.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import random
from abc import abstractmethod
from typing import Any, List, Optional, Tuple
from typing import List, Optional, Tuple

import cv2
import numpy as np
from langchain_core.example_selectors.base import BaseExampleSelector
from langchain_core.runnables import Runnable
from langchain_core.runnables.config import RunnableConfig
from pydantic import BaseModel, ValidationInfo, field_validator, model_validator
from pydantic import BaseModel, model_validator
from typing_extensions import Self

from layout_prompter.models import ProcessedLayoutData
Expand Down Expand Up @@ -36,7 +33,10 @@ def select_examples( # type: ignore[override]
) -> List[ProcessedLayoutData]:
raise NotImplementedError

def add_example(self, example: ProcessedLayoutData) -> Any:
def add_example( # type: ignore[override]
self,
example: ProcessedLayoutData,
) -> None:
self.examples.append(example)

def _is_filter(self, data: ProcessedLayoutData) -> bool:
Expand Down
20 changes: 12 additions & 8 deletions src/layout_prompter/modules/serializers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import logging
from typing import Any, Final, List

from langchain_core.prompt_values import ChatPromptValue
Expand All @@ -14,6 +15,8 @@

from layout_prompter.models import Coordinates, ProcessedLayoutData, SerializedData

logger = logging.getLogger(__name__)

SYSTEM_PROMPT: Final[str] = """\
Please generate a layout based on the given information. You need to ensure that the generated layout looks realistic, with elements well aligned and avoiding unnecessary overlap.

Expand All @@ -25,22 +28,19 @@
{layout_domain} layout

## Canvas Size
{canvas_width}px x {canvas_height}px
"""
{canvas_width}px x {canvas_height}px"""

CONTENT_AWARE_CONSTRAINT: Final[str] = """\
# Constraints
## Content Constraint
{content_constraint}

## Element Type Constraint
{type_constraint}
"""
{type_constraint}"""

SERIALIZED_LAYOUT: Final[str] = """\
# Serialized Layout
{serialized_layout}
"""
{serialized_layout}"""


class LayoutSerializerInput(BaseModel):
Expand Down Expand Up @@ -97,16 +97,18 @@ def _get_content_constraint(self, data: ProcessedLayoutData) -> str:
return self._convert_to_double_bracket(content_constraint)

def _get_type_constraint(self, data: ProcessedLayoutData) -> str:
assert data.labels is not None
type_constraint = json.dumps(
{idx: label for idx, label in enumerate(data.labels)}
)
return self._convert_to_double_bracket(type_constraint)

def _get_serialized_layout(self, data: ProcessedLayoutData) -> str:
assert len(data.labels) == len(data.discrete_gold_bboxes)
assert data.labels is not None and data.discrete_bboxes is not None
labels, discrete_gold_bboxes = data.labels, data.discrete_bboxes

serialized_data_list = []
for class_name, bbox in zip(data.labels, data.discrete_gold_bboxes):
for class_name, bbox in zip(labels, discrete_gold_bboxes):
left, top, width, height = bbox

serialized_data = SerializedData(
Expand Down Expand Up @@ -173,4 +175,6 @@ def invoke(
}
)
assert isinstance(final_prompt, ChatPromptValue)
logger.debug(final_prompt.to_messages())

return final_prompt
2 changes: 2 additions & 0 deletions src/layout_prompter/transforms/discretize_bboxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def invoke(
config: RunnableConfig | None = None,
**kwargs: Any,
) -> ProcessedLayoutData:
assert input.bboxes is not None and input.labels is not None

canvas_size = input.canvas_size.model_dump()

bboxes, labels = copy.deepcopy(input.bboxes), copy.deepcopy(input.labels)
Expand Down
2 changes: 2 additions & 0 deletions src/layout_prompter/transforms/label_dict_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def invoke(
config: RunnableConfig | None = None,
**kwargs: Any,
) -> ProcessedLayoutData:
assert input.bboxes is not None and input.labels is not None

canvas_size = input.canvas_size
bboxes, labels = copy.deepcopy(input.bboxes), copy.deepcopy(input.labels)
content_bboxes = (
Expand Down
3 changes: 3 additions & 0 deletions src/layout_prompter/transforms/lexicographic_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def invoke(
config: RunnableConfig | None = None,
**kwargs: Any,
) -> ProcessedLayoutData:
assert input.bboxes is not None and input.labels is not None

canvas_size = input.canvas_size
bboxes, labels = copy.deepcopy(input.bboxes), copy.deepcopy(input.labels)
content_bboxes = (
Expand All @@ -40,6 +42,7 @@ def invoke(
)

# Extract left and top coordinates from bboxes
assert input.bboxes is not None
left, top, _, _ = input.bboxes.T

# Get the indices of the sorted bboxes based on left and top coordinates
Expand Down
3 changes: 1 addition & 2 deletions src/layout_prompter/transforms/saliency_map_to_bboxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import cv2
import numpy as np
from langchain_core.runnables import Runnable
from PIL import Image
from pydantic import BaseModel

from layout_prompter.typehints import PilImage
Expand Down Expand Up @@ -51,6 +50,6 @@ def invoke( # type: ignore[override]
contours, _ = cv2.findContours(
thresholded_map, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
)
bboxes = self.get_filtered_bboxes(contours)
bboxes = self.get_filtered_bboxes(contours) # type: ignore[arg-type]

return bboxes if len(bboxes) != 0 else None