Skip to content

Commit afed28b

Browse files
committed
refactor: use iterators for checks
1 parent 3bc0d13 commit afed28b

File tree

6 files changed

+169
-170
lines changed

6 files changed

+169
-170
lines changed

proselint/command_line.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def proselint(paths=None, config=None, version=None,
113113
except Exception:
114114
traceback.print_exc()
115115
sys.exit(2)
116-
errors = lint(f, config, debug=debug)
116+
errors = lint(f, config)
117117
num_errors += len(errors)
118118
print_errors(fp, errors, output_json=output_json, compact=compact)
119119

proselint/registry/checks/__init__.py

Lines changed: 37 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44

55
from __future__ import annotations
66

7+
from collections.abc import Iterator
78
from enum import Enum
9+
from itertools import chain, islice
10+
from math import ceil
811
from re import RegexFlag
912
from typing import TYPE_CHECKING, NamedTuple, Optional
1013

@@ -90,61 +93,57 @@ class CheckFlags(NamedTuple):
9093
ppm_threshold: int = 0
9194

9295
@staticmethod
93-
def truncate(results: list[CheckResult], limit: int) -> list[CheckResult]:
96+
def truncate(
97+
results: Iterator[CheckResult], limit: int
98+
) -> Iterator[CheckResult]:
9499
"""
95100
Truncate a list of results to a given threshold.
96101
97102
This also notes how many times the check flagged prior to truncation.
98103
"""
99-
if limit == 0 or (num_results := len(results)) <= limit:
104+
if limit == 0:
100105
return results
101106

102-
last_result = results[limit - 1]
103-
num_extras = num_results - limit
104-
last_message = last_result.message + " Found {} elsewhere.".format(
105-
"once" if num_extras == 1 else f"{num_extras} times"
106-
)
107-
108-
return [
109-
*results[0 : limit - 1],
110-
CheckResult(
111-
start_pos=last_result.start_pos,
112-
end_pos=last_result.end_pos,
113-
check_path=last_result.check_path,
114-
message=last_message,
115-
replacements=None,
107+
return chain(
108+
islice(results, limit - 1),
109+
(
110+
CheckResult(
111+
start_pos=result.start_pos,
112+
end_pos=result.end_pos,
113+
check_path=result.check_path,
114+
message=f"{result.message} Also found elsewhere.",
115+
replacements=result.replacements,
116+
)
117+
for result in islice(results, 1)
116118
),
117-
]
119+
)
118120

119121
@staticmethod
120122
def apply_threshold(
121-
results: list[CheckResult], threshold: int, length: int
122-
) -> list[CheckResult]:
123+
results: Iterator[CheckResult], threshold: int, length: int
124+
) -> Iterator[CheckResult]:
123125
"""Return an error if the specified PPM `threshold` is surpassed."""
124-
if threshold == 0 or length == 0 or (num_results := len(results)) < 2:
125-
return []
126+
if 0 in {threshold, length}:
127+
return results
126128

127-
length = max(length, 1000)
128-
if (ppm := (num_results / length) * 1e6) <= threshold:
129-
return []
130-
return [
129+
req_results = max(ceil((threshold / 1e6) * max(length, 1000)), 2)
130+
return (
131131
CheckResult(
132-
start_pos=results[0].start_pos,
133-
end_pos=results[0].end_pos,
134-
check_path=results[0].check_path,
135-
message=results[0].message + f" Reached {ppm:.0f} ppm.",
132+
start_pos=result.start_pos,
133+
end_pos=result.end_pos,
134+
check_path=result.check_path,
135+
message=f"{result.message} Surpassed {threshold} ppm.",
136136
replacements=None,
137137
)
138-
]
138+
for result in islice(results, req_results - 1, req_results)
139+
)
139140

140141
def apply(
141-
self, results: list[CheckResult], text_len: int
142-
) -> list[CheckResult]:
143-
"""Apply the specified flags to a list of `results`."""
142+
self, results: Iterator[CheckResult], text_len: int
143+
) -> Iterator[CheckResult]:
144+
"""Apply the specified flags to an iterator of `results`."""
144145
return CheckFlags.truncate(
145-
CheckFlags.apply_threshold(results, self.ppm_threshold, text_len)
146-
if self.ppm_threshold
147-
else results,
146+
CheckFlags.apply_threshold(results, self.ppm_threshold, text_len),
148147
self.results_limit,
149148
)
150149

@@ -176,14 +175,14 @@ def matches_partial(self, partial: str) -> bool:
176175

177176
return self.path_segments[:len(partial_segments)] == partial_segments
178177

179-
def check(self, text: str) -> list[CheckResult]:
178+
def check(self, text: str) -> Iterator[CheckResult]:
180179
"""Apply the check over `text`."""
181180
return (
182181
self.check_type
183182
if callable(self.check_type)
184183
else self.check_type.check
185184
)(text, self)
186185

