Skip to content

Commit e5dd9d5

Browse files
authored
!fix: return a list of elements in stage_for_transformers (#420)
* update stage_for_transformers to return a list of elements * bump changelog and version * flag breaking change * fix last word bug in chunk_by_attention_window
1 parent 2f5c61c commit e5dd9d5

File tree

4 files changed

+47
-14
lines changed

4 files changed

+47
-14
lines changed

Diff for: CHANGELOG.md

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
## 0.5.8-dev4
1+
## 0.5.8-dev5
22

33
### Enhancements
44

55
* Update `elements_to_json` to return string when filename is not specified
66
* `elements_from_json` may take a string instead of a filename with the `text` kwarg
77
* `detect_filetype` now does a final fallback to file extension.
8+
* Empty tags are now skipped during the depth check for HTML processing.
89

910
### Features
1011

@@ -18,6 +19,13 @@
1819
* Partitioning functions that accept a `text` kwarg no longer raise an error if an empty
1920
string is passed (and empty list of elements is returned instead).
2021
* `partition_json` no longer fails if the input is an empty list.
22+
* Fixed bug in `chunk_by_attention_window` that caused the last word in segments to be cut-off
23+
in some cases.
24+
25+
### BREAKING CHANGES
26+
27+
* `stage_for_transformers` now returns a list of elements, making it consistent with other
28+
staging bricks
2129

2230
## 0.5.7
2331

Diff for: test_unstructured/staging/test_huggingface.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from unstructured.documents.elements import Text
3+
from unstructured.documents.elements import Text, Title
44
from unstructured.staging import huggingface
55

66

@@ -12,14 +12,23 @@ def tokenize(self, text):
1212

1313

1414
def test_stage_for_transformers():
15-
elements = [Text(text="hello " * 20), Text(text="there " * 20)]
15+
title_element = (Title(text="Here is a wonderful story"),)
16+
elements = [title_element, Text(text="hello " * 20 + "there " * 20)]
17+
1618
tokenizer = MockTokenizer()
1719

18-
chunks = huggingface.stage_for_transformers(elements, tokenizer, buffer=10)
20+
chunk_elements = huggingface.stage_for_transformers(elements, tokenizer, buffer=10)
1921

20-
hello_chunk = ("hello " * 10).strip()
21-
there_chunk = ("there " * 10).strip()
22-
assert chunks == [hello_chunk, hello_chunk, "\n\n" + there_chunk, there_chunk]
22+
hello_chunk = Text(("hello " * 10).strip())
23+
there_chunk = Text(("there " * 10).strip())
24+
25+
assert chunk_elements == [
26+
title_element,
27+
hello_chunk,
28+
hello_chunk,
29+
there_chunk,
30+
there_chunk,
31+
]
2332

2433

2534
def test_chunk_by_attention_window():

Diff for: unstructured/__version__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.5.8-dev4" # pragma: no cover
1+
__version__ = "0.5.8-dev5" # pragma: no cover

Diff for: unstructured/staging/huggingface.py

+22-6
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,32 @@
1+
from copy import deepcopy
12
from typing import Callable, List, Optional
23

34
from transformers import PreTrainedTokenizer
45

5-
from unstructured.documents.elements import Text
6+
from unstructured.documents.elements import Element, NarrativeText, Text
67

78

89
def stage_for_transformers(
910
elements: List[Text],
1011
tokenizer: PreTrainedTokenizer,
1112
**chunk_kwargs,
12-
) -> List[str]:
13+
) -> List[Element]:
1314
"""Stages text elements for transformers pipelines by chunking them into sections that can
1415
fit into the attention window for the model associated with the tokenizer."""
15-
combined_text = "\n\n".join([str(element) for element in elements])
16-
return chunk_by_attention_window(combined_text, tokenizer, **chunk_kwargs)
16+
chunked_elements: List[Element] = []
17+
for element in elements:
18+
# NOTE(robinson) - Only chunk potentially lengthy text. Shorter text (like titles)
19+
# should already fit into the attention window just fine.
20+
if isinstance(element, (NarrativeText, Text)):
21+
chunked_text = chunk_by_attention_window(element.text, tokenizer, **chunk_kwargs)
22+
for chunk in chunked_text:
23+
_chunk_element = deepcopy(element)
24+
_chunk_element.text = chunk
25+
chunked_elements.append(_chunk_element)
26+
else:
27+
chunked_elements.append(element)
28+
29+
return chunked_elements
1730

1831

1932
def chunk_by_attention_window(
@@ -68,8 +81,8 @@ def chunk_by_attention_window(
6881
f"error is: \n\n{segment}",
6982
)
7083

71-
if chunk_size + num_tokens > max_chunk_size or i == (num_splits - 1):
72-
chunks.append(chunk_text)
84+
if chunk_size + num_tokens > max_chunk_size:
85+
chunks.append(chunk_text + chunk_separator.strip())
7386
chunk_text = ""
7487
chunk_size = 0
7588

@@ -79,4 +92,7 @@ def chunk_by_attention_window(
7992
chunk_text += segment
8093
chunk_size += num_tokens
8194

95+
if i == (num_splits - 1) and len(chunk_text) > 0:
96+
chunks.append(chunk_text)
97+
8298
return chunks

0 commit comments

Comments
 (0)