diff --git a/proselint/config/__init__.py b/proselint/config/__init__.py index b0cd9f1de..2ea99a53a 100644 --- a/proselint/config/__init__.py +++ b/proselint/config/__init__.py @@ -1,9 +1,11 @@ """Configuration for proselint.""" import json +from collections.abc import Hashable, Mapping from importlib.resources import files +from itertools import chain from pathlib import Path -from typing import TypedDict +from typing import TypeAlias, TypedDict, TypeVar, cast from warnings import showwarning as warn from proselint import config @@ -23,14 +25,22 @@ class Config(TypedDict): checks: dict[str, bool] -DEFAULT: Config = json.loads((files(config) / "default.json").read_text()) +DEFAULT = cast( + "Config", json.loads((files(config) / "default.json").read_text()) +) +Checks: TypeAlias = Mapping[str, "bool | Checks"] +KT_co = TypeVar("KT_co", bound=Hashable, covariant=True) +VT_co = TypeVar("VT_co", covariant=True) -def _deepmerge_dicts(base: dict, overrides: dict) -> dict: + +def _deepmerge_dicts( + base: dict[KT_co, VT_co], overrides: dict[KT_co, VT_co] +) -> dict[KT_co, VT_co]: # fmt: off return base | overrides | { key: ( - _deepmerge_dicts(b_value, o_value) + _deepmerge_dicts(b_value, o_value) # pyright: ignore[reportUnknownArgumentType] if isinstance(b_value := base[key], dict) else o_value ) @@ -39,6 +49,29 @@ def _deepmerge_dicts(base: dict, overrides: dict) -> dict: } +def _flatten_checks(checks: Checks, prefix: str = "") -> dict[str, bool]: + return dict( + chain.from_iterable( + [(full_key, value)] + if isinstance(value, bool) + else _flatten_checks(value, full_key).items() + for key, value in checks.items() + for full_key in [f"{prefix}.{key}" if prefix else key] + ) + ) + + +def _sort_by_specificity(checks: dict[str, bool]) -> dict[str, bool]: + """Sort selected checks by depth (specificity) in descending order.""" + return dict( + sorted( + checks.items(), + key=lambda x: x[0].count("."), + reverse=True, + ) + ) + + def load_from(config_path: Path | None = None) -> Config: """ Read various config paths, allowing user overrides. @@ -50,9 +83,9 @@ def load_from(config_path: Path | None = None) -> Config: for path in config_paths: if path.is_file(): - result: Config = _deepmerge_dicts( - result, - json.loads(path.read_text()), + result = _deepmerge_dicts( + cast("dict[str, object]", result), + json.loads(path.read_text()), # pyright: ignore[reportAny] ) if path.suffix == ".json" and (old := path.with_suffix("")).is_file(): warn( @@ -62,4 +95,9 @@ def load_from(config_path: Path | None = None) -> Config: 0, ) - return result + result = cast("Config", result) + + return Config( + max_errors=result.get("max_errors", 0), + checks=_sort_by_specificity(_flatten_checks(result.get("checks", {}))), + ) diff --git a/proselint/registry/__init__.py b/proselint/registry/__init__.py index fd546e028..2bdf98e08 100644 --- a/proselint/registry/__init__.py +++ b/proselint/registry/__init__.py @@ -53,20 +53,24 @@ def checks(self) -> list[Check]: def get_all_enabled( self, enabled: dict[str, bool] = DEFAULT["checks"] ) -> list[Check]: - """Filter registered checks by config values based on their keys.""" - self.enabled_checks = enabled + """ + Filter registered checks by config values based on their keys. - enabled_checks: list[str] = [] - skipped_checks: list[str] = [] - for key, key_enabled in self.enabled_checks.items(): - (skipped_checks, enabled_checks)[key_enabled].append(key) + This assumes that keys are not nested, and sorted in descending order + of depth (specificity). For example, all keys should look like + `a.b.c`, and `a.b.c` should come before `a.b`. + """ + self.enabled_checks = enabled return [ check for check in self.checks - if not any(check.matches_partial(key) for key in skipped_checks) - and any( - check.path == key or check.matches_partial(key) - for key in enabled_checks + if next( + ( + value + for prefix, value in self.enabled_checks.items() + if check.matches_partial(prefix) + ), + False, ) ] diff --git a/tests/test-proselintrc.json b/tests/test-proselintrc.json index 0c055c5b0..bc762f6d0 100644 --- a/tests/test-proselintrc.json +++ b/tests/test-proselintrc.json @@ -10,7 +10,7 @@ "lexical_illusions": true, "malapropisms": true, "misc": true, - "mixed_metaphors": true, + "mixed_metaphors": true, "mondegreens": true, "needless_variants": true, "nonwords": true, diff --git a/tests/test_config_flag.py b/tests/test_config_flag.py index a70bc3b64..9f846178f 100644 --- a/tests/test_config_flag.py +++ b/tests/test_config_flag.py @@ -8,7 +8,9 @@ from proselint.command_line import get_parser, proselint from proselint.config import ( DEFAULT, - _deepmerge_dicts, # pyright: ignore[reportUnknownVariableType, reportPrivateUsage] + _deepmerge_dicts, # pyright: ignore[reportPrivateUsage] + _flatten_checks, # pyright: ignore[reportPrivateUsage] + _sort_by_specificity, # pyright: ignore[reportPrivateUsage] load_from, ) @@ -29,6 +31,44 @@ def test_deepmerge_dicts() -> None: } +def test_sort_by_specificity() -> None: + """Test sort_by_specificity sorts by dot count descending.""" + unsorted = { + "a": True, + "a.b.c": False, + "x.y": True, + "a.b": True, + } + + sorted_checks = _sort_by_specificity(unsorted) + keys = list(sorted_checks.keys()) + + dots = [key.count(".") for key in keys] + + assert dots == sorted(dots, reverse=True) + assert keys[0] == "a.b.c" + assert keys[-1] == "a" + + assert sorted_checks["a.b.c"] is False + assert sorted_checks["a"] is True + + +def test_flatten_checks() -> None: + """Test flatten_checks.""" + assert _flatten_checks({"a": True, "b": False}) == { + "a": True, + "b": False, + } + + assert _flatten_checks({"x": {"y": True, "z": False}, "w": True}) == { + "x.y": True, + "x.z": False, + "w": True, + } + + assert _flatten_checks({"a": {"b": {"c": True}}}) == {"a.b.c": True} + + def test_load_from() -> None: """Test load_options by specifying a user options path.""" overrides = load_from(CONFIG_FILE) diff --git a/tests/test_registry.py b/tests/test_registry.py new file mode 100644 index 000000000..4cfce7751 --- /dev/null +++ b/tests/test_registry.py @@ -0,0 +1,39 @@ +"""Test the registry module.""" + +from proselint.config import ( + _sort_by_specificity, # pyright: ignore[reportPrivateUsage] +) +from proselint.registry import CheckRegistry + + +def test_specific_overrides_general() -> None: + """Test that specific config keys override general ones.""" + checks = _sort_by_specificity( + { + "typography": True, + "typography.symbols": False, + "typography.symbols.curly_quotes": True, + "typography.punctuation.hyperbole": False, + } + ) + + registry = CheckRegistry() + enabled = registry.get_all_enabled(checks) + + paths = {check.path for check in enabled} + + assert "typography.symbols.curly_quotes" in paths + assert "typography.punctuation.hyperbole" not in paths + + assert all( + path == "typography.symbols.curly_quotes" + or not path.startswith("typography.symbols.") + for path in paths + ) + + assert any( + path.startswith("typography.") + and not path.startswith("typography.symbols.") + and path != "typography.punctuation.hyperbole" + for path in paths + )