|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 |
| -import itertools |
4 |
| -from collections import Counter |
5 |
| -from collections.abc import Iterable, Iterator, Sequence |
| 3 | +from collections import Counter, defaultdict |
| 4 | +from collections.abc import Callable, Hashable, Iterable, Iterator, Sequence |
6 | 5 | from typing import TYPE_CHECKING, Literal, TypeVar, Union, cast
|
7 | 6 |
|
8 | 7 | import pandas as pd
|
@@ -269,10 +268,7 @@ def _combine_all_along_first_dim(
|
269 | 268 | combine_attrs: CombineAttrsOptions = "drop",
|
270 | 269 | ):
|
271 | 270 | # Group into lines of datasets which must be combined along dim
|
272 |
| - # need to sort by _new_tile_id first for groupby to work |
273 |
| - # TODO: is the sorted need? |
274 |
| - combined_ids = dict(sorted(combined_ids.items(), key=_new_tile_id)) |
275 |
| - grouped = itertools.groupby(combined_ids.items(), key=_new_tile_id) |
| 271 | + grouped = groupby_defaultdict(list(combined_ids.items()), key=_new_tile_id) |
276 | 272 |
|
277 | 273 | # Combine all of these datasets along dim
|
278 | 274 | new_combined_ids = {}
|
@@ -606,6 +602,21 @@ def vars_as_keys(ds):
|
606 | 602 | return tuple(sorted(ds))
|
607 | 603 |
|
608 | 604 |
|
| 605 | +K = TypeVar("K", bound=Hashable) |
| 606 | + |
| 607 | + |
| 608 | +def groupby_defaultdict( |
| 609 | + iter: list[T], |
| 610 | + key: Callable[[T], K], |
| 611 | +) -> Iterator[tuple[K, Iterator[T]]]: |
| 612 | + """replacement for itertools.groupby""" |
| 613 | + idx = defaultdict(list) |
| 614 | + for i, obj in enumerate(iter): |
| 615 | + idx[key(obj)].append(i) |
| 616 | + for k, ix in idx.items(): |
| 617 | + yield k, (iter[i] for i in ix) |
| 618 | + |
| 619 | + |
609 | 620 | def _combine_single_variable_hypercube(
|
610 | 621 | datasets,
|
611 | 622 | fill_value=dtypes.NA,
|
@@ -965,8 +976,7 @@ def combine_by_coords(
|
965 | 976 | ]
|
966 | 977 |
|
967 | 978 | # Group by data vars
|
968 |
| - sorted_datasets = sorted(data_objects, key=vars_as_keys) |
969 |
| - grouped_by_vars = itertools.groupby(sorted_datasets, key=vars_as_keys) |
| 979 | + grouped_by_vars = groupby_defaultdict(data_objects, key=vars_as_keys) |
970 | 980 |
|
971 | 981 | # Perform the multidimensional combine on each group of data variables
|
972 | 982 | # before merging back together
|
|
0 commit comments