1010import numpy as np
1111from rapidocr import RapidOCR
1212
13- from .logger import Logger
14- from .utils import which_type
13+ from .utils .logger import Logger
14+ from .utils .utils import error_log , which_type
15+
16+ logger = Logger (logger_name = __name__ ).get_log ()
1517
1618
1719class RapidOCRPDF :
1820 def __init__ (self , dpi = 200 , ocr_params : Optional [Dict ] = None ):
1921 self .dpi = dpi
2022 self .ocr_engine = RapidOCR (params = ocr_params )
2123 self .empty_list = []
22- self .logger = Logger (logger_name = __name__ ).get_log ()
2324
2425 def __call__ (
25- self , content : Union [str , Path , bytes ], force_ocr : bool = False
26+ self ,
27+ content : Union [str , Path , bytes ],
28+ force_ocr : bool = False ,
29+ page_num_list : Optional [List [int ]] = None ,
2630 ) -> List [List [Union [str , str , str ]]]:
2731 try :
2832 file_type = which_type (content )
@@ -35,10 +39,12 @@ def __call__(
3539 try :
3640 pdf_data = self .load_pdf (content )
3741 except RapidOCRPDFError as e :
38- self . logger .error (e )
42+ logger .error ("%s \n %s" , e , error_log () )
3943 return self .empty_list
4044
41- txts_dict , need_ocr_idxs = self .extract_texts (pdf_data , force_ocr )
45+ txts_dict , need_ocr_idxs = self .extract_texts (
46+ pdf_data , force_ocr , page_num_list
47+ )
4248
4349 ocr_res_dict = self .get_ocr_res_streaming (pdf_data , need_ocr_idxs )
4450
@@ -60,21 +66,41 @@ def load_pdf(pdf_content: Union[str, Path, bytes]) -> bytes:
6066
6167 raise RapidOCRPDFError (f"{ type (pdf_content )} is not in [str, Path, bytes]." )
6268
63- def extract_texts (self , pdf_data : bytes , force_ocr : bool ) -> Tuple [Dict , List ]:
69+ def extract_texts (
70+ self , pdf_data : bytes , force_ocr : bool , page_num_list : Optional [List [int ]]
71+ ) -> Tuple [Dict , List ]:
6472 texts , need_ocr_idxs = {}, []
6573 with fitz .open (stream = pdf_data ) as doc :
74+ page_num_list = self .get_page_num_range (page_num_list , doc .page_count )
6675 for i , page in enumerate (doc ):
76+ if page_num_list is not None and i not in page_num_list :
77+ continue
78+
6779 if force_ocr :
6880 need_ocr_idxs .append (i )
6981 continue
7082
7183 text = page .get_text ("text" , sort = True )
7284 if text :
73- texts [str ( i ) ] = text
85+ texts [i ] = text
7486 else :
7587 need_ocr_idxs .append (i )
7688 return texts , need_ocr_idxs
7789
90+ @staticmethod
91+ def get_page_num_range (
92+ page_num_list : Optional [List [int ]], page_count : int
93+ ) -> Optional [List [int ]]:
94+ if page_num_list is None :
95+ return None
96+
97+ if max (page_num_list ) >= page_count :
98+ raise RapidOCRPDFError (
99+ f"The max value of { page_num_list } is greater than total page nums: { page_count } "
100+ )
101+
102+ return page_num_list
103+
78104 def get_ocr_res_streaming (self , pdf_data : bytes , need_ocr_idxs : List ) -> Dict :
79105 def convert_img (page ):
80106 pix = page .get_pixmap (dpi = self .dpi )
@@ -96,7 +122,7 @@ def convert_img(page):
96122 sum (preds .scores ) / len (preds .scores ) if preds .scores else 0.0
97123 )
98124
99- ocr_res [str ( i ) ] = {
125+ ocr_res [i ] = {
100126 "text" : "\n " .join (preds .txts ),
101127 "avg_confidence" : avg_score ,
102128 }
@@ -121,27 +147,36 @@ class RapidOCRPDFError(Exception):
121147 pass
122148
123149
124- def main ( ):
150+ def parse_args ( arg_list : Optional [ List [ str ]] = None ):
125151 parser = argparse .ArgumentParser ()
126- parser .add_argument (
127- "-path" , "--file_path" , type = str , help = "File path, PDF or images"
128- )
152+ parser .add_argument ("pdf_path" , type = str )
153+ parser .add_argument ("--dpi" , type = int , default = 200 )
129154 parser .add_argument (
130155 "-f" ,
131156 "--force_ocr" ,
132157 action = "store_true" ,
133158 default = False ,
134159 help = "Whether to use ocr for all pages." ,
135160 )
136- args = parser .parse_args ()
161+ parser .add_argument (
162+ "--page_num_list" ,
163+ type = int ,
164+ nargs = "*" ,
165+ default = None ,
166+ help = "Which pages will be extracted. e.g. 0 1 2. Note: the index of page num starts from 0." ,
167+ )
168+ args = parser .parse_args (arg_list )
169+ return args
137170
138- pdf_extracter = RapidOCRPDF ()
139171
172+ def main (arg_list : Optional [List [str ]] = None ):
173+ args = parse_args (arg_list )
174+ pdf_extracter = RapidOCRPDF (args .dpi )
140175 try :
141- result = pdf_extracter (args .file_path , args .force_ocr )
176+ result = pdf_extracter (args .pdf_path , args .force_ocr , args . page_num_list )
142177 print (result )
143178 except Exception as e :
144- print ( f"[ERROR] { e } " )
179+ logger . error ( "%s \n %s" , e , error_log () )
145180
146181
147182if __name__ == "__main__" :
0 commit comments