Skip to content

Fix notion columnlist retrieval #208

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
## 0.2.2-dev0

### Fixes

* **Fix Notion Pagination** Iterate on Notion paginated results using the `next_cursor` and `start_cursor` properties.

## 0.2.1

### Enhancements
Expand Down Expand Up @@ -131,7 +137,7 @@

* **Model serialization with nested models** Logic updated to properly handle serializing pydantic models that have nested configs with secret values.
* **Sharepoint permission config requirement** The sharepoint connector was expecting the permission config, even though it should have been optional.
* **Sharepoint CLI permission params made optional
* **Sharepoint CLI permission params made optional**

### Enhancements

Expand Down Expand Up @@ -182,7 +188,7 @@
* **Chroma dict settings should allow string inputs**
* **Move opensearch non-secret fields out of access config**
* **Support string inputs for dict type model fields** Use the `BeforeValidator` support from pydantic to map a string value to a dict if that's provided.
* **Move opensearch non-secret fields out of access config
* **Move opensearch non-secret fields out of access config**

### Fixes

Expand Down
2 changes: 1 addition & 1 deletion unstructured_ingest/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.2.1" # pragma: no cover
__version__ = "0.2.2-dev0" # pragma: no cover
10 changes: 6 additions & 4 deletions unstructured_ingest/connector/notion/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,12 @@ def iterate_list(
block_id: str,
**kwargs: Any,
) -> Generator[List[Block], None, None]:
next_cursor = None
while True:
response: dict = (
self.retry_handler(super().list, block_id=block_id, **kwargs)
self.retry_handler(super().list, block_id=block_id, start_cursor=next_cursor, **kwargs)
if self.retry_handler
else super().list(block_id=block_id, **kwargs)
else super().list(block_id=block_id, start_cursor=next_cursor, **kwargs)
) # type: ignore
child_blocks = [Block.from_dict(data=b) for b in response.pop("results", [])]
yield child_blocks
Expand Down Expand Up @@ -149,11 +150,12 @@ def query(self, database_id: str, **kwargs: Any) -> Tuple[List[Page], dict]:
return pages, resp

def iterate_query(self, database_id: str, **kwargs: Any) -> Generator[List[Page], None, None]:
next_cursor = None
while True:
response: dict = (
self.retry_handler(super().query, database_id=database_id, **kwargs)
self.retry_handler(super().query, database_id=database_id, start_cursor=next_cursor, **kwargs)
if (self.retry_handler)
else (super().query(database_id=database_id, **kwargs))
else (super().query(database_id=database_id, start_cursor=next_cursor, **kwargs))
) # type: ignore
pages = [Page.from_dict(data=p) for p in response.pop("results", [])]
for p in pages:
Expand Down
149 changes: 90 additions & 59 deletions unstructured_ingest/connector/notion/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,41 +30,30 @@


@dataclass
class TextExtractionResponse:
text: Optional[str] = None
child_pages: List[str] = field(default_factory=list)
child_databases: List[str] = field(default_factory=list)


@dataclass
class HtmlExtractionResponse:
html: Optional[HtmlTag] = None
class ProcessBlockResponse:
html_elements: List[Tuple[BlockBase, HtmlTag]] = field(default_factory=list)
child_pages: List[str] = field(default_factory=list)
child_databases: List[str] = field(default_factory=list)


