88from PIL import Image
99
1010import unstructured_inference .inference .layout as layout
11+ import unstructured_inference .inference .elements as elements
1112import unstructured_inference .models .base as models
1213import unstructured_inference .models .detectron2 as detectron2
1314import unstructured_inference .models .tesseract as tesseract
@@ -20,9 +21,9 @@ def mock_image():
2021
2122@pytest .fixture
2223def mock_page_layout ():
23- text_block = layout .TextRegion (2 , 4 , 6 , 8 , text = "A very repetitive narrative. " * 10 )
24+ text_block = layout .EmbeddedTextRegion (2 , 4 , 6 , 8 , text = "A very repetitive narrative. " * 10 )
2425
25- title_block = layout .TextRegion (1 , 2 , 3 , 4 , text = "A Catchy Title" )
26+ title_block = layout .EmbeddedTextRegion (1 , 2 , 3 , 4 , text = "A Catchy Title" )
2627
2728 return [text_block , title_block ]
2829
@@ -49,7 +50,7 @@ def detect(self, *args):
4950 image = Image .fromarray (np .random .randint (12 , 24 , (40 , 40 )), mode = "RGB" )
5051 text_block = layout .TextRegion (1 , 2 , 3 , 4 , text = None )
5152
52- assert layout .ocr (text_block , image = image ) == mock_text
53+ assert elements .ocr (text_block , image = image ) == mock_text
5354
5455
5556class MockLayoutModel :
@@ -69,12 +70,12 @@ def test_get_page_elements(monkeypatch, mock_page_layout):
6970 number = 0 , image = image , layout = mock_page_layout , model = MockLayoutModel (mock_page_layout )
7071 )
7172
72- elements = page .get_elements (inplace = False )
73+ elements = page .get_elements_with_model (inplace = False )
7374
7475 assert str (elements [0 ]) == "A Catchy Title"
7576 assert str (elements [1 ]).startswith ("A very repetitive narrative." )
7677
77- page .get_elements (inplace = True )
78+ page .get_elements_with_model (inplace = True )
7879 assert elements == page .elements
7980
8081
@@ -95,13 +96,13 @@ def test_get_page_elements_with_ocr(monkeypatch):
9596 doc_layout = [text_block , image_block ]
9697
9798 monkeypatch .setattr (detectron2 , "is_detectron2_available" , lambda * args : True )
98- monkeypatch .setattr (layout , "ocr" , lambda * args : "An Even Catchier Title" )
99+ monkeypatch .setattr (elements , "ocr" , lambda * args : "An Even Catchier Title" )
99100
100101 image = Image .fromarray (np .random .randint (12 , 14 , size = (40 , 10 , 3 )), mode = "RGB" )
101102 page = layout .PageLayout (
102103 number = 0 , image = image , layout = doc_layout , model = MockLayoutModel (doc_layout )
103104 )
104- page .get_elements ()
105+ page .get_elements_with_model ()
105106
106107 assert str (page ) == "\n \n An Even Catchier Title"
107108
@@ -174,7 +175,7 @@ def tolist(self):
174175 return [1 , 2 , 3 , 4 ]
175176
176177
177- class MockTextRegion (layout .TextRegion ):
178+ class MockEmbeddedTextRegion (layout .EmbeddedTextRegion ):
178179 def __init__ (self , type = None , text = None , ocr_text = None ):
179180 self .type = type
180181 self .text = text
@@ -193,7 +194,7 @@ def __init__(self, layout=None, model=None, ocr_strategy="auto", extract_tables=
193194 self .ocr_strategy = ocr_strategy
194195 self .extract_tables = extract_tables
195196
196- def ocr (self , text_block : MockTextRegion ):
197+ def ocr (self , text_block : MockEmbeddedTextRegion ):
197198 return text_block .ocr_text
198199
199200
@@ -202,15 +203,15 @@ def ocr(self, text_block: MockTextRegion):
202203 [("base" , 0.0 ), ("" , 0.0 ), ("(cid:2)" , 1.0 ), ("(cid:1)a" , 0.5 ), ("c(cid:1)ab" , 0.25 )],
203204)
204205def test_cid_ratio (text , expected ):
205- assert layout .cid_ratio (text ) == expected
206+ assert elements .cid_ratio (text ) == expected
206207
207208
208209@pytest .mark .parametrize (
209210 "text, expected" ,
210211 [("base" , False ), ("(cid:2)" , True ), ("(cid:1234567890)" , True ), ("jkl;(cid:12)asdf" , True )],
211212)
212213def test_is_cid_present (text , expected ):
213- assert layout .is_cid_present (text ) == expected
214+ assert elements .is_cid_present (text ) == expected
214215
215216
216217class MockLayout :
@@ -241,7 +242,7 @@ def filter_by(self, *args, **kwargs):
241242 ],
242243)
243244def test_get_element_from_block (block_text , layout_texts , mock_image , expected_text ):
244- with patch ("unstructured_inference.inference.layout .ocr" , return_value = "ocr" ):
245+ with patch ("unstructured_inference.inference.elements .ocr" , return_value = "ocr" ):
245246 block = layout .TextRegion (0 , 0 , 10 , 10 , text = block_text )
246247 captured_layout = [
247248 layout .TextRegion (i + 1 , i + 1 , i + 2 , i + 2 , text = text )
@@ -263,7 +264,7 @@ def test_from_image_file(monkeypatch, mock_page_layout, filetype):
263264 def mock_get_elements (self , * args , ** kwargs ):
264265 self .elements = [mock_page_layout ]
265266
266- monkeypatch .setattr (layout .PageLayout , "get_elements " , mock_get_elements )
267+ monkeypatch .setattr (layout .PageLayout , "get_elements_with_model " , mock_get_elements )
267268 elements = (
268269 layout .DocumentLayout .from_image_file (f"sample-docs/loremipsum.{ filetype } " )
269270 .pages [0 ]
@@ -301,12 +302,12 @@ def test_get_elements_from_layout(mock_page_layout, idx):
301302@pytest .mark .parametrize (
302303 "fixed_layouts, called_method, not_called_method" ,
303304 [
304- ([MockLayout ()], "get_elements_from_layout" , "get_elements " ),
305- (None , "get_elements " , "get_elements_from_layout" ),
305+ ([MockLayout ()], "get_elements_from_layout" , "get_elements_with_model " ),
306+ (None , "get_elements_with_model " , "get_elements_from_layout" ),
306307 ],
307308)
308309def test_from_file_fixed_layout (fixed_layouts , called_method , not_called_method ):
309- with patch .object (layout .PageLayout , "get_elements " , return_value = []), patch .object (
310+ with patch .object (layout .PageLayout , "get_elements_with_model " , return_value = []), patch .object (
310311 layout .PageLayout , "get_elements_from_layout" , return_value = []
311312 ):
312313 layout .DocumentLayout .from_file ("sample-docs/loremipsum.pdf" , fixed_layouts = fixed_layouts )
@@ -323,70 +324,105 @@ def test_invalid_ocr_strategy_raises(mock_image):
323324 ("text" , "expected" ), [("a\t s\x0c d\n fas\f d\r f\b " , "asdfasdf" ), ("\" '\\ " , "\" '\\ " )]
324325)
325326def test_remove_control_characters (text , expected ):
326- assert layout .remove_control_characters (text ) == expected
327+ assert elements .remove_control_characters (text ) == expected
327328
328329
329- no_text_region = layout .TextRegion (0 , 0 , 100 , 100 )
330- text_region = layout .TextRegion (0 , 0 , 100 , 100 , text = "test" )
331- cid_text_region = layout .TextRegion (0 , 0 , 100 , 100 , text = "(cid:1)(cid:2)(cid:3)(cid:4)(cid:5)" )
330+ no_text_region = layout .EmbeddedTextRegion (0 , 0 , 100 , 100 )
331+ text_region = layout .EmbeddedTextRegion (0 , 0 , 100 , 100 , text = "test" )
332+ cid_text_region = layout .EmbeddedTextRegion (
333+ 0 , 0 , 100 , 100 , text = "(cid:1)(cid:2)(cid:3)(cid:4)(cid:5)"
334+ )
332335overlapping_rect = layout .ImageTextRegion (50 , 50 , 150 , 150 )
333336nonoverlapping_rect = layout .ImageTextRegion (150 , 150 , 200 , 200 )
334- populated_text_region = layout .TextRegion (50 , 50 , 60 , 60 , text = "test" )
335- unpopulated_text_region = layout .TextRegion (50 , 50 , 60 , 60 , text = None )
337+ populated_text_region = layout .EmbeddedTextRegion (50 , 50 , 60 , 60 , text = "test" )
338+ unpopulated_text_region = layout .EmbeddedTextRegion (50 , 50 , 60 , 60 , text = None )
336339
337340
338341@pytest .mark .parametrize (
339- ("region" , "text_objects" , "image_objects " , "ocr_strategy" , "expected" ),
342+ ("region" , "objects " , "ocr_strategy" , "expected" ),
340343 [
341- (no_text_region , [], [ nonoverlapping_rect ], "auto" , False ),
342- (no_text_region , [], [ overlapping_rect ], "auto" , True ),
343- (no_text_region , [], [], "auto" , False ),
344- (no_text_region , [populated_text_region ], [ nonoverlapping_rect ], "auto" , False ),
345- (no_text_region , [populated_text_region ], [ overlapping_rect ], "auto" , False ),
346- (no_text_region , [populated_text_region ], [], "auto" , False ),
347- (no_text_region , [unpopulated_text_region ], [ nonoverlapping_rect ], "auto" , False ),
348- (no_text_region , [unpopulated_text_region ], [ overlapping_rect ], "auto" , True ),
349- (no_text_region , [unpopulated_text_region ], [], "auto" , False ),
344+ (no_text_region , [nonoverlapping_rect ], "auto" , False ),
345+ (no_text_region , [overlapping_rect ], "auto" , True ),
346+ (no_text_region , [], "auto" , False ),
347+ (no_text_region , [populated_text_region , nonoverlapping_rect ], "auto" , False ),
348+ (no_text_region , [populated_text_region , overlapping_rect ], "auto" , False ),
349+ (no_text_region , [populated_text_region ], "auto" , False ),
350+ (no_text_region , [unpopulated_text_region , nonoverlapping_rect ], "auto" , False ),
351+ (no_text_region , [unpopulated_text_region , overlapping_rect ], "auto" , True ),
352+ (no_text_region , [unpopulated_text_region ], "auto" , False ),
350353 * list (
351354 product (
352355 [text_region ],
353- [[], [populated_text_region ], [unpopulated_text_region ]],
354- [[], [nonoverlapping_rect ], [overlapping_rect ]],
356+ [
357+ [],
358+ [populated_text_region ],
359+ [unpopulated_text_region ],
360+ [nonoverlapping_rect ],
361+ [overlapping_rect ],
362+ [populated_text_region , nonoverlapping_rect ],
363+ [populated_text_region , overlapping_rect ],
364+ [unpopulated_text_region , nonoverlapping_rect ],
365+ [unpopulated_text_region , overlapping_rect ],
366+ ],
355367 ["auto" ],
356368 [False ],
357369 )
358370 ),
359371 * list (
360372 product (
361373 [cid_text_region ],
362- [[], [populated_text_region ], [unpopulated_text_region ]],
363- [[overlapping_rect ]],
374+ [
375+ [],
376+ [populated_text_region ],
377+ [unpopulated_text_region ],
378+ [overlapping_rect ],
379+ [populated_text_region , overlapping_rect ],
380+ [unpopulated_text_region , overlapping_rect ],
381+ ],
364382 ["auto" ],
365383 [True ],
366384 )
367385 ),
368386 * list (
369387 product (
370388 [no_text_region , text_region , cid_text_region ],
371- [[], [populated_text_region ], [unpopulated_text_region ]],
372- [[], [nonoverlapping_rect ], [overlapping_rect ]],
389+ [
390+ [],
391+ [populated_text_region ],
392+ [unpopulated_text_region ],
393+ [nonoverlapping_rect ],
394+ [overlapping_rect ],
395+ [populated_text_region , nonoverlapping_rect ],
396+ [populated_text_region , overlapping_rect ],
397+ [unpopulated_text_region , nonoverlapping_rect ],
398+ [unpopulated_text_region , overlapping_rect ],
399+ ],
373400 ["force" ],
374401 [True ],
375402 )
376403 ),
377404 * list (
378405 product (
379406 [no_text_region , text_region , cid_text_region ],
380- [[], [populated_text_region ], [unpopulated_text_region ]],
381- [[], [nonoverlapping_rect ], [overlapping_rect ]],
407+ [
408+ [],
409+ [populated_text_region ],
410+ [unpopulated_text_region ],
411+ [nonoverlapping_rect ],
412+ [overlapping_rect ],
413+ [populated_text_region , nonoverlapping_rect ],
414+ [populated_text_region , overlapping_rect ],
415+ [unpopulated_text_region , nonoverlapping_rect ],
416+ [unpopulated_text_region , overlapping_rect ],
417+ ],
382418 ["never" ],
383419 [False ],
384420 )
385421 ),
386422 ],
387423)
388- def test_ocr_image (region , text_objects , image_objects , ocr_strategy , expected ):
389- assert layout .needs_ocr (region , text_objects , image_objects , ocr_strategy ) is expected
424+ def test_ocr_image (region , objects , ocr_strategy , expected ):
425+ assert elements .needs_ocr (region , objects , ocr_strategy ) is expected
390426
391427
392428def test_load_pdf ():
0 commit comments