187-
def check_with_flags(self, text: str) -> list[CheckResult]:
186+
def check_with_flags(self, text: str) -> Iterator[CheckResult]:
188187
"""Apply the check over `text`, including specified `CheckFlags`."""
189188
return self.flags.apply(self.check(text), len(text))

proselint/registry/checks/types.py

Lines changed: 78 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""Check types."""
22

3+
from collections.abc import Generator, Iterable, Iterator
34
from functools import partial
4-
from itertools import chain, filterfalse
5+
from itertools import chain, filterfalse, zip_longest
56
from re import (
67
Match,
78
Pattern,
@@ -10,11 +11,26 @@
1011
search,
1112
)
1213
from re import compile as rcompile
13-
from typing import Callable, NamedTuple, Union
14+
from typing import Callable, NamedTuple, TypeVar, Union
1415

1516
from proselint.registry.checks import Check, CheckResult, Padding
1617

17-
CheckFn = Callable[[str, Check], list[CheckResult]]
18+
CheckFn = Callable[[str, Check], Iterator[CheckResult]]
19+
20+
T = TypeVar("T")
21+
22+
23+
def _takewhile_peek(
24+
predicate: Callable[[T], bool], iterable: Iterable[T]
25+
) -> Generator[T]:
26+
"""Take elements from `iterable` while `predicate` is consecutively true."""
27+
prev = True
28+
for x in iterable:
29+
current = predicate(x)
30+
if not prev and not current:
31+
break
32+
yield x
33+
prev = current
1834

1935

2036
class Consistency(NamedTuple):
@@ -29,32 +45,41 @@ def process_pair(
2945
check: Check,
3046
flag: Union[RegexFlag, int],
3147
pair: tuple[str, str],
32-
) -> list[CheckResult]:
48+
) -> Iterator[CheckResult]:
3349
"""Check a term pair over `text`."""
34-
matches = [list(finditer(term, text, flag)) for term in pair]
50+
# Unzip the zip of pair matches while both of the last pair were truthy
51+
# Reads the minimum possible elements to generate results
52+
matches: tuple[tuple[Union[Match[str], None], ...], ...] = tuple(
53+
zip(
54+
*_takewhile_peek(
55+
all,
56+
zip_longest(*(finditer(term, text, flag) for term in pair)),
57+
)
58+
)
59+
)
3560

36-
if not len(matches[0]) and not len(matches[1]):
37-
return []
61+
if not matches:
62+
return iter(())
3863

39-
idx_minority = len(matches[0]) > len(matches[1])
64+
idx_minority = matches[1][-1] is None
4065
majority_term = pair[not idx_minority]
41-
return [
66+
return (
4267
CheckResult(
4368
start_pos=m.start() + check.offset[0],
4469
end_pos=m.end() + check.offset[1],
4570
check_path=check.path,
4671
message=check.message.format(majority_term, m.group(0)),
4772
replacements=majority_term,
4873
)
49-
for m in matches[idx_minority]
50-
]
74+
for m in filter(None, matches[idx_minority])
75+
)
5176

52-
def check(self, text: str, check: Check) -> list[CheckResult]:
77+
def check(self, text: str, check: Check) -> Iterator[CheckResult]:
5378
"""Check the consistency of given term pairs in `text`."""
5479
flag = check.re_flag
5580
process_pair = partial(Consistency.process_pair, text, check, flag)
5681

57-
return list(chain.from_iterable(map(process_pair, self.term_pairs)))
82+
return chain.from_iterable(map(process_pair, self.term_pairs))
5883

5984

6085
class PreferredForms(NamedTuple):
@@ -63,12 +88,12 @@ class PreferredForms(NamedTuple):
6388
items: dict[str, str]
6489
padding: Padding = Padding.WORDS_IN_TEXT
6590

66-
def check(self, text: str, check: Check) -> list[CheckResult]:
91+
def check(self, text: str, check: Check) -> Iterator[CheckResult]:
6792
"""Check for terms to be replaced with a preferred form in `text`."""
6893
offset = self.padding.to_offset_from(check.offset)
6994
flag = check.re_flag
7095

71-
return [
96+
return (
7297
CheckResult(
7398
start_pos=m.start() + offset[0],
7499
end_pos=m.end() + offset[1],
@@ -78,7 +103,7 @@ def check(self, text: str, check: Check) -> list[CheckResult]:
78103
)
79104
for original, replacement in self.items.items()
80105
for m in finditer(self.padding.format(original), text, flag)
81-
]
106+
)
82107

83108

84109
class PreferredFormsSimple(NamedTuple):
@@ -104,7 +129,7 @@ def map_match(
104129
replacements=replacement,
105130
)
106131

107-
def check(self, text: str, check: Check) -> list[CheckResult]:
132+
def check(self, text: str, check: Check) -> Iterator[CheckResult]:
108133
"""Check for terms to be replaced with a preferred form in `text`."""
109134
offset = self.padding.to_offset_from(check.offset)
110135
flag = check.re_flag
@@ -116,7 +141,33 @@ def check(self, text: str, check: Check) -> list[CheckResult]:
116141
else next(iter(self.items))
117142
)
118143

119-
return list(map(map_match, finditer(pattern, text, flag)))
144+
return map(map_match, finditer(pattern, text, flag))
145+
146+
147+
def _process_existence(
148+
pattern: str,
149+
exceptions: tuple[str, ...],
150+
offset: tuple[int, int],
151+
text: str,
152+
check: Check,
153+
) -> Iterator[CheckResult]:
154+
"""Match against `pattern` respecting `offset` in `text`."""
155+
flag = check.re_flag
156+
157+
return (
158+
CheckResult(
159+
start_pos=m.start() + offset[0],
160+
end_pos=m.end() + offset[1],
161+
check_path=check.path,
162+
message=check.message.format(m.group(0).strip()),
163+
replacements=None,
164+
)
165+
for m in finditer(pattern, text, flag)
166+
if not any(
167+
search(exception, m.group(0).strip(), flag)
168+
for exception in exceptions
169+
)
170+
)
120171

121172

122173
class Existence(NamedTuple):
@@ -126,31 +177,17 @@ class Existence(NamedTuple):
126177
padding: Padding = Padding.WORDS_IN_TEXT
127178
exceptions: tuple[str, ...] = ()
128179

129-
def check(self, text: str, check: Check) -> list[CheckResult]:
180+
def check(self, text: str, check: Check) -> Iterator[CheckResult]:
130181
"""Check for the existence of terms in `text`."""
131182
offset = self.padding.to_offset_from(check.offset)
132-
flag = check.re_flag
133183

134184
pattern = self.padding.format(
135185
Padding.SAFE_JOIN.format("|".join(self.items))
136186
if len(self.items) > 1
137187
else next(iter(self.items))
138188
)
139189

140-
return [
141-
CheckResult(
142-
start_pos=m.start() + offset[0],
143-
end_pos=m.end() + offset[1],
144-
check_path=check.path,
145-
message=check.message.format(m.group(0).strip()),
146-
replacements=None,
147-
)
148-
for m in finditer(pattern, text, flag)
149-
if not any(
150-
search(exception, m.group(0).strip(), flag)
151-
for exception in self.exceptions
152-
)
153-
]
190+
return _process_existence(pattern, self.exceptions, offset, text, check)
154191

155192

156193
class ExistenceSimple(NamedTuple):
@@ -159,24 +196,11 @@ class ExistenceSimple(NamedTuple):
159196
pattern: str
160197
exceptions: tuple[str, ...] = ()
161198

162-
def check(self, text: str, check: Check) -> list[CheckResult]:
199+
def check(self, text: str, check: Check) -> Iterator[CheckResult]:
163200
"""Check for the existence of a single pattern in `text`."""
164-
flag = check.re_flag
165-
166-
return [
167-
CheckResult(
168-
start_pos=m.start() + check.offset[0],
169-
end_pos=m.end() + check.offset[1],
170-
check_path=check.path,
171-
message=check.message.format(m.group(0).strip()),
172-
replacements=None,
173-
)
174-
for m in finditer(self.pattern, text, flag)
175-
if not any(
176-
search(exception, m.group(0).strip(), flag)
177-
for exception in self.exceptions
178-
)
179-
]
201+
return _process_existence(
202+
self.pattern, self.exceptions, check.offset, text, check
203+
)
180204

181205

182206
_DEFAULT_TOKENIZER = rcompile(
@@ -198,15 +222,15 @@ def _allowed_match(
198222
m_text = match.group(0)
199223
return (m_text.lower() if ignore_case else m_text) in allowed
200224

201-
def check(self, text: str, check: Check) -> list[CheckResult]:
225+
def check(self, text: str, check: Check) -> Iterator[CheckResult]:
202226
"""Check for words in `text` that are not `allowed`."""
203227
allowed_match = partial(
204228
ReverseExistence._allowed_match,
205229
self.allowed,
206230
ignore_case=check.ignore_case,
207231
)
208232

209-
return [
233+
return (
210234
CheckResult(
211235
start_pos=m.start() + check.offset[0] + 1,
212236
end_pos=m.end() + check.offset[1],
@@ -215,7 +239,7 @@ def check(self, text: str, check: Check) -> list[CheckResult]:
215239
replacements=None,
216240
)
217241
for m in filterfalse(allowed_match, self.TOKENIZER.finditer(text))
218-
]
242+
)
219243

220244

221245
CheckType = Union[

0 commit comments

Comments
 (0)