Skip to content

Commit 2261c27

Browse files
committed
import: assert importer returns valid types
1 parent f262d5b commit 2261c27

File tree

2 files changed

+116
-59
lines changed

2 files changed

+116
-59
lines changed

src/fava/core/ingest.py

Lines changed: 91 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import datetime
56
import os
67
import sys
78
import traceback
@@ -33,11 +34,13 @@
3334
from fava.util.date import local_today
3435

3536
if TYPE_CHECKING: # pragma: no cover
36-
import datetime
3737
from collections.abc import Iterable
3838
from collections.abc import Mapping
3939
from collections.abc import Sequence
40+
from typing import Any
4041
from typing import Callable
42+
from typing import ParamSpec
43+
from typing import TypeVar
4144

4245
from fava.beans.abc import Directive
4346
from fava.beans.ingest import FileMemo
@@ -46,6 +49,9 @@
4649
HookOutput = list[tuple[str, list[Directive]]]
4750
Hooks = Sequence[Callable[[HookOutput, Sequence[Directive]], HookOutput]]
4851

52+
P = ParamSpec("P")
53+
T = TypeVar("T")
54+
4955

5056
class IngestError(BeancountError):
5157
"""An error with one of the importers."""
@@ -60,6 +66,16 @@ def __init__(self) -> None:
6066
)
6167

6268

69+
class ImporterInvalidTypeError(FavaAPIError):
70+
"""One of the importer methods returned an unexpected type."""
71+
72+
def __init__(self, attr: str, expected: type[Any], actual: Any) -> None:
73+
super().__init__(
74+
f"Got unexpected type from importer as {attr}:"
75+
f" expected {expected!s}, got {type(actual)!s}:"
76+
)
77+
78+
6379
class ImporterExtractError(ImporterMethodCallError):
6480
"""Error calling extract for importer."""
6581

@@ -155,61 +171,87 @@ class FileImporters:
155171
importers: list[FileImportInfo]
156172

157173

158-
def get_name(importer: BeanImporterProtocol | Importer) -> str:
159-
"""Get the name of an importer."""
160-
try:
161-
if isinstance(importer, Importer):
162-
return importer.name
163-
return importer.name()
164-
except Exception as err:
165-
raise ImporterMethodCallError from err
174+
def _catch_any(func: Callable[P, T]) -> Callable[P, T]:
175+
"""Helper to catch any exception that might be raised by the importer."""
166176

177+
def wrapper(*args: P.args, **kwds: P.kwargs) -> T:
178+
try:
179+
return func(*args, **kwds)
180+
except Exception as err:
181+
if isinstance(err, ImporterInvalidTypeError):
182+
raise
183+
raise ImporterMethodCallError from err
167184

168-
def importer_identify(
169-
importer: BeanImporterProtocol | Importer, path: Path
170-
) -> bool:
171-
"""Get the name of an importer."""
172-
try:
173-
if isinstance(importer, Importer):
174-
return importer.identify(str(path))
175-
return importer.identify(get_cached_file(path))
176-
except Exception as err:
177-
raise ImporterMethodCallError from err
185+
return wrapper
178186

179187

180-
def file_import_info(
181-
path: Path,
182-
importer: BeanImporterProtocol | Importer,
183-
) -> FileImportInfo:
184-
"""Generate info about a file with an importer."""
185-
filename = str(path)
186-
try:
188+
def _assert_type(attr: str, value: T, type_: type[T]) -> T:
189+
"""Helper to validate types return by importer methods."""
190+
if not isinstance(value, type_):
191+
raise ImporterInvalidTypeError(attr, type_, value)
192+
return value
193+
194+
195+
class WrappedImporter:
196+
"""A wrapper to safely call importer methods."""
197+
198+
importer: BeanImporterProtocol | Importer
199+
200+
def __init__(self, importer: BeanImporterProtocol | Importer) -> None:
201+
self.importer = importer
202+
203+
@property
204+
@_catch_any
205+
def name(self) -> str:
206+
"""Get the name of the importer."""
207+
importer = self.importer
208+
name = (
209+
importer.name
210+
if isinstance(importer, Importer)
211+
else importer.name()
212+
)
213+
return _assert_type("name", name, str)
214+
215+
@_catch_any
216+
def identify(self: WrappedImporter, path: Path) -> bool:
217+
"""Whether the importer is matching the file."""
218+
importer = self.importer
219+
matches = (
220+
importer.identify(str(path))
221+
if isinstance(importer, Importer)
222+
else importer.identify(get_cached_file(path))
223+
)
224+
return _assert_type("identify", matches, bool)
225+
226+
@_catch_any
227+
def file_import_info(self, path: Path) -> FileImportInfo:
228+
"""Generate info about a file with an importer."""
229+
importer = self.importer
187230
if isinstance(importer, Importer):
188-
account = importer.account(filename)
189-
date = importer.date(filename)
190-
name = importer.filename(filename)
231+
str_path = str(path)
232+
account = importer.account(str_path)
233+
date = importer.date(str_path)
234+
filename = importer.filename(str_path)
191235
else:
192236
file = get_cached_file(path)
193237
account = importer.file_account(file)
194238
date = importer.file_date(file)
195-
name = importer.file_name(file)
196-
except Exception as err:
197-
raise ImporterMethodCallError from err
239+
filename = importer.file_name(file)
198240