def extract_page_html(
def process_block(
client: Client,
page_id: str,
logger: logging.Logger,
) -> HtmlExtractionResponse:
page_id_uuid = UUID(page_id)
parent_block: Block,
start_level: int = 0,
) -> ProcessBlockResponse:
block_id_uuid = UUID(parent_block.id)
html_elements: List[Tuple[BlockBase, HtmlTag]] = []
parent_block: Block = client.blocks.retrieve(block_id=page_id) # type: ignore
head = None
if isinstance(parent_block.block, notion_blocks.ChildPage):
head = Head([], Title([], parent_block.block.title))
child_pages: List[str] = []
child_databases: List[str] = []
parents: List[Tuple[int, Block]] = [(0, parent_block)]
processed_block_ids = []
parents: List[Tuple[int, Block]] = [(start_level, parent_block)]
processed_block_ids: List[str] = []
while len(parents) > 0:
level, parent = parents.pop(0)
parent_html = parent.get_html()
if parent_html:
html_elements.append((parent.block, parent_html))
logger.debug(f"processing block: {parent}")
if isinstance(parent.block, notion_blocks.ChildPage) and parent.id != str(page_id_uuid):
if isinstance(parent.block, notion_blocks.ChildPage) and parent.id != str(block_id_uuid):
child_pages.append(parent.id)
continue
if isinstance(parent.block, notion_blocks.ChildDatabase):
Expand All @@ -77,8 +66,10 @@ def extract_page_html(
child_databases.extend(table_response.child_databases)
continue
if isinstance(parent.block, notion_blocks.ColumnList):
column_html = build_columned_list(client=client, column_parent=parent)
html_elements.append((parent.block, column_html))
build_columned_list_response = build_columned_list(client=client, logger=logger, column_parent=parent, level=level)
child_pages.extend(build_columned_list_response.child_pages)
child_databases.extend(build_columned_list_response.child_databases)
html_elements.append((parent.block, build_columned_list_response.columned_list_html))
continue
if isinstance(parent.block, notion_blocks.BulletedListItem):
bullet_list_resp = build_bulleted_list_children(
Expand All @@ -96,7 +87,12 @@ def extract_page_html(
if numbered_list_children := numbered_list_resp.child_list:
html_elements.append((parent.block, numbered_list_children))
continue
if parent.block.can_have_children() and parent.has_children:
if parent.has_children:
if not parent.block.can_have_children():
# TODO: wrap in div?
logger.error(f"WARNING! block {parent.type} cannot have children: {parent}")
continue

children = []
for children_block in client.blocks.children.iterate_list( # type: ignore
block_id=parent.id,
Expand All @@ -107,36 +103,56 @@ def extract_page_html(
for child in children:
if child.id not in processed_block_ids:
parents.append((level + 1, child))
processed_block_ids.append(parent)

# Join list items
joined_html_elements = []
numbered_list_items = []
bullet_list_items = []
for block, html in html_elements:
if isinstance(block, notion_blocks.BulletedListItem):
bullet_list_items.append(html)
continue
if isinstance(block, notion_blocks.NumberedListItem):
numbered_list_items.append(html)
continue
if len(numbered_list_items) > 0:
joined_html_elements.append(Ol([], numbered_list_items))
numbered_list_items = []
if len(bullet_list_items) > 0:
joined_html_elements.append(Ul([], bullet_list_items))
bullet_list_items = []
joined_html_elements.append(html)

body = Body([], joined_html_elements)
processed_block_ids.append(parent.id)

return ProcessBlockResponse(
html_elements=html_elements,
child_pages=child_pages,
child_databases=child_databases,
)



@dataclass
class TextExtractionResponse:
text: Optional[str] = None
child_pages: List[str] = field(default_factory=list)
child_databases: List[str] = field(default_factory=list)


@dataclass
class HtmlExtractionResponse:
html: Optional[HtmlTag] = None
child_pages: List[str] = field(default_factory=list)
child_databases: List[str] = field(default_factory=list)


def extract_page_html(
client: Client,
page_id: str,
logger: logging.Logger,
) -> HtmlExtractionResponse:
parent_block: Block = client.blocks.retrieve(block_id=page_id) # type: ignore
head = None
if isinstance(parent_block.block, notion_blocks.ChildPage):
head = Head([], Title([], parent_block.block.title))

process_block_response = process_block(
client=client,
logger=logger,
parent_block=parent_block,
start_level=0,
)
body = Body([], [html for block, html in process_block_response.html_elements])
all_elements = [body]
if head:
all_elements = [head] + all_elements
full_html = Html([], all_elements)

return HtmlExtractionResponse(
full_html,
child_pages=child_pages,
child_databases=child_databases,
child_pages=process_block_response.child_pages,
child_databases=process_block_response.child_databases,
)


Expand Down Expand Up @@ -454,29 +470,44 @@ def build_table(client: Client, table: Block) -> BuildTableResponse:
child_databases=child_databases,
)

@dataclass
class BuildColumnedListResponse:
columned_list_html: HtmlTag
child_pages: List[str] = field(default_factory=list)
child_databases: List[str] = field(default_factory=list)

def build_columned_list(client: Client, column_parent: Block) -> HtmlTag:

def build_columned_list(client: Client, logger: logging.Logger, column_parent: Block, level: int = 0) -> BuildColumnedListResponse:
if not isinstance(column_parent.block, notion_blocks.ColumnList):
raise ValueError(f"block type not column list: {type(column_parent.block)}")
columns: List[Block] = []
child_pages: List[str] = []
child_databases: List[str] = []
for column_chunk in client.blocks.children.iterate_list( # type: ignore
block_id=column_parent.id,
):
columns.extend(column_chunk)
num_columns = len(columns)
columns_content = []
for column in columns:
for column_content_chunk in client.blocks.children.iterate_list( # type: ignore
block_id=column.id,
):
columns_content.append(
Div(
[Style(f"width:{100/num_columns}%; float: left")],
[content.block.get_html() for content in column_content_chunk],
),
)
column_content_response = process_block(
client=client,
logger=logger,
parent_block=column,
start_level=level + 1,
)
columns_content.append(
Div(
[Style(f"width:{100/num_columns}%; float: left")],
[html for block, html in column_content_response.html_elements],
),
)

return Div([], columns_content)
return BuildColumnedListResponse(
columned_list_html=Div([], columns_content),
child_pages=child_pages,
child_databases=child_databases,
)


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class Heading(BlockBase):

@staticmethod
def can_have_children() -> bool:
return False
return True

@classmethod
def from_dict(cls, data: dict):
Expand Down