22# @Author: SWHL
3344import argparse
5- import warnings
65from pathlib import Path
7- from typing import Dict , List , Tuple , Union
6+ from typing import Dict , List , Optional , Tuple , Union
87
98import cv2
10- import filetype
119import fitz
1210import numpy as np
11+ from rapidocr import RapidOCR
1312
14- from .utils import import_package
13+ from .logger import Logger
14+ from .utils import which_type
1515
1616
17- class PDFExtracter :
18- def __init__ (self , dpi = 200 , ** ocr_kwargs ):
17+ class RapidOCRPDF :
18+ def __init__ (self , dpi = 200 , ocr_params : Optional [ Dict ] = None ):
1919 self .dpi = dpi
20-
21- ocr_engine = import_package ("rapidocr_onnxruntime" )
22- if ocr_engine is None :
23- ocr_engine = import_package ("rapidocr_openvino" )
24-
25- if ocr_engine is None :
26- ocr_engine = import_package ("rapidocr_paddle" )
27-
28- if ocr_engine is not None :
29- ocr_kwargs .update ({
30- "det_use_cuda" : True ,
31- "cls_use_cuda" : True ,
32- "rec_use_cuda" : True
33- })
34- else :
35- raise ModuleNotFoundError (
36- "Can't find the rapidocr_onnxruntime/rapidocr_openvino/rapidocr_paddle package.\n Please pip install rapidocr_onnxruntime to run the code."
37- )
38-
39- self .text_sys = ocr_engine .RapidOCR (** ocr_kwargs )
20+ self .ocr_engine = RapidOCR (params = ocr_params )
4021 self .empty_list = []
22+ self .logger = Logger (logger_name = __name__ ).get_log ()
4123
4224 def __call__ (
43- self ,
44- content : Union [str , Path , bytes ],
45- force_ocr : bool = False ,
25+ self , content : Union [str , Path , bytes ], force_ocr : bool = False
4626 ) -> List [List [Union [str , str , str ]]]:
4727 try :
48- file_type = self . which_type (content )
28+ file_type = which_type (content )
4929 except (FileExistsError , TypeError ) as e :
50- raise PDFExtracterError ("The input content is empty." ) from e
30+ raise RapidOCRPDFError ("The input content is empty." ) from e
5131
5232 if file_type != "pdf" :
53- raise PDFExtracterError ("The file type is not PDF format." )
33+ raise RapidOCRPDFError ("The file type is not PDF format." )
5434
5535 try :
5636 pdf_data = self .load_pdf (content )
57- except PDFExtracterError as e :
58- warnings . warn ( str ( e ) )
37+ except RapidOCRPDFError as e :
38+ self . logger . error ( e )
5939 return self .empty_list
6040
6141 txts_dict , need_ocr_idxs = self .extract_texts (pdf_data , force_ocr )
@@ -69,7 +49,7 @@ def __call__(
6949 def load_pdf (pdf_content : Union [str , Path , bytes ]) -> bytes :
7050 if isinstance (pdf_content , (str , Path )):
7151 if not Path (pdf_content ).exists ():
72- raise PDFExtracterError (f"{ pdf_content } does not exist." )
52+ raise RapidOCRPDFError (f"{ pdf_content } does not exist." )
7353
7454 with open (pdf_content , "rb" ) as f :
7555 data = f .read ()
@@ -78,7 +58,7 @@ def load_pdf(pdf_content: Union[str, Path, bytes]) -> bytes:
7858 if isinstance (pdf_content , bytes ):
7959 return pdf_content
8060
81- raise PDFExtracterError (f"{ type (pdf_content )} is not in [str, Path, bytes]." )
61+ raise RapidOCRPDFError (f"{ type (pdf_content )} is not in [str, Path, bytes]." )
8262
8363 def extract_texts (self , pdf_data : bytes , force_ocr : bool ) -> Tuple [Dict , List ]:
8464 texts , need_ocr_idxs = {}, []
@@ -107,20 +87,19 @@ def convert_img(page):
10787 with fitz .open (stream = pdf_data ) as doc :
10888 for i in need_ocr_idxs :
10989 img = convert_img (doc [i ])
110- preds , _ = self .text_sys (img )
111- if preds :
112- text = []
113- confidences = []
114- for pred in preds :
115- _ , rec_res , confidence = pred
116- text .append (rec_res )
117- confidences .append (float (confidence ))
118-
119- avg_confidence = np .mean (confidences ) if confidences else 0.0
120- ocr_res [str (i )] = {
121- "text" : "\n " .join (text ),
122- "avg_confidence" : avg_confidence
123- }
90+
91+ preds = self .ocr_engine (img )
92+ if preds .txts is None :
93+ continue
94+
95+ avg_score = (
96+ sum (preds .scores ) / len (preds .scores ) if preds .scores else 0.0
97+ )
98+
99+ ocr_res [str (i )] = {
100+ "text" : "\n " .join (preds .txts ),
101+ "avg_confidence" : avg_score ,
102+ }
124103 return ocr_res
125104
126105 def merge_direct_ocr (self , txts_dict : Dict , ocr_res_dict : Dict ) -> List [List [str ]]:
@@ -131,25 +110,14 @@ def merge_direct_ocr(self, txts_dict: Dict, ocr_res_dict: Dict) -> List[List[str
131110 for page_idx , ocr_data in ocr_res_dict .items ():
132111 final_result [page_idx ] = {
133112 "text" : ocr_data ["text" ],
134- "avg_confidence" : ocr_data ["avg_confidence" ]
113+ "avg_confidence" : ocr_data ["avg_confidence" ],
135114 }
136115
137116 final_result = dict (sorted (final_result .items (), key = lambda x : int (x [0 ])))
138- return [[k , v ["text" ], str (v ["avg_confidence" ])] for k , v in final_result .items ()]
139-
140- @staticmethod
141- def which_type (content : Union [bytes , str , Path ]) -> str :
142- if isinstance (content , (str , Path )) and not Path (content ).exists ():
143- raise FileExistsError (f"{ content } does not exist." )
144-
145- kind = filetype .guess (content )
146- if kind is None :
147- raise TypeError (f"The type of { content } does not support." )
148-
149- return kind .extension
117+ return [[k , v ["text" ], v ["avg_confidence" ]] for k , v in final_result .items ()]
150118
151119
152- class PDFExtracterError (Exception ):
120+ class RapidOCRPDFError (Exception ):
153121 pass
154122
155123
@@ -167,7 +135,7 @@ def main():
167135 )
168136 args = parser .parse_args ()
169137
170- pdf_extracter = PDFExtracter ()
138+ pdf_extracter = RapidOCRPDF ()
171139
172140 try :
173141 result = pdf_extracter (args .file_path , args .force_ocr )
0 commit comments