199-
return FileImportInfo(
200-
get_name(importer),
201-
account or "",
202-
date or local_today(),
203-
name or Path(filename).name,
204-
)
241+
return FileImportInfo(
242+
self.name,
243+
_assert_type("account", account or "", str),
244+
_assert_type("date", date or local_today(), datetime.date),
245+
_assert_type("filename", filename or path.name, str),
246+
)
205247

206248

207249
# Copied here from beangulp to minimise the imports.
208250
_FILE_TOO_LARGE_THRESHOLD = 8 * 1024 * 1024
209251

210252

211253
def find_imports(
212-
config: Sequence[BeanImporterProtocol | Importer], directory: Path
254+
config: Sequence[WrappedImporter], directory: Path
213255
) -> Iterable[FileImporters]:
214256
"""Pair files and matching importers.
215257
@@ -223,31 +265,32 @@ def find_imports(
223265
continue
224266

225267
importers = [
226-
file_import_info(path, importer)
268+
importer.file_import_info(path)
227269
for importer in config
228-
if importer_identify(importer, path)
270+
if importer.identify(path)
229271
]
230272
yield FileImporters(
231273
name=str(path), basename=path.name, importers=importers
232274
)
233275

234276

235277
def extract_from_file(
236-
importer: BeanImporterProtocol | Importer,
278+
wrapped_importer: WrappedImporter,
237279
path: Path,
238280
existing_entries: Sequence[Directive],
239281
) -> list[Directive]:
240282
"""Import entries from a document.
241283
242284
Args:
243-
importer: The importer instance to handle the document.
285+
wrapped_importer: The importer instance to handle the document.
244286
path: Filesystem path to the document.
245287
existing_entries: Existing entries.
246288
247289
Returns:
248290
The list of imported entries.
249291
"""
250292
filename = str(path)
293+
importer = wrapped_importer.importer
251294
if isinstance(importer, Importer):
252295
entries = importer.extract(filename, existing=existing_entries)
253296
else:
@@ -269,7 +312,7 @@ def extract_from_file(
269312

270313
def load_import_config(
271314
module_path: Path,
272-
) -> tuple[Mapping[str, BeanImporterProtocol | Importer], Hooks]:
315+
) -> tuple[Mapping[str, WrappedImporter], Hooks]:
273316
"""Load the given import config and extract importers and hooks.
274317
275318
Args:
@@ -311,7 +354,8 @@ def load_import_config(
311354
"not satisfy importer protocol"
312355
)
313356
raise ImportConfigLoadError(msg)
314-
importers[get_name(importer)] = importer
357+
wrapped_importer = WrappedImporter(importer)
358+
importers[wrapped_importer.name] = wrapped_importer
315359
return importers, hooks
316360

317361

@@ -320,7 +364,7 @@ class IngestModule(FavaModule):
320364

321365
def __init__(self, ledger: FavaLedger) -> None:
322366
super().__init__(ledger)
323-
self.importers: Mapping[str, BeanImporterProtocol | Importer] = {}
367+
self.importers: Mapping[str, WrappedImporter] = {}
324368
self.hooks: Hooks = []
325369
self.mtime: int | None = None
326370
self.errors: list[IngestError] = []
@@ -359,7 +403,7 @@ def load_file(self) -> None: # noqa: D102
359403
try:
360404
self.importers, self.hooks = load_import_config(module_path)
361405
self.mtime = new_mtime
362-
except ImportConfigLoadError as error:
406+
except FavaAPIError as error:
363407
msg = f"Error in import config '{module_path}': {error!s}"
364408
self._error(msg)
365409

