11"""Check types."""
22
3+ from collections .abc import Generator , Iterable , Iterator
34from functools import partial
4- from itertools import chain , filterfalse
5+ from itertools import chain , filterfalse , zip_longest
56from re import (
67 Match ,
78 Pattern ,
1011 search ,
1112)
1213from re import compile as rcompile
13- from typing import Callable , NamedTuple , Union
14+ from typing import Callable , NamedTuple , TypeVar , Union
1415
1516from proselint .registry .checks import Check , CheckResult , Padding
1617
17- CheckFn = Callable [[str , Check ], list [CheckResult ]]
18+ CheckFn = Callable [[str , Check ], Iterator [CheckResult ]]
19+
20+ T = TypeVar ("T" )
21+
22+
23+ def _takewhile_peek (
24+ predicate : Callable [[T ], bool ], iterable : Iterable [T ]
25+ ) -> Generator [T ]:
26+ """Take elements from `iterable` while `predicate` is consecutively true."""
27+ prev = True
28+ for x in iterable :
29+ current = predicate (x )
30+ if not prev and not current :
31+ break
32+ yield x
33+ prev = current
1834
1935
2036class Consistency (NamedTuple ):
@@ -29,32 +45,41 @@ def process_pair(
2945 check : Check ,
3046 flag : Union [RegexFlag , int ],
3147 pair : tuple [str , str ],
32- ) -> list [CheckResult ]:
48+ ) -> Iterator [CheckResult ]:
3349 """Check a term pair over `text`."""
34- matches = [list (finditer (term , text , flag )) for term in pair ]
50+ # Unzip the zip of pair matches while both of the last pair were truthy
51+ # Reads the minimum possible elements to generate results
52+ matches : tuple [tuple [Union [Match [str ], None ], ...], ...] = tuple (
53+ zip (
54+ * _takewhile_peek (
55+ all ,
56+ zip_longest (* (finditer (term , text , flag ) for term in pair )),
57+ )
58+ )
59+ )
3560
36- if not len ( matches [ 0 ]) and not len ( matches [ 1 ]) :
37- return []
61+ if not matches :
62+ return iter (())
3863
39- idx_minority = len ( matches [0 ]) > len ( matches [ 1 ])
64+ idx_minority = matches [1 ][ - 1 ] is None
4065 majority_term = pair [not idx_minority ]
41- return [
66+ return (
4267 CheckResult (
4368 start_pos = m .start () + check .offset [0 ],
4469 end_pos = m .end () + check .offset [1 ],
4570 check_path = check .path ,
4671 message = check .message .format (majority_term , m .group (0 )),
4772 replacements = majority_term ,
4873 )
49- for m in matches [idx_minority ]
50- ]
74+ for m in filter ( None , matches [idx_minority ])
75+ )
5176
52- def check (self , text : str , check : Check ) -> list [CheckResult ]:
77+ def check (self , text : str , check : Check ) -> Iterator [CheckResult ]:
5378 """Check the consistency of given term pairs in `text`."""
5479 flag = check .re_flag
5580 process_pair = partial (Consistency .process_pair , text , check , flag )
5681
57- return list ( chain .from_iterable (map (process_pair , self .term_pairs ) ))
82+ return chain .from_iterable (map (process_pair , self .term_pairs ))
5883
5984
6085class PreferredForms (NamedTuple ):
@@ -63,12 +88,12 @@ class PreferredForms(NamedTuple):
6388 items : dict [str , str ]
6489 padding : Padding = Padding .WORDS_IN_TEXT
6590
66- def check (self , text : str , check : Check ) -> list [CheckResult ]:
91+ def check (self , text : str , check : Check ) -> Iterator [CheckResult ]:
6792 """Check for terms to be replaced with a preferred form in `text`."""
6893 offset = self .padding .to_offset_from (check .offset )
6994 flag = check .re_flag
7095
71- return [
96+ return (
7297 CheckResult (
7398 start_pos = m .start () + offset [0 ],
7499 end_pos = m .end () + offset [1 ],
@@ -78,7 +103,7 @@ def check(self, text: str, check: Check) -> list[CheckResult]:
78103 )
79104 for original , replacement in self .items .items ()
80105 for m in finditer (self .padding .format (original ), text , flag )
81- ]
106+ )
82107
83108
84109class PreferredFormsSimple (NamedTuple ):
@@ -104,7 +129,7 @@ def map_match(
104129 replacements = replacement ,
105130 )
106131
107- def check (self , text : str , check : Check ) -> list [CheckResult ]:
132+ def check (self , text : str , check : Check ) -> Iterator [CheckResult ]:
108133 """Check for terms to be replaced with a preferred form in `text`."""
109134 offset = self .padding .to_offset_from (check .offset )
110135 flag = check .re_flag
@@ -116,7 +141,33 @@ def check(self, text: str, check: Check) -> list[CheckResult]:
116141 else next (iter (self .items ))
117142 )
118143
119- return list (map (map_match , finditer (pattern , text , flag )))
144+ return map (map_match , finditer (pattern , text , flag ))
145+
146+
147+ def _process_existence (
148+ pattern : str ,
149+ exceptions : tuple [str , ...],
150+ offset : tuple [int , int ],
151+ text : str ,
152+ check : Check ,
153+ ) -> Iterator [CheckResult ]:
154+ """Match against `pattern` respecting `offset` in `text`."""
155+ flag = check .re_flag
156+
157+ return (
158+ CheckResult (
159+ start_pos = m .start () + offset [0 ],
160+ end_pos = m .end () + offset [1 ],
161+ check_path = check .path ,
162+ message = check .message .format (m .group (0 ).strip ()),
163+ replacements = None ,
164+ )
165+ for m in finditer (pattern , text , flag )
166+ if not any (
167+ search (exception , m .group (0 ).strip (), flag )
168+ for exception in exceptions
169+ )
170+ )
120171
121172
122173class Existence (NamedTuple ):
@@ -126,31 +177,17 @@ class Existence(NamedTuple):
126177 padding : Padding = Padding .WORDS_IN_TEXT
127178 exceptions : tuple [str , ...] = ()
128179
129- def check (self , text : str , check : Check ) -> list [CheckResult ]:
180+ def check (self , text : str , check : Check ) -> Iterator [CheckResult ]:
130181 """Check for the existence of terms in `text`."""
131182 offset = self .padding .to_offset_from (check .offset )
132- flag = check .re_flag
133183
134184 pattern = self .padding .format (
135185 Padding .SAFE_JOIN .format ("|" .join (self .items ))
136186 if len (self .items ) > 1
137187 else next (iter (self .items ))
138188 )
139189
140- return [
141- CheckResult (
142- start_pos = m .start () + offset [0 ],
143- end_pos = m .end () + offset [1 ],
144- check_path = check .path ,
145- message = check .message .format (m .group (0 ).strip ()),
146- replacements = None ,
147- )
148- for m in finditer (pattern , text , flag )
149- if not any (
150- search (exception , m .group (0 ).strip (), flag )
151- for exception in self .exceptions
152- )
153- ]
190+ return _process_existence (pattern , self .exceptions , offset , text , check )
154191
155192
156193class ExistenceSimple (NamedTuple ):
@@ -159,24 +196,11 @@ class ExistenceSimple(NamedTuple):
159196 pattern : str
160197 exceptions : tuple [str , ...] = ()
161198
162- def check (self , text : str , check : Check ) -> list [CheckResult ]:
199+ def check (self , text : str , check : Check ) -> Iterator [CheckResult ]:
163200 """Check for the existence of a single pattern in `text`."""
164- flag = check .re_flag
165-
166- return [
167- CheckResult (
168- start_pos = m .start () + check .offset [0 ],
169- end_pos = m .end () + check .offset [1 ],
170- check_path = check .path ,
171- message = check .message .format (m .group (0 ).strip ()),
172- replacements = None ,
173- )
174- for m in finditer (self .pattern , text , flag )
175- if not any (
176- search (exception , m .group (0 ).strip (), flag )
177- for exception in self .exceptions
178- )
179- ]
201+ return _process_existence (
202+ self .pattern , self .exceptions , check .offset , text , check
203+ )
180204
181205
182206_DEFAULT_TOKENIZER = rcompile (
@@ -198,15 +222,15 @@ def _allowed_match(
198222 m_text = match .group (0 )
199223 return (m_text .lower () if ignore_case else m_text ) in allowed
200224
201- def check (self , text : str , check : Check ) -> list [CheckResult ]:
225+ def check (self , text : str , check : Check ) -> Iterator [CheckResult ]:
202226 """Check for words in `text` that are not `allowed`."""
203227 allowed_match = partial (
204228 ReverseExistence ._allowed_match ,
205229 self .allowed ,
206230 ignore_case = check .ignore_case ,
207231 )
208232
209- return [
233+ return (
210234 CheckResult (
211235 start_pos = m .start () + check .offset [0 ] + 1 ,
212236 end_pos = m .end () + check .offset [1 ],
@@ -215,7 +239,7 @@ def check(self, text: str, check: Check) -> list[CheckResult]:
215239 replacements = None ,
216240 )
217241 for m in filterfalse (allowed_match , self .TOKENIZER .finditer (text ))
218- ]
242+ )
219243
220244
221245CheckType = Union [
0 commit comments