Skip to content

Commit e7e6f17

Browse files
authored
♻️ Clean up groupby patch diff (#41)
1 parent 0b09dc0 commit e7e6f17

1 file changed

Lines changed: 28 additions & 19 deletions

File tree

src/mapply/_groupby.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from types import MethodType
33
from typing import Any, Callable
44

5-
from mapply.parallel import multiprocessing_imap
5+
from mapply.parallel import multiprocessing_imap, tqdm
66

77
logger = logging.getLogger(__name__)
88

@@ -22,6 +22,7 @@ def run_groupwise_apply( # noqa:CCR001
2222
def apply(self, f, data, axis=0):
2323
# patching https://github.com/pandas-dev/pandas/blob/v1.5.3/pandas/core/groupby/ops.py#L823
2424
# with a multiprocessing_imap
25+
# +
2526
from pandas.core.groupby.ops import _is_indexed_like
2627

2728
mutated = False
@@ -30,33 +31,41 @@ def apply(self, f, data, axis=0):
3031
result_values = []
3132

3233
# This calls DataSplitter.__iter__
33-
# -
34-
# zipped = zip(group_keys, splitter)
34+
zipped = zip(group_keys, splitter)
35+
3536
# +
36-
splitter = list(splitter)
37+
group_axes_list = []
38+
splitter_gen = (
39+
(
40+
# mimic the side-effects commented out below
41+
object.__setattr__(group, "name", key)
42+
or group_axes_list.append(group.axes)
43+
or group
44+
)
45+
for key, group in zipped
46+
)
47+
splitter_gen = tqdm(splitter_gen, disable=True, total=splitter.ngroups)
3748
zipped = zip(
38-
group_keys,
39-
splitter,
4049
multiprocessing_imap(
41-
f, splitter, n_workers=n_workers, progressbar=progressbar
50+
f, splitter_gen, n_workers=n_workers, progressbar=progressbar
4251
),
52+
group_axes_list,
4353
)
4454

4555
# -
4656
# for key, group in zipped:
57+
# # Pinning name is needed for
58+
# # test_group_apply_once_per_group,
59+
# # test_inconsistent_return_type, test_set_group_name,
60+
# # test_group_name_available_in_inference_pass,
61+
# # test_groupby_multi_timezone
62+
# object.__setattr__(group, "name", key)
63+
# # group might be modified
64+
# group_axes = group.axes
65+
# res = f(group)
4766
# +
48-
for key, group, res in zipped:
49-
# Pinning name is needed for
50-
# test_group_apply_once_per_group,
51-
# test_inconsistent_return_type, test_set_group_name,
52-
# test_group_name_available_in_inference_pass,
53-
# test_groupby_multi_timezone
54-
object.__setattr__(group, "name", key)
55-
56-
# group might be modified
57-
group_axes = group.axes
58-
# -
59-
# res = f(group)
67+
for res, group_axes in zipped:
68+
# no changes made below this line
6069
if not mutated and not _is_indexed_like(res, group_axes, axis):
6170
mutated = True
6271
result_values.append(res)

0 commit comments

Comments
 (0)