Skip to content

Commit 682e782

Browse files
committed
ENH: enable strict zip (almost) everywhere
1 parent 51589f5 commit 682e782

File tree

7 files changed

+26
-15
lines changed

7 files changed

+26
-15
lines changed

docs/extensions/show_all_units.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,13 @@ def as_rest_table(data, full=False):
149149
data = data if data else [["No Data"]]
150150
table = []
151151
# max size of each column
152-
sizes = list(map(max, zip(*[[len(str(elt)) for elt in member] for member in data])))
152+
sizes = [
153+
max(*args)
154+
for args in zip(
155+
*[[len(str(elt)) for elt in member] for member in data],
156+
strict=True,
157+
)
158+
]
153159
num_elts = len(sizes)
154160

155161
if full:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ exclude = [
118118
ignore = [
119119
"E501",
120120
"B904",
121-
"B905", # zip-without-explicit-strict
122121
]
123122
select = [
124123
"E",
@@ -130,6 +129,7 @@ select = [
130129
"I", # isort
131130
"UP", # pyupgrade
132131
"NPY", # numpy specific rules
132+
"RUF007", # zip-instead-of-pairwise
133133
]
134134

135135
[tool.ruff.lint.isort]

unyt/_array_functions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,9 @@ def _histogramdd(
296296

297297
if weights is not None and hasattr(weights, "units"):
298298
counts *= weights.units
299-
return counts, tuple(_bin * getattr(s, "units", 1) for _bin, s in zip(bins, sample))
299+
return counts, tuple(
300+
_bin * getattr(s, "units", 1) for _bin, s in zip(bins, sample, strict=True)
301+
)
300302

301303

302304
if NUMPY_VERSION >= Version("1.24"):

unyt/_parsing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
66
"""
77

8+
import itertools
89
import token
910

1011
from sympy import Basic, Float, Integer, Rational, Symbol, sqrt
@@ -23,7 +24,7 @@ def _auto_positive_symbol(tokens, local_dict, global_dict):
2324
result = []
2425

2526
tokens.append((None, None)) # so zip traverses all tokens
26-
for tok, nextTok in zip(tokens, tokens[1:]):
27+
for tok, nextTok in itertools.pairwise(tokens):
2728
tokNum, tokVal = tok
2829
nextTokNum, nextTokVal = nextTok
2930
if tokNum == token.NAME:

unyt/array.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1368,7 +1368,7 @@ def from_astropy(cls, arr, unit_registry=None):
13681368
u = arr
13691369
_arr = 1.0 * u
13701370
ap_units = []
1371-
for base, exponent in zip(u.bases, u.powers):
1371+
for base, exponent in zip(u.bases, u.powers, strict=True):
13721372
unit_str = base.to_string()
13731373
# we have to do this because AstroPy is silly and defines
13741374
# hour as "h"
@@ -2114,7 +2114,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
21142114
# out_arr is an ndarray
21152115
out.units = Unit("", registry=self.units.registry)
21162116
elif isinstance(out, tuple):
2117-
for o, oa in zip(out, out_arr):
2117+
for o, oa in zip(out, out_arr, strict=True):
21182118
if o is None:
21192119
continue
21202120
o.units = oa.units
@@ -2710,7 +2710,7 @@ def loadtxt(fname, dtype="float", delimiter="\t", usecols=None, comments="#"):
27102710
arrays = [arrays]
27112711
if usecols is not None:
27122712
units = [units[col] for col in usecols]
2713-
ret = tuple(unyt_array(arr, unit) for arr, unit in zip(arrays, units))
2713+
ret = tuple(unyt_array(arr, unit) for arr, unit in zip(arrays, units, strict=True))
27142714
if len(ret) == 1:
27152715
return ret[0]
27162716
return ret

unyt/dimensions.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,9 @@ def new_f(*args, **kwargs):
277277
If the units do not match.
278278
279279
"""
280-
for arg_name, arg_value in chain(zip(names_of_args, args), kwargs.items()):
280+
for arg_name, arg_value in chain(
281+
zip(names_of_args, args, strict=False), kwargs.items()
282+
):
281283
if arg_name in arg_units: # function argument needs to be checked
282284
dimension = arg_units[arg_name]
283285
if not _has_dimensions(arg_value, dimension):
@@ -378,7 +380,7 @@ def new_f(*args, **kwargs):
378380
else:
379381
result_tuple = (results,)
380382

381-
for result, dimension in zip(result_tuple, r_units):
383+
for result, dimension in zip(result_tuple, r_units, strict=True):
382384
if not _has_dimensions(result, dimension):
383385
raise TypeError(f"result '{result}' does not match {dimension}")
384386
return results

unyt/tests/test_unyt_array.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -590,15 +590,15 @@ def test_comparisons():
590590
[True, True, False],
591591
)
592592

593-
for op, answer in zip(ops, answers):
593+
for op, answer in zip(ops, answers, strict=True):
594594
operate_and_compare(a1, a2, op, answer)
595-
for op, answer in zip(ops, answers):
595+
for op, answer in zip(ops, answers, strict=True):
596596
operate_and_compare(a1, dimless, op, answer)
597597

598-
for op, answer in zip(ops, answers):
598+
for op, answer in zip(ops, answers, strict=True):
599599
operate_and_compare(a1, a3, op, answer)
600600

601-
for op, answer in zip(ops, answers):
601+
for op, answer in zip(ops, answers, strict=True):
602602
operate_and_compare(a1, a3.in_units("cm"), op, answer)
603603

604604
# Check that comparisons with dimensionless quantities work in both
@@ -862,7 +862,7 @@ def test_iteration():
862862
"""
863863
a = np.arange(3)
864864
b = unyt_array(np.arange(3), "cm")
865-
for ia, ib in zip(a, b):
865+
for ia, ib in zip(a, b, strict=True):
866866
assert_equal(ia, ib.value)
867867
assert_equal(ib.units, b.units)
868868

@@ -1153,7 +1153,7 @@ def binary_ufunc_comparison(ufunc, a, b):
11531153
if isinstance(ret, tuple):
11541154
assert isinstance(out, tuple)
11551155
assert len(out) == len(ret)
1156-
for o, r in zip(out, ret):
1156+
for o, r in zip(out, ret, strict=True):
11571157
assert_array_equal(r, o)
11581158
else:
11591159
assert_array_equal(ret, out)

0 commit comments

Comments
 (0)