22
33from __future__ import annotations
44
5+ import datetime
56import os
67import sys
78import traceback
3334from fava .util .date import local_today
3435
3536if TYPE_CHECKING : # pragma: no cover
36- import datetime
3737 from collections .abc import Iterable
3838 from collections .abc import Mapping
3939 from collections .abc import Sequence
40+ from typing import Any
4041 from typing import Callable
42+ from typing import ParamSpec
43+ from typing import TypeVar
4144
4245 from fava .beans .abc import Directive
4346 from fava .beans .ingest import FileMemo
4649 HookOutput = list [tuple [str , list [Directive ]]]
4750 Hooks = Sequence [Callable [[HookOutput , Sequence [Directive ]], HookOutput ]]
4851
52+ P = ParamSpec ("P" )
53+ T = TypeVar ("T" )
54+
4955
5056class IngestError (BeancountError ):
5157 """An error with one of the importers."""
@@ -60,6 +66,16 @@ def __init__(self) -> None:
6066 )
6167
6268
69+ class ImporterInvalidTypeError (FavaAPIError ):
70+ """One of the importer methods returned an unexpected type."""
71+
72+ def __init__ (self , attr : str , expected : type [Any ], actual : Any ) -> None :
73+ super ().__init__ (
74+ f"Got unexpected type from importer as { attr } :"
75+ f" expected { expected !s} , got { type (actual )!s} :"
76+ )
77+
78+
6379class ImporterExtractError (ImporterMethodCallError ):
6480 """Error calling extract for importer."""
6581
@@ -155,61 +171,87 @@ class FileImporters:
155171 importers : list [FileImportInfo ]
156172
157173
158- def get_name (importer : BeanImporterProtocol | Importer ) -> str :
159- """Get the name of an importer."""
160- try :
161- if isinstance (importer , Importer ):
162- return importer .name
163- return importer .name ()
164- except Exception as err :
165- raise ImporterMethodCallError from err
174+ def _catch_any (func : Callable [P , T ]) -> Callable [P , T ]:
175+ """Helper to catch any exception that might be raised by the importer."""
166176
177+ def wrapper (* args : P .args , ** kwds : P .kwargs ) -> T :
178+ try :
179+ return func (* args , ** kwds )
180+ except Exception as err :
181+ if isinstance (err , ImporterInvalidTypeError ):
182+ raise
183+ raise ImporterMethodCallError from err
167184
168- def importer_identify (
169- importer : BeanImporterProtocol | Importer , path : Path
170- ) -> bool :
171- """Get the name of an importer."""
172- try :
173- if isinstance (importer , Importer ):
174- return importer .identify (str (path ))
175- return importer .identify (get_cached_file (path ))
176- except Exception as err :
177- raise ImporterMethodCallError from err
185+ return wrapper
178186
179187
180- def file_import_info (
181- path : Path ,
182- importer : BeanImporterProtocol | Importer ,
183- ) -> FileImportInfo :
184- """Generate info about a file with an importer."""
185- filename = str (path )
186- try :
188+ def _assert_type (attr : str , value : T , type_ : type [T ]) -> T :
189+ """Helper to validate types return by importer methods."""
190+ if not isinstance (value , type_ ):
191+ raise ImporterInvalidTypeError (attr , type_ , value )
192+ return value
193+
194+
195+ class WrappedImporter :
196+ """A wrapper to safely call importer methods."""
197+
198+ importer : BeanImporterProtocol | Importer
199+
200+ def __init__ (self , importer : BeanImporterProtocol | Importer ) -> None :
201+ self .importer = importer
202+
203+ @property
204+ @_catch_any
205+ def name (self ) -> str :
206+ """Get the name of the importer."""
207+ importer = self .importer
208+ name = (
209+ importer .name
210+ if isinstance (importer , Importer )
211+ else importer .name ()
212+ )
213+ return _assert_type ("name" , name , str )
214+
215+ @_catch_any
216+ def identify (self : WrappedImporter , path : Path ) -> bool :
217+ """Whether the importer is matching the file."""
218+ importer = self .importer
219+ matches = (
220+ importer .identify (str (path ))
221+ if isinstance (importer , Importer )
222+ else importer .identify (get_cached_file (path ))
223+ )
224+ return _assert_type ("identify" , matches , bool )
225+
226+ @_catch_any
227+ def file_import_info (self , path : Path ) -> FileImportInfo :
228+ """Generate info about a file with an importer."""
229+ importer = self .importer
187230 if isinstance (importer , Importer ):
188- account = importer .account (filename )
189- date = importer .date (filename )
190- name = importer .filename (filename )
231+ str_path = str (path )
232+ account = importer .account (str_path )
233+ date = importer .date (str_path )
234+ filename = importer .filename (str_path )
191235 else :
192236 file = get_cached_file (path )
193237 account = importer .file_account (file )
194238 date = importer .file_date (file )
195- name = importer .file_name (file )
196- except Exception as err :
197- raise ImporterMethodCallError from err
239+ filename = importer .file_name (file )
198240
199- return FileImportInfo (
200- get_name ( importer ) ,
201- account or "" ,
202- date or local_today (),
203- name or Path ( filename ) .name ,
204- )
241+ return FileImportInfo (
242+ self . name ,
243+ _assert_type ( " account" , account or "" , str ) ,
244+ _assert_type ( " date" , date or local_today (), datetime . date ),
245+ _assert_type ( " filename" , filename or path .name , str ) ,
246+ )
205247
206248
207249# Copied here from beangulp to minimise the imports.
208250_FILE_TOO_LARGE_THRESHOLD = 8 * 1024 * 1024
209251
210252
211253def find_imports (
212- config : Sequence [BeanImporterProtocol | Importer ], directory : Path
254+ config : Sequence [WrappedImporter ], directory : Path
213255) -> Iterable [FileImporters ]:
214256 """Pair files and matching importers.
215257
@@ -223,31 +265,32 @@ def find_imports(
223265 continue
224266
225267 importers = [
226- file_import_info (path , importer )
268+ importer . file_import_info (path )
227269 for importer in config
228- if importer_identify ( importer , path )
270+ if importer . identify ( path )
229271 ]
230272 yield FileImporters (
231273 name = str (path ), basename = path .name , importers = importers
232274 )
233275
234276
235277def extract_from_file (
236- importer : BeanImporterProtocol | Importer ,
278+ wrapped_importer : WrappedImporter ,
237279 path : Path ,
238280 existing_entries : Sequence [Directive ],
239281) -> list [Directive ]:
240282 """Import entries from a document.
241283
242284 Args:
243- importer : The importer instance to handle the document.
285+ wrapped_importer : The importer instance to handle the document.
244286 path: Filesystem path to the document.
245287 existing_entries: Existing entries.
246288
247289 Returns:
248290 The list of imported entries.
249291 """
250292 filename = str (path )
293+ importer = wrapped_importer .importer
251294 if isinstance (importer , Importer ):
252295 entries = importer .extract (filename , existing = existing_entries )
253296 else :
@@ -269,7 +312,7 @@ def extract_from_file(
269312
270313def load_import_config (
271314 module_path : Path ,
272- ) -> tuple [Mapping [str , BeanImporterProtocol | Importer ], Hooks ]:
315+ ) -> tuple [Mapping [str , WrappedImporter ], Hooks ]:
273316 """Load the given import config and extract importers and hooks.
274317
275318 Args:
@@ -311,7 +354,8 @@ def load_import_config(
311354 "not satisfy importer protocol"
312355 )
313356 raise ImportConfigLoadError (msg )
314- importers [get_name (importer )] = importer
357+ wrapped_importer = WrappedImporter (importer )
358+ importers [wrapped_importer .name ] = wrapped_importer
315359 return importers , hooks
316360
317361
@@ -320,7 +364,7 @@ class IngestModule(FavaModule):
320364
321365 def __init__ (self , ledger : FavaLedger ) -> None :
322366 super ().__init__ (ledger )
323- self .importers : Mapping [str , BeanImporterProtocol | Importer ] = {}
367+ self .importers : Mapping [str , WrappedImporter ] = {}
324368 self .hooks : Hooks = []
325369 self .mtime : int | None = None
326370 self .errors : list [IngestError ] = []
@@ -359,7 +403,7 @@ def load_file(self) -> None: # noqa: D102
359403 try :
360404 self .importers , self .hooks = load_import_config (module_path )
361405 self .mtime = new_mtime
362- except ImportConfigLoadError as error :
406+ except FavaAPIError as error :
363407 msg = f"Error in import config '{ module_path } ': { error !s} "
364408 self ._error (msg )
365409
0 commit comments