Skip to content

Commit 78049c6

Browse files
odimkacopybara-github
authored andcommitted
Add sort=True option to kd.group_by.
PiperOrigin-RevId: 714955232 Change-Id: Ia45344864f4887635e23698f2b61d5a227627beb
1 parent 928df3e commit 78049c6

File tree

10 files changed

+321
-361
lines changed

10 files changed

+321
-361
lines changed

koladata/operators/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ cc_library(
116116
"//koladata:object_factories",
117117
"//koladata:pointwise_utils",
118118
"//koladata:uuid_utils",
119+
"//koladata/internal:casting",
119120
"//koladata/internal:data_bag",
120121
"//koladata/internal:data_item",
121122
"//koladata/internal:data_slice",

koladata/operators/operators.cc

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,14 +231,12 @@ OPERATOR("kde.slices._collapse", Collapse);
231231
OPERATOR_FAMILY("kde.slices._concat_or_stack",
232232
arolla::MakeVariadicInputOperatorFamily(ConcatOrStack));
233233
OPERATOR("kde.slices._dense_rank", DenseRank);
234+
OPERATOR_FAMILY("kde.slices._group_by_indices",
235+
arolla::MakeVariadicInputOperatorFamily(GroupByIndices));
234236
OPERATOR("kde.slices._inverse_mapping", InverseMapping);
235237
OPERATOR("kde.slices._ordinal_rank", OrdinalRank);
236238
OPERATOR("kde.slices._select", Select);
237239
OPERATOR_FAMILY("kde.slices.align", std::make_unique<AlignOperatorFamily>());
238-
OPERATOR_FAMILY("kde.slices.group_by_indices",
239-
arolla::MakeVariadicInputOperatorFamily(GroupByIndices));
240-
OPERATOR_FAMILY("kde.slices.group_by_indices_sorted",
241-
arolla::MakeVariadicInputOperatorFamily(GroupByIndicesSorted));
242240
OPERATOR("kde.slices.inverse_select", InverseSelect);
243241
OPERATOR("kde.slices.is_empty", IsEmpty);
244242
OPERATOR("kde.slices.reverse", Reverse);

koladata/operators/slices.cc

Lines changed: 46 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
#include "koladata/data_bag.h"
4545
#include "koladata/data_slice.h"
4646
#include "koladata/data_slice_qtype.h"
47+
#include "koladata/internal/casting.h"
4748
#include "koladata/internal/data_item.h"
4849
#include "koladata/internal/data_slice.h"
4950
#include "koladata/internal/dtype.h"
@@ -501,47 +502,6 @@ class GroupByIndicesProcessor {
501502
bool sort_;
502503
};
503504

504-
absl::StatusOr<DataSlice> GroupByIndicesImpl(
505-
absl::Span<const DataSlice* const> slices, bool sort) {
506-
if (slices.empty()) {
507-
return absl::InvalidArgumentError("requires at least 1 argument");
508-
}
509-
const auto& shape = slices[0]->GetShape();
510-
if (shape.rank() == 0) {
511-
return absl::FailedPreconditionError(
512-
"group_by is not supported for scalar data");
513-
}
514-
GroupByIndicesProcessor processor(shape.edges().back(),
515-
/*sort=*/sort);
516-
for (const auto* const ds_ptr : slices) {
517-
const auto& ds = *ds_ptr;
518-
if (!ds.GetShape().IsEquivalentTo(shape)) {
519-
return absl::FailedPreconditionError(
520-
"all arguments must have the same shape");
521-
}
522-
if (sort) {
523-
if (ds.slice().is_mixed_dtype()) {
524-
return absl::FailedPreconditionError(
525-
"sort is not supported for mixed dtype");
526-
}
527-
if (!internal::IsKodaScalarQTypeSortable(ds.slice().dtype())) {
528-
return absl::FailedPreconditionError(absl::StrCat(
529-
"sort is not supported for ", ds.slice().dtype()->name()));
530-
}
531-
}
532-
processor.ProcessGroupKey(ds.slice());
533-
}
534-
auto [indices_array, group_split_points, item_split_points] =
535-
processor.CreateFinalDataSlice();
536-
ASSIGN_OR_RETURN(auto new_shape,
537-
shape.RemoveDims(/*from=*/shape.rank() - 1)
538-
.AddDims({std::move(group_split_points),
539-
std::move(item_split_points)}));
540-
return DataSlice::Create(
541-
internal::DataSliceImpl::Create(std::move(indices_array)),
542-
std::move(new_shape), internal::DataItem(schema::kInt64));
543-
}
544-
545505
struct Slice {
546506
int64_t start;
547507
std::optional<int64_t> stop;
@@ -824,14 +784,52 @@ absl::StatusOr<DataSlice> ConcatOrStack(
824784

825785
absl::StatusOr<DataSlice> GroupByIndices(
826786
absl::Span<const DataSlice* const> slices) {
827-
return GroupByIndicesImpl(slices, /*sort=*/false);
828-
}
787+
if (slices.size() < 2) {
788+
return absl::InvalidArgumentError(
789+
absl::StrCat("_group_by_indices expected at least 2 arguments, but "
790+
"got ", slices.size()));
791+
}
829792

830-
absl::StatusOr<DataSlice> GroupByIndicesSorted(
831-
absl::Span<const DataSlice* const> slices) {
832-
return GroupByIndicesImpl(slices, /*sort=*/true);
793+
ASSIGN_OR_RETURN(auto sort_bool, GetBoolArgument(*slices[0], "sort"));
794+
795+
const auto& shape = slices[1]->GetShape();
796+
if (shape.rank() == 0) {
797+
return absl::InvalidArgumentError(
798+
"group_by arguments must be DataSlices with ndim > 0, got DataItems");
799+
}
800+
801+
GroupByIndicesProcessor processor(shape.edges().back(), /*sort=*/sort_bool);
802+
for (const auto* const ds_ptr : slices.subspan(1)) {
803+
const auto& ds = *ds_ptr;
804+
if (!ds.GetShape().IsEquivalentTo(shape)) {
805+
return absl::InvalidArgumentError(
806+
"all arguments must have the same shape");
807+
}
808+
if (sort_bool) {
809+
if (ds.slice().is_mixed_dtype()) {
810+
return absl::InvalidArgumentError(
811+
"sort is not supported for mixed dtype");
812+
}
813+
if (!internal::IsKodaScalarQTypeSortable(ds.slice().dtype())) {
814+
return absl::InvalidArgumentError(absl::StrCat(
815+
"sort is not supported for ",
816+
schema::schema_internal::GetQTypeName(ds.slice().dtype())));
817+
}
818+
}
819+
processor.ProcessGroupKey(ds.slice());
820+
}
821+
auto [indices_array, group_split_points, item_split_points] =
822+
processor.CreateFinalDataSlice();
823+
ASSIGN_OR_RETURN(auto new_shape,
824+
shape.RemoveDims(/*from=*/shape.rank() - 1)
825+
.AddDims({std::move(group_split_points),
826+
std::move(item_split_points)}));
827+
return DataSlice::Create(
828+
internal::DataSliceImpl::Create(std::move(indices_array)),
829+
std::move(new_shape), internal::DataItem(schema::kInt64));
833830
}
834831

832+
835833
absl::StatusOr<DataSlice> Unique(const DataSlice& x, const DataSlice& sort) {
836834
if (x.is_item()) {
837835
return x;
@@ -844,7 +842,8 @@ absl::StatusOr<DataSlice> Unique(const DataSlice& x, const DataSlice& sort) {
844842
}
845843
if (!internal::IsKodaScalarQTypeSortable(x.slice().dtype())) {
846844
return absl::FailedPreconditionError(absl::StrCat(
847-
"sort is not supported for ", x.slice().dtype()->name()));
845+
"sort is not supported for ",
846+
schema::schema_internal::GetQTypeName(x.slice().dtype())));
848847
}
849848
}
850849

koladata/operators/slices.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,6 @@ absl::StatusOr<DataSlice> IsEmpty(const DataSlice& obj);
5656
absl::StatusOr<DataSlice> GroupByIndices(
5757
absl::Span<const DataSlice* const> slices);
5858

59-
// kde.slices.group_by_indices_sorted.
60-
absl::StatusOr<DataSlice> GroupByIndicesSorted(
61-
absl::Span<const DataSlice* const> slices);
62-
6359
// kde.slices.unique.
6460
absl::StatusOr<DataSlice> Unique(const DataSlice& x, const DataSlice& sort);
6561

py/koladata/operators/slices.py

Lines changed: 43 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -575,20 +575,27 @@ def take(x, indices): # pylint: disable=unused-argument
575575
raise NotImplementedError('implemented in the backend')
576576

577577

578-
@optools.add_to_registry(aliases=['kde.group_by_indices'])
579578
@arolla.optools.as_backend_operator(
579+
'kde.slices._group_by_indices',
580+
qtype_inference_expr=qtypes.DATA_SLICE,
581+
)
582+
def _group_by_indices(*args): # pylint: disable=unused-argument
583+
raise NotImplementedError('implemented in the backend')
584+
585+
586+
@optools.add_to_registry(aliases=['kde.group_by_indices'])
587+
@optools.as_lambda_operator(
580588
'kde.slices.group_by_indices',
581589
qtype_constraints=[
582590
(
583591
M.qtype.get_field_count(P.args) > 0,
584592
'expected at least one argument',
585593
),
586594
qtype_utils.expect_data_slice_args(P.args),
595+
qtype_utils.expect_data_slice(P.sort),
587596
],
588-
qtype_inference_expr=qtypes.DATA_SLICE,
589-
experimental_aux_policy=py_boxing.DEFAULT_BOXING_POLICY,
590597
)
591-
def group_by_indices(*args): # pylint: disable=unused-argument
598+
def group_by_indices(*args, sort=False): # pylint: disable=redefined-outer-name, unused-argument
592599
"""Returns a indices DataSlice with injected grouped_by dimension.
593600
594601
The resulting DataSlice has get_ndim() + 1. The first `get_ndim() - 1`
@@ -599,7 +606,8 @@ def group_by_indices(*args): # pylint: disable=unused-argument
599606
dimension. `kde.take(x, kde.group_by_indices(x))` would group the items in
600607
`x` by their values.
601608
602-
Groups are ordered by the appearance of the first object in the group.
609+
If sort=True groups are ordered by value, otherwise groups are ordered by the
610+
appearance of the first object in the group.
603611
604612
Example 1:
605613
x: kd.slice([1, 3, 2, 1, 2, 3, 1, 3])
@@ -609,20 +617,26 @@ def group_by_indices(*args): # pylint: disable=unused-argument
609617
the items in the original DataSlice.
610618
611619
Example 2:
620+
x: kd.slice([1, 3, 2, 1, 2, 3, 1, 3], sort=True)
621+
result: kd.slice([[0, 3, 6], [2, 4], [1, 5, 7]])
622+
623+
Groups are now ordered by value.
624+
625+
Example 3:
612626
x: kd.slice([[1, 2, 1, 3, 1, 3], [1, 3, 1]])
613627
result: kd.slice([[[0, 2, 4], [1], [3, 5]], [[0, 2], [1]]])
614628
615629
We have three groups in the first sublist in order: 1, 2, 3 and two groups
616630
in the second sublist in order: 1, 3.
617631
Each sublist contains the indices of the items in the original sublist.
618632
619-
Example 3:
633+
Example 4:
620634
x: kd.slice([1, 3, 2, 1, None, 3, 1, None])
621635
result: kd.slice([[0, 3, 6], [1, 5], [2]])
622636
623637
Missing values are not listed in the result.
624638
625-
Example 4:
639+
Example 5:
626640
x: kd.slice([1, 2, 3, 1, 2, 3, 1, 3]),
627641
y: kd.slice([7, 4, 0, 9, 4, 0, 7, 0]),
628642
result: kd.slice([[0, 6], [1, 4], [2, 5, 7], [3]])
@@ -633,57 +647,13 @@ def group_by_indices(*args): # pylint: disable=unused-argument
633647
Args:
634648
*args: DataSlices keys to group by. All data slices must have the same
635649
shape. Scalar DataSlices are not supported.
650+
sort: Whether groups should be ordered by value.
636651
637652
Returns:
638653
INT64 DataSlice with indices and injected grouped_by dimension.
639654
"""
640-
raise NotImplementedError('implemented in the backend')
641-
642-
643-
@optools.add_to_registry(aliases=['kde.group_by_indices_sorted'])
644-
@arolla.optools.as_backend_operator(
645-
'kde.slices.group_by_indices_sorted',
646-
qtype_constraints=[
647-
(
648-
M.qtype.get_field_count(P.args) > 0,
649-
'expected at least one argument',
650-
),
651-
qtype_utils.expect_data_slice_args(P.args),
652-
],
653-
qtype_inference_expr=qtypes.DATA_SLICE,
654-
experimental_aux_policy=py_boxing.DEFAULT_BOXING_POLICY,
655-
)
656-
def group_by_indices_sorted(*args): # pylint: disable=unused-argument
657-
"""Similar to `group_by_indices` but groups are sorted by the value.
658-
659-
Each argument must contain the values of one type.
660-
661-
Mixed types are not supported.
662-
ExprQuote and DType are not supported.
663-
664-
Example 1:
665-
x: kd.slice([1, 3, 2, 1, 2, 3, 1, 3])
666-
result: kd.slice([[0, 3, 6], [2, 4], [1, 5, 7]])
667-
668-
We have three groups in order: 1, 2, 3. Each sublist contains the indices of
669-
the items in the original DataSlice.
670-
671-
Example 2:
672-
x: kd.slice([1, 2, 3, 1, 2, 3, 1, 3]),
673-
y: kd.slice([9, 4, 0, 3, 4, 0, 9, 0]),
674-
result: kd.slice([[3], [0, 6], [1, 4], [2, 5, 7]])
675-
676-
With several arguments keys is a tuple.
677-
In this example we have the following groups: (1, 3), (1, 9), (2, 4), (3, 0)
678-
679-
Args:
680-
*args: DataSlices keys to group by. All data slices must have the same
681-
shape. Scalar DataSlices are not supported.
682-
683-
Returns:
684-
INT64 DataSlice with indices and injected grouped_by dimension.
685-
"""
686-
raise NotImplementedError('implemented in the backend')
655+
args = arolla.optools.fix_trace_args(args)
656+
return M.core.apply_varargs(_group_by_indices, sort, args)
687657

688658

689659
@optools.add_to_registry(aliases=['kde.group_by'])
@@ -692,9 +662,10 @@ def group_by_indices_sorted(*args): # pylint: disable=unused-argument
692662
qtype_constraints=[
693663
qtype_utils.expect_data_slice(P.x),
694664
qtype_utils.expect_data_slice_args(P.args),
665+
qtype_utils.expect_data_slice(P.sort),
695666
],
696667
)
697-
def group_by(x, *args):
668+
def group_by(x, *args, sort=False): # pylint: disable=redefined-outer-name
698669
"""Returns permutation of `x` with injected grouped_by dimension.
699670
700671
The resulting DataSlice has get_ndim() + 1. The first `get_ndim() - 1`
@@ -705,38 +676,43 @@ def group_by(x, *args):
705676
keys. If length of `args` is greater than 1, the key is a tuple.
706677
If `args` is empty, the key is `x`.
707678
708-
Groups are ordered by the appearance of the first item in the group.
679+
If sort=True groups are ordered by value, otherwise groups are ordered by the
680+
appearance of the first object in the group.
709681
710682
Example 1:
711683
x: kd.slice([1, 3, 2, 1, 2, 3, 1, 3])
712684
result: kd.slice([[1, 1, 1], [3, 3, 3], [2, 2]])
713685
714686
Example 2:
687+
x: kd.slice([1, 3, 2, 1, 2, 3, 1, 3], sort=True)
688+
result: kd.slice([[1, 1, 1], [2, 2], [3, 3, 3]])
689+
690+
Example 3:
715691
x: kd.slice([[1, 2, 1, 3, 1, 3], [1, 3, 1]])
716692
result: kd.slice([[[1, 1, 1], [2], [3, 3]], [[1, 1], [3]]])
717693
718-
Example 3:
694+
Example 4:
719695
x: kd.slice([1, 3, 2, 1, None, 3, 1, None])
720696
result: kd.slice([[1, 1, 1], [3, 3], [2]])
721697
722698
Missing values are not listed in the result.
723699
724-
Example 4:
700+
Example 5:
725701
x: kd.slice([1, 2, 3, 4, 5, 6, 7, 8]),
726702
y: kd.slice([7, 4, 0, 9, 4, 0, 7, 0]),
727703
result: kd.slice([[1, 7], [2, 5], [3, 6, 8], [4]])
728704
729705
When *args is present, `x` is not used for the key.
730706
731-
Example 5:
707+
Example 6:
732708
x: kd.slice([1, 2, 3, 4, None, 6, 7, 8]),
733709
y: kd.slice([7, 4, 0, 9, 4, 0, 7, None]),
734710
result: kd.slice([[1, 7], [2, None], [3, 6], [4]])
735711
736712
Items with missing key is not listed in the result.
737713
Missing `x` values are missing in the result.
738714
739-
Example 6:
715+
Example 7:
740716
x: kd.slice([1, 2, 3, 4, 5, 6, 7, 8]),
741717
y: kd.slice([7, 4, 0, 9, 4, 0, 7, 0]),
742718
z: kd.slice([A, D, B, A, D, C, A, B]),
@@ -751,22 +727,25 @@ def group_by(x, *args):
751727
*args: DataSlices keys to group by. All data slices must have the same shape
752728
as x. Scalar DataSlices are not supported. If not present, `x` is used as
753729
the key.
730+
sort: Whether groups should be ordered by value.
754731
755732
Returns:
756733
DataSlice with the same shape and schema as `x` with injected grouped
757734
by dimension.
758735
"""
759736
args = arolla.optools.fix_trace_args(args)
760737
dispatch_op = arolla.types.DispatchOperator(
761-
'x, args',
738+
'x, args, sort',
762739
x_is_key_case=arolla.types.DispatchCase(
763-
take(P.x, group_by_indices(P.x)),
740+
take(P.x, _group_by_indices(P.sort, P.x)),
764741
condition=M.qtype.get_field_count(P.args) == 0,
765742
),
766743
# TODO: add assertion: x has the same shape as other args.
767-
default=take(P.x, M.core.apply_varargs(group_by_indices, P.args)),
744+
default=take(
745+
P.x, M.core.apply_varargs(_group_by_indices, P.sort, P.args)
746+
),
768747
)
769-
return dispatch_op(x, args)
748+
return dispatch_op(x, args, sort)
770749

771750

772751
@optools.as_lambda_operator(

0 commit comments

Comments
 (0)