@@ -575,20 +575,27 @@ def take(x, indices): # pylint: disable=unused-argument
575575 raise NotImplementedError ('implemented in the backend' )
576576
577577
578- @optools .add_to_registry (aliases = ['kde.group_by_indices' ])
579578@arolla .optools .as_backend_operator (
579+ 'kde.slices._group_by_indices' ,
580+ qtype_inference_expr = qtypes .DATA_SLICE ,
581+ )
582+ def _group_by_indices (* args ): # pylint: disable=unused-argument
583+ raise NotImplementedError ('implemented in the backend' )
584+
585+
586+ @optools .add_to_registry (aliases = ['kde.group_by_indices' ])
587+ @optools .as_lambda_operator (
580588 'kde.slices.group_by_indices' ,
581589 qtype_constraints = [
582590 (
583591 M .qtype .get_field_count (P .args ) > 0 ,
584592 'expected at least one argument' ,
585593 ),
586594 qtype_utils .expect_data_slice_args (P .args ),
595+ qtype_utils .expect_data_slice (P .sort ),
587596 ],
588- qtype_inference_expr = qtypes .DATA_SLICE ,
589- experimental_aux_policy = py_boxing .DEFAULT_BOXING_POLICY ,
590597)
591- def group_by_indices (* args ): # pylint: disable=unused-argument
598+ def group_by_indices (* args , sort = False ): # pylint: disable=redefined-outer-name, unused-argument
592599 """Returns a indices DataSlice with injected grouped_by dimension.
593600
594601 The resulting DataSlice has get_ndim() + 1. The first `get_ndim() - 1`
@@ -599,7 +606,8 @@ def group_by_indices(*args): # pylint: disable=unused-argument
599606 dimension. `kde.take(x, kde.group_by_indices(x))` would group the items in
600607 `x` by their values.
601608
602- Groups are ordered by the appearance of the first object in the group.
609+ If sort=True groups are ordered by value, otherwise groups are ordered by the
610+ appearance of the first object in the group.
603611
604612 Example 1:
605613 x: kd.slice([1, 3, 2, 1, 2, 3, 1, 3])
@@ -609,20 +617,26 @@ def group_by_indices(*args): # pylint: disable=unused-argument
609617 the items in the original DataSlice.
610618
611619 Example 2:
620+ x: kd.slice([1, 3, 2, 1, 2, 3, 1, 3], sort=True)
621+ result: kd.slice([[0, 3, 6], [2, 4], [1, 5, 7]])
622+
623+ Groups are now ordered by value.
624+
625+ Example 3:
612626 x: kd.slice([[1, 2, 1, 3, 1, 3], [1, 3, 1]])
613627 result: kd.slice([[[0, 2, 4], [1], [3, 5]], [[0, 2], [1]]])
614628
615629 We have three groups in the first sublist in order: 1, 2, 3 and two groups
616630 in the second sublist in order: 1, 3.
617631 Each sublist contains the indices of the items in the original sublist.
618632
619- Example 3 :
633+ Example 4 :
620634 x: kd.slice([1, 3, 2, 1, None, 3, 1, None])
621635 result: kd.slice([[0, 3, 6], [1, 5], [2]])
622636
623637 Missing values are not listed in the result.
624638
625- Example 4 :
639+ Example 5 :
626640 x: kd.slice([1, 2, 3, 1, 2, 3, 1, 3]),
627641 y: kd.slice([7, 4, 0, 9, 4, 0, 7, 0]),
628642 result: kd.slice([[0, 6], [1, 4], [2, 5, 7], [3]])
@@ -633,57 +647,13 @@ def group_by_indices(*args): # pylint: disable=unused-argument
633647 Args:
634648 *args: DataSlices keys to group by. All data slices must have the same
635649 shape. Scalar DataSlices are not supported.
650+ sort: Whether groups should be ordered by value.
636651
637652 Returns:
638653 INT64 DataSlice with indices and injected grouped_by dimension.
639654 """
640- raise NotImplementedError ('implemented in the backend' )
641-
642-
643- @optools .add_to_registry (aliases = ['kde.group_by_indices_sorted' ])
644- @arolla .optools .as_backend_operator (
645- 'kde.slices.group_by_indices_sorted' ,
646- qtype_constraints = [
647- (
648- M .qtype .get_field_count (P .args ) > 0 ,
649- 'expected at least one argument' ,
650- ),
651- qtype_utils .expect_data_slice_args (P .args ),
652- ],
653- qtype_inference_expr = qtypes .DATA_SLICE ,
654- experimental_aux_policy = py_boxing .DEFAULT_BOXING_POLICY ,
655- )
656- def group_by_indices_sorted (* args ): # pylint: disable=unused-argument
657- """Similar to `group_by_indices` but groups are sorted by the value.
658-
659- Each argument must contain the values of one type.
660-
661- Mixed types are not supported.
662- ExprQuote and DType are not supported.
663-
664- Example 1:
665- x: kd.slice([1, 3, 2, 1, 2, 3, 1, 3])
666- result: kd.slice([[0, 3, 6], [2, 4], [1, 5, 7]])
667-
668- We have three groups in order: 1, 2, 3. Each sublist contains the indices of
669- the items in the original DataSlice.
670-
671- Example 2:
672- x: kd.slice([1, 2, 3, 1, 2, 3, 1, 3]),
673- y: kd.slice([9, 4, 0, 3, 4, 0, 9, 0]),
674- result: kd.slice([[3], [0, 6], [1, 4], [2, 5, 7]])
675-
676- With several arguments keys is a tuple.
677- In this example we have the following groups: (1, 3), (1, 9), (2, 4), (3, 0)
678-
679- Args:
680- *args: DataSlices keys to group by. All data slices must have the same
681- shape. Scalar DataSlices are not supported.
682-
683- Returns:
684- INT64 DataSlice with indices and injected grouped_by dimension.
685- """
686- raise NotImplementedError ('implemented in the backend' )
655+ args = arolla .optools .fix_trace_args (args )
656+ return M .core .apply_varargs (_group_by_indices , sort , args )
687657
688658
689659@optools .add_to_registry (aliases = ['kde.group_by' ])
@@ -692,9 +662,10 @@ def group_by_indices_sorted(*args): # pylint: disable=unused-argument
692662 qtype_constraints = [
693663 qtype_utils .expect_data_slice (P .x ),
694664 qtype_utils .expect_data_slice_args (P .args ),
665+ qtype_utils .expect_data_slice (P .sort ),
695666 ],
696667)
697- def group_by (x , * args ):
668+ def group_by (x , * args , sort = False ): # pylint: disable=redefined-outer-name
698669 """Returns permutation of `x` with injected grouped_by dimension.
699670
700671 The resulting DataSlice has get_ndim() + 1. The first `get_ndim() - 1`
@@ -705,38 +676,43 @@ def group_by(x, *args):
705676 keys. If length of `args` is greater than 1, the key is a tuple.
706677 If `args` is empty, the key is `x`.
707678
708- Groups are ordered by the appearance of the first item in the group.
679+ If sort=True groups are ordered by value, otherwise groups are ordered by the
680+ appearance of the first object in the group.
709681
710682 Example 1:
711683 x: kd.slice([1, 3, 2, 1, 2, 3, 1, 3])
712684 result: kd.slice([[1, 1, 1], [3, 3, 3], [2, 2]])
713685
714686 Example 2:
687+ x: kd.slice([1, 3, 2, 1, 2, 3, 1, 3], sort=True)
688+ result: kd.slice([[1, 1, 1], [2, 2], [3, 3, 3]])
689+
690+ Example 3:
715691 x: kd.slice([[1, 2, 1, 3, 1, 3], [1, 3, 1]])
716692 result: kd.slice([[[1, 1, 1], [2], [3, 3]], [[1, 1], [3]]])
717693
718- Example 3 :
694+ Example 4 :
719695 x: kd.slice([1, 3, 2, 1, None, 3, 1, None])
720696 result: kd.slice([[1, 1, 1], [3, 3], [2]])
721697
722698 Missing values are not listed in the result.
723699
724- Example 4 :
700+ Example 5 :
725701 x: kd.slice([1, 2, 3, 4, 5, 6, 7, 8]),
726702 y: kd.slice([7, 4, 0, 9, 4, 0, 7, 0]),
727703 result: kd.slice([[1, 7], [2, 5], [3, 6, 8], [4]])
728704
729705 When *args is present, `x` is not used for the key.
730706
731- Example 5 :
707+ Example 6 :
732708 x: kd.slice([1, 2, 3, 4, None, 6, 7, 8]),
733709 y: kd.slice([7, 4, 0, 9, 4, 0, 7, None]),
734710 result: kd.slice([[1, 7], [2, None], [3, 6], [4]])
735711
736712 Items with missing key is not listed in the result.
737713 Missing `x` values are missing in the result.
738714
739- Example 6 :
715+ Example 7 :
740716 x: kd.slice([1, 2, 3, 4, 5, 6, 7, 8]),
741717 y: kd.slice([7, 4, 0, 9, 4, 0, 7, 0]),
742718 z: kd.slice([A, D, B, A, D, C, A, B]),
@@ -751,22 +727,25 @@ def group_by(x, *args):
751727 *args: DataSlices keys to group by. All data slices must have the same shape
752728 as x. Scalar DataSlices are not supported. If not present, `x` is used as
753729 the key.
730+ sort: Whether groups should be ordered by value.
754731
755732 Returns:
756733 DataSlice with the same shape and schema as `x` with injected grouped
757734 by dimension.
758735 """
759736 args = arolla .optools .fix_trace_args (args )
760737 dispatch_op = arolla .types .DispatchOperator (
761- 'x, args' ,
738+ 'x, args, sort ' ,
762739 x_is_key_case = arolla .types .DispatchCase (
763- take (P .x , group_by_indices ( P .x )),
740+ take (P .x , _group_by_indices ( P . sort , P .x )),
764741 condition = M .qtype .get_field_count (P .args ) == 0 ,
765742 ),
766743 # TODO: add assertion: x has the same shape as other args.
767- default = take (P .x , M .core .apply_varargs (group_by_indices , P .args )),
744+ default = take (
745+ P .x , M .core .apply_varargs (_group_by_indices , P .sort , P .args )
746+ ),
768747 )
769- return dispatch_op (x , args )
748+ return dispatch_op (x , args , sort )
770749
771750
772751@optools .as_lambda_operator (
0 commit comments