tests/test_core_ingest.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,13 @@
1212
from fava.beans.abc import Note
1313
from fava.beans.abc import Transaction
1414
from fava.beans.ingest import BeanImporterProtocol
15-
from fava.core.ingest import file_import_info
1615
from fava.core.ingest import FileImportInfo
1716
from fava.core.ingest import filepath_in_primary_imports_folder
18-
from fava.core.ingest import get_name
1917
from fava.core.ingest import ImportConfigLoadError
20-
from fava.core.ingest import importer_identify
2118
from fava.core.ingest import ImporterExtractError
19+
from fava.core.ingest import ImporterInvalidTypeError
2220
from fava.core.ingest import load_import_config
21+
from fava.core.ingest import WrappedImporter
2322
from fava.helpers import FavaAPIError
2423
from fava.serialisation import serialise
2524
from fava.util.date import local_today
@@ -40,7 +39,7 @@ def test_ingest_file_import_info(
4039
assert importer
4140

4241
csv_path = test_data_dir / "import.csv"
43-
info = file_import_info(csv_path, importer)
42+
info = importer.file_import_info(csv_path)
4443
assert info.account == "Assets:Checking"
4544

4645

@@ -49,7 +48,7 @@ def __init__(self, acc: str = "Assets:Checking") -> None:
4948
self.acc = acc
5049

5150
def name(self) -> str:
52-
return self.acc
51+
return f"MinimalImporter({self.acc})"
5352

5453
def identify(self, file: FileMemo) -> bool:
5554
return self.acc in file.name
@@ -61,11 +60,11 @@ def file_account(self, _file: FileMemo) -> str:
6160
def test_ingest_file_import_info_minimal_importer(test_data_dir: Path) -> None:
6261
csv_path = test_data_dir / "import.csv"
6362

64-
info = file_import_info(csv_path, MinimalImporter("rawfile"))
65-
assert isinstance(info.account, str)
63+
importer = WrappedImporter(MinimalImporter())
64+
info = importer.file_import_info(csv_path)
6665
assert info == FileImportInfo(
67-
"rawfile",
68-
"rawfile",
66+
"MinimalImporter(Assets:Checking)",
67+
"Assets:Checking",
6968
local_today(),
7069
"import.csv",
7170
)
@@ -82,8 +81,9 @@ def test_ingest_file_import_info_account_method_errors(
8281
) -> None:
8382
csv_path = test_data_dir / "import.csv"
8483

84+
importer = WrappedImporter(AccountNameErrors())
8585
with pytest.raises(FavaAPIError) as err:
86-
file_import_info(csv_path, AccountNameErrors())
86+
importer.file_import_info(csv_path)
8787
assert "Some error reason..." in err.value.message
8888

8989

@@ -96,8 +96,9 @@ def identify(self, _file: FileMemo) -> bool:
9696
def test_ingest_identify_errors(test_data_dir: Path) -> None:
9797
csv_path = test_data_dir / "import.csv"
9898

99+
importer = WrappedImporter(IdentifyErrors())
99100
with pytest.raises(FavaAPIError) as err:
100-
importer_identify(IdentifyErrors(), csv_path)
101+
importer.identify(csv_path)
101102
assert "IDENTIFY_ERRORS" in err.value.message
102103

103104

@@ -108,11 +109,23 @@ def name(self) -> str:
108109

109110

110111
def test_ingest_get_name_errors() -> None:
112+
importer = WrappedImporter(ImporterNameErrors())
111113
with pytest.raises(FavaAPIError) as err:
112-
get_name(ImporterNameErrors())
114+
assert importer.name
113115
assert "GET_NAME_WILL_ERROR" in err.value.message
114116

115117

118+
class ImporterNameInvalidType(MinimalImporter):
119+
def name(self) -> str:
120+
return False # type: ignore[return-value]
121+
122+
123+
def test_ingest_get_name_invalid_type() -> None:
124+
importer = WrappedImporter(ImporterNameInvalidType())
125+
with pytest.raises(ImporterInvalidTypeError):
126+
assert importer.name
127+
128+
116129
@pytest.mark.skipif(
117130
sys.platform == "win32", reason="different error on windows"
118131
)

0 commit comments

Comments
 